aboutsummaryrefslogblamecommitdiffstats
path: root/beancount_extras_kris7t/importers/rules.py
blob: 3890f24d048cef534143f40a51640206f64224b9 (plain) (tree)





































































































































                                                                                
#!/usr/bin/env python3

from typing import cast, Dict, List, NamedTuple, Optional, Tuple, Union
import re

from beancount.core.amount import Amount

from beancount_extras_kris7t.importers.utils import Extractor, Row

WILDCARD = re.compile('.*')


class When(NamedTuple):
    payee: re.Pattern
    text: re.Pattern
    amount: Optional[Amount]


def _compile_regex(s: str) -> re.Pattern:
    return re.compile(s, re.IGNORECASE)


def when(payee: Optional[Union[re.Pattern, str]] = None,
         text: Optional[Union[re.Pattern, str]] = None,
         amount: Optional[Amount] = None) -> When:
    if not payee and not text:
        raise TypeError('at least one of payee and desc must be provided')
    if isinstance(payee, str):
        payee_regex = _compile_regex(payee)
    else:
        payee_regex = payee or WILDCARD
    if isinstance(text, str):
        text_regex = _compile_regex(text)
    else:
        text_regex = text or WILDCARD
    return When(payee_regex, text_regex, amount)


Condition = Union[str, re.Pattern, When]


def _compile_condition(cond: Condition) -> When:
    if isinstance(cond, When):
        return cond
    else:
        return when(text=cond)


class let(NamedTuple):
    payee: Optional[str] = None
    desc: Optional[str] = None
    account: Optional[str] = None
    flag: Optional[str] = None
    tag: Optional[str] = None


Action = Union[str,
               Tuple[str, str],
               Tuple[str, str, str],
               Tuple[str, str, str, str],
               let]


def _compile_action(action: Action) -> let:
    if isinstance(action, str):
        return let(account=action)
    if isinstance(action, let):
        return action
    elif isinstance(action, tuple):
        if len(action) == 2:
            payee, account = cast(Tuple[str, str], action)
            return let(payee=payee, account=account)
        elif len(action) == 3:
            payee, desc, account = cast(Tuple[str, str, str], action)
            return let(payee, desc, account)
        else:
            flag, payee, desc, account = cast(Tuple[str, str, str, str], action)
            return let(payee, desc, account, flag)
    else:
        raise ValueError(f'Unknown action: {action}')


Rules = Dict[Condition, Action]
CompiledRules = List[Tuple[When, let]]


def _compile_rules(rules: Rules) -> CompiledRules:
    return [(_compile_condition(cond), _compile_action(action))
            for cond, action in rules.items()]


def _rule_condition_matches(cond: When, row: Row) -> bool:
    if row.payee:
        payee_valid = cond.payee.search(row.payee) is not None
    else:
        payee_valid = cond.payee == WILDCARD
    if cond.text == WILDCARD:
        text_valid = True
    else:
        characteristics: List[str] = []
        if row.entry_type:
            characteristics.append(row.entry_type)
        if row.payee:
            characteristics.append(row.payee)
        if row.comment:
            characteristics.append(row.comment)
        row_str = ' '.join(characteristics)
        text_valid = cond.text.search(row_str) is not None
    amount_valid = not cond.amount or row.transacted_amount == cond.amount
    return payee_valid and text_valid and amount_valid


def extract_rules(input_rules: Rules) -> Extractor:
    compiled_rules = _compile_rules(input_rules)

    def do_extract(row: Row) -> None:
        for cond, (payee, desc, account, flag, tag) in compiled_rules:
            if not _rule_condition_matches(cond, row):
                continue
            if payee is not None:
                if row.payee == row.comment:
                    row.comment = ''
                row.payee = payee
            if desc is not None:
                row.comment = desc
            if account is not None:
                row.assign_to_account(account)
            if flag is not None:
                row.flag = flag
            if tag is not None:
                row.tags.add(tag)
            if row.postings:
                return
    return do_extract