aboutsummaryrefslogtreecommitdiffstats
path: root/beancount_extras_kris7t/importers/utils.py
blob: f0a8134eefd266236afb0be6db5a0ecd175ac873 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
'''
Utilities for custom importers.
'''
__copyright__ = 'Copyright (c) 2020  Kristóf Marussy <kristof@marussy.com>'
__license__ = 'GNU GPLv2'

from abc import ABC, abstractmethod
import datetime as dt
from decimal import Decimal
from typing import cast, Callable, Iterable, List, NamedTuple, Optional, Set, TypeVar, Union

from beancount.core import amount as am, data
from beancount.core.amount import Amount
from beancount.core.flags import FLAG_OKAY, FLAG_WARNING
from beancount.core.number import D, ZERO

MISSING_AMOUNT = cast(Amount, None)
COMMENT_META = 'import-raw-comment'
PAYEE_META = 'import-raw-payee'


class InvalidEntry(Exception):
    pass


class Posting(NamedTuple):
    account: str
    amount: Amount


class Row(ABC):
    entry_type: Optional[str]
    payee: Optional[str]
    comment: str
    meta: data.Meta
    flag: str
    tags: Set[str]
    links: Set[str]
    _postings: Optional[List[Posting]]

    def __init__(self,
                 file_name: str,
                 line_number: int,
                 entry_type: Optional[str],
                 payee: Optional[str],
                 comment: str):
        self.entry_type = entry_type
        self.payee = payee
        self.comment = comment
        self.meta = data.new_metadata(file_name, line_number)
        self.flag = FLAG_OKAY
        self.tags = set()
        self.links = set()
        self._postings = None

    @property
    @abstractmethod
    def transacted_amount(self) -> Amount:
        pass

    @property
    def transacted_currency(self) -> str:
        return self.transacted_amount.currency

    @property
    def postings(self) -> Optional[List[Posting]]:
        return self._postings

    def assign_to_accounts(self, *postings: Posting) -> None:
        if self.done:
            raise InvalidEntry('Transaction is alrady done processing')
        self._postings = list(postings)
        if not self._postings:
            raise InvalidEntry('Not assigned to any accounts')
        head, *rest = self._postings
        sum = head.amount
        for posting in rest:
            sum = am.add(sum, posting.amount)
        if sum != self.transacted_amount:
            self.flag = FLAG_WARNING

    def assign_to_account(self, account: str) -> None:
        self.assign_to_accounts(Posting(account, self.transacted_amount))

    @property
    def done(self) -> bool:
        return self._postings is not None


Extractor = Callable[[Row], None]
TRow = TypeVar('TRow', bound=Row)


def run_row_extractors(row: TRow, extractors: Iterable[Callable[[TRow], None]]) -> None:
    for extractor in extractors:
        extractor(row)
        if row.done:
            return


def extract_unknown(expenses_account: str, income_account: str) -> Extractor:
    def do_extract(row: Row) -> None:
        if row.transacted_amount.number < ZERO:
            row.assign_to_account(expenses_account)
        else:
            row.assign_to_account(income_account)
        row.flag = FLAG_WARNING
    return do_extract


def parse_date(date_str: str, format_string: str) -> dt.date:
    try:
        return dt.datetime.strptime(date_str, format_string).date()
    except ValueError as exc:
        raise InvalidEntry(f'Cannot parse date: {date_str}') from exc


def parse_number(in_amount: Union[str, int, float, Decimal]) -> Decimal:
    try:
        value = D(in_amount)
    except ValueError as exc:
        raise InvalidEntry(f'Cannot parse number: {in_amount}') from exc
    if value is None:
        raise InvalidEntry(f'Parse number returned None: {in_amount}')
    return value