aboutsummaryrefslogtreecommitdiffstats
path: root/beancount_extras_kris7t/importers/rules.py
diff options
context:
space:
mode:
Diffstat (limited to 'beancount_extras_kris7t/importers/rules.py')
-rw-r--r--beancount_extras_kris7t/importers/rules.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/beancount_extras_kris7t/importers/rules.py b/beancount_extras_kris7t/importers/rules.py
new file mode 100644
index 0000000..3890f24
--- /dev/null
+++ b/beancount_extras_kris7t/importers/rules.py
@@ -0,0 +1,134 @@
1#!/usr/bin/env python3
2
3from typing import cast, Dict, List, NamedTuple, Optional, Tuple, Union
4import re
5
6from beancount.core.amount import Amount
7
8from beancount_extras_kris7t.importers.utils import Extractor, Row
9
10WILDCARD = re.compile('.*')
11
12
13class When(NamedTuple):
14 payee: re.Pattern
15 text: re.Pattern
16 amount: Optional[Amount]
17
18
19def _compile_regex(s: str) -> re.Pattern:
20 return re.compile(s, re.IGNORECASE)
21
22
23def when(payee: Optional[Union[re.Pattern, str]] = None,
24 text: Optional[Union[re.Pattern, str]] = None,
25 amount: Optional[Amount] = None) -> When:
26 if not payee and not text:
27 raise TypeError('at least one of payee and desc must be provided')
28 if isinstance(payee, str):
29 payee_regex = _compile_regex(payee)
30 else:
31 payee_regex = payee or WILDCARD
32 if isinstance(text, str):
33 text_regex = _compile_regex(text)
34 else:
35 text_regex = text or WILDCARD
36 return When(payee_regex, text_regex, amount)
37
38
39Condition = Union[str, re.Pattern, When]
40
41
42def _compile_condition(cond: Condition) -> When:
43 if isinstance(cond, When):
44 return cond
45 else:
46 return when(text=cond)
47
48
49class let(NamedTuple):
50 payee: Optional[str] = None
51 desc: Optional[str] = None
52 account: Optional[str] = None
53 flag: Optional[str] = None
54 tag: Optional[str] = None
55
56
57Action = Union[str,
58 Tuple[str, str],
59 Tuple[str, str, str],
60 Tuple[str, str, str, str],
61 let]
62
63
64def _compile_action(action: Action) -> let:
65 if isinstance(action, str):
66 return let(account=action)
67 if isinstance(action, let):
68 return action
69 elif isinstance(action, tuple):
70 if len(action) == 2:
71 payee, account = cast(Tuple[str, str], action)
72 return let(payee=payee, account=account)
73 elif len(action) == 3:
74 payee, desc, account = cast(Tuple[str, str, str], action)
75 return let(payee, desc, account)
76 else:
77 flag, payee, desc, account = cast(Tuple[str, str, str, str], action)
78 return let(payee, desc, account, flag)
79 else:
80 raise ValueError(f'Unknown action: {action}')
81
82
83Rules = Dict[Condition, Action]
84CompiledRules = List[Tuple[When, let]]
85
86
87def _compile_rules(rules: Rules) -> CompiledRules:
88 return [(_compile_condition(cond), _compile_action(action))
89 for cond, action in rules.items()]
90
91
92def _rule_condition_matches(cond: When, row: Row) -> bool:
93 if row.payee:
94 payee_valid = cond.payee.search(row.payee) is not None
95 else:
96 payee_valid = cond.payee == WILDCARD
97 if cond.text == WILDCARD:
98 text_valid = True
99 else:
100 characteristics: List[str] = []
101 if row.entry_type:
102 characteristics.append(row.entry_type)
103 if row.payee:
104 characteristics.append(row.payee)
105 if row.comment:
106 characteristics.append(row.comment)
107 row_str = ' '.join(characteristics)
108 text_valid = cond.text.search(row_str) is not None
109 amount_valid = not cond.amount or row.transacted_amount == cond.amount
110 return payee_valid and text_valid and amount_valid
111
112
113def extract_rules(input_rules: Rules) -> Extractor:
114 compiled_rules = _compile_rules(input_rules)
115
116 def do_extract(row: Row) -> None:
117 for cond, (payee, desc, account, flag, tag) in compiled_rules:
118 if not _rule_condition_matches(cond, row):
119 continue
120 if payee is not None:
121 if row.payee == row.comment:
122 row.comment = ''
123 row.payee = payee
124 if desc is not None:
125 row.comment = desc
126 if account is not None:
127 row.assign_to_account(account)
128 if flag is not None:
129 row.flag = flag
130 if tag is not None:
131 row.tags.add(tag)
132 if row.postings:
133 return
134 return do_extract