Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for use in importer hooks #16

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/beanahead/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from beancount.core.data import Transaction, Entries
from beancount.ingest import extract

from . import utils
from . import reconcile
from .errors import BeanaheadWriteError


class ReconcileExpected:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could a simple test be added to verify the hook is working as intended?

"""
Hook class for smart_importer to reconcile expected entries on the fly.

You also need to use the adapted duplicate hook to avoid false positives
using the new style import invocation:

...
hools = [ReconcileExpected.adapted_find_duplicate_entries]
beancount.ingest.scripts_utils.ingest(CONFIG, hooks=hooks)
"""

def __init__(self, x_txns_file):
path = utils.get_verified_path(x_txns_file)
utils.set_root_accounts_context(x_txns_file)
_ = utils.get_verified_ledger_file_key(path) # just verify that a ledger
self.expected_txns_path = path
self.expected_txns: list[Transaction] = utils.get_unverified_txns(path)

def __call__(self, importer, file, imported_entries, existing_entries) -> Entries:
"""Apply the hook and modify the imported entries.

Args:
importer: The importer that this hooks is being applied to.
file: The file that is being imported.
imported_entries: The current list of imported entries.
existing_entries: The existing entries, as passed to the extract
function.

Returns:
The updated imported entries.
"""
new_txns, new_other = reconcile.separate_out_txns(imported_entries)
reconciled_x_txns = reconcile.reconcile_x_txns(self.expected_txns, new_txns)

updated_new_txns = reconcile.update_new_txns(new_txns, reconciled_x_txns)
updated_entries = updated_new_txns + new_other

# Update expected transation file
x_txns_to_remove = []
for x_txn, _ in reconciled_x_txns:
if x_txn in self.expected_txns:
x_txns_to_remove.append(x_txn)

prev_contents = utils.get_content(self.expected_txns_path)
try:
utils.remove_txns_from_ledger(self.expected_txns_path, x_txns_to_remove)
except Exception as err:
utils.write(self.expected_txns_path, prev_contents)
raise BeanaheadWriteError(
self.expected_txns_path, [self.expected_txns_path]
) from err

return updated_entries

@staticmethod
def adapted_find_duplicate_entries(new_entries_list, existing_entries):
keep = []
# filter out expected transactions from duplicate detection
for entry in existing_entries:
if isinstance(entry, Transaction) and utils.TAGS_X & entry.tags:
continue
keep.append(entry)
return extract.find_duplicate_entries(new_entries_list, keep)
20 changes: 15 additions & 5 deletions src/beanahead/reconcile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from decimal import Decimal
from pathlib import Path
import re
import sys

import beancount
from beancount import loader
Expand Down Expand Up @@ -147,7 +148,9 @@ def get_pattern(x_txn: Transaction) -> re.Pattern:
def get_payee_matches(txns: list[Transaction], x_txn: Transaction) -> list[Transaction]:
"""Return transactions matching an Expected Transaction's payee."""
pattern = get_pattern(x_txn)
return [txn for txn in txns if pattern.search(txn.payee) is not None]
return [
txn for txn in txns if (txn.payee and pattern.search(txn.payee) is not None)
]


def get_common_accounts(a: Transaction, b: Transaction) -> set[str]:
Expand Down Expand Up @@ -362,7 +365,8 @@ def confirm_single(
print(
f"{utils.SEPARATOR_LINE}Expected Transaction:\n"
f"{utils.compose_entries_content(x_txn)}\n"
f"Incoming Transaction:\n{utils.compose_entries_content(matches[0])}"
f"Incoming Transaction:\n{utils.compose_entries_content(matches[0])}",
file=sys.stderr,
maread99 marked this conversation as resolved.
Show resolved Hide resolved
)
response = utils.get_input(MSG_SINGLE_MATCH).lower()
while response not in ["n", "y"]:
Expand Down Expand Up @@ -392,10 +396,11 @@ def get_mult_match(
print(
f"{utils.SEPARATOR_LINE}Expected Transaction:\n"
f"{utils.compose_entries_content(x_txn)}\n\n"
f"Incoming Transactions:\n"
f"Incoming Transactions:\n",
file=sys.stderr,
)
for i, match in enumerate(matches):
print(f"{i}\n{utils.compose_entries_content(match)}")
print(f"{i}\n{utils.compose_entries_content(match)}", file=sys.stderr)

max_value = len(matches) - 1
options = f"[0-{max_value}]/n"
Expand Down Expand Up @@ -521,7 +526,12 @@ def update_new_txn(new_txn: Transaction, x_txn: Transaction) -> Transaction:
new_txn_posting = get_posting_to_account(new_txn, account)

# carry over any meta not otherwise defined on new_txn
updated_posting = new_txn_posting._replace(meta=new_txn_posting.meta.copy())
if new_txn_posting.meta:
updated_posting = new_txn_posting._replace(
meta=new_txn_posting.meta.copy()
)
else:
updated_posting = new_txn_posting._replace(meta={})
for k, v in posting.meta.items():
updated_posting.meta.setdefault(k, v)

Expand Down
9 changes: 5 additions & 4 deletions src/beanahead/rx_txns.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def get_definition_group(definition: Transaction) -> GrouperKey:
if account == bal_sheet_account:
continue
account_type = get_account_type(account)
if account_type == "Assets":
if account_type == utils.RootAccountsContext.get("name_assets", "Assets"):
other_sides.add("Assets")
elif account_type == "Income":
elif account_type == utils.RootAccountsContext.get("name_income", "Income"):
other_sides.add("Income")
else:
other_sides.add("Expenses")
Expand Down Expand Up @@ -657,8 +657,9 @@ def add_txns(self, end: str | pd.Timestamp = END_DFLT):
ledger_txns = self.rx_txns + new_txns

# ensure all new content checks out before writting anything
content_ledger = compose_new_content("rx", ledger_txns)
content_defs = compose_new_content("rx_def", new_defs)
name_options = utils.set_root_accounts_context(self.path_ledger_main)
content_ledger = compose_new_content("rx", ledger_txns, name_options)
maread99 marked this conversation as resolved.
Show resolved Hide resolved
content_defs = compose_new_content("rx_def", new_defs, name_options)

self._overwrite_beancount_file(self.path_ledger, content_ledger)
also_revert = [self.path_ledger]
Expand Down
6 changes: 6 additions & 0 deletions src/beanahead/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

def make_file(args: argparse.Namespace):
"""Pass through command line args to make a new beanahead file."""
utils.set_root_accounts_context(args.main_ledger)
utils.create_beanahead_file(args.key, args.dirpath, args.filename)


Expand Down Expand Up @@ -89,6 +90,11 @@ def main():
choices=["x", "rx", "rx_def"],
metavar="key",
)
parser_make.add_argument(
maread99 marked this conversation as resolved.
Show resolved Hide resolved
"main_ledger",
help="Path to the main ledger file to read its options.",
metavar="main_ledger",
)
parser_make.add_argument(
*["-d", "--dirpath"],
help=(
Expand Down
54 changes: 47 additions & 7 deletions src/beanahead/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datetime
from pathlib import Path
import re
import sys

from beancount import loader
from beancount.core import data
Expand All @@ -29,6 +30,14 @@
TAG_RX = "rx_txn"
TAGS_X = set([TAG_X, TAG_RX])

ADOPT_OPTIONS = [
"name_assets",
"name_liabilities",
"name_income",
"name_expenses",
"name_equity",
]

RX_META_DFLTS = {
"final": None,
"roll": True,
Expand Down Expand Up @@ -67,6 +76,26 @@

LEDGER_FILE_KEYS = ["x", "rx"]

RootAccountsContext = {} # global context


def set_root_accounts_context(path_ledger: str) -> dict[str]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a simple test to verify that all's working as expected when users are using non-default category names?

"""Set the root accounts context from the ledger file path.

Returns
-------
dict[str]
The name options set in the ledger.
"""
name_options: dict[str] = {}
options = get_options(path_ledger)
for opt in ADOPT_OPTIONS:
if opt in options:
name_options[opt] = options[opt]
global RootAccountsContext
RootAccountsContext = name_options
return name_options


def validate_file_key(file_key: str):
"""Validate a file_key.
Expand Down Expand Up @@ -114,8 +143,15 @@ def compose_header_footer(file_key: str) -> tuple[str, str]:
"""
config = FILE_CONFIG[file_key]
plugin, tag, comment = config["plugin"], config["tag"], config["comment"]
extra_headers = ""
for k, v in RootAccountsContext.items():
extra_headers += f'option "{k}" "{v}"\n'

header = f"""option "title" "{config['title']}"\n"""
if extra_headers:
header += "\n"
header += extra_headers
header += "\n"
if plugin is not None:
header += f'plugin "{plugin}"\n'
header += f"pushtag #{tag}\n"
Expand Down Expand Up @@ -518,7 +554,7 @@ def reverse_automatic_balancing(txn: Transaction) -> Transaction:
"""
new_postings = []
for posting in txn.postings:
if AUTOMATIC_META in posting.meta:
if AUTOMATIC_META in (posting.meta or {}):
meta = {k: v for k, v in posting.meta.items() if k != AUTOMATIC_META}
posting = posting._replace(units=None, meta=meta)
new_postings.append(posting)
Expand All @@ -537,10 +573,7 @@ def is_assets_account(string: str) -> bool:
>>> is_assets_account("Assets:US:BofA:Checking")
True
"""
return is_account_type("Assets", string)


BAL_SHEET_ACCS = ["Assets", "Liabilities"]
return is_account_type(RootAccountsContext.get("name_assets", "Assets"), string)


def is_balance_sheet_account(string: str) -> bool:
Expand All @@ -566,7 +599,13 @@ def is_balance_sheet_account(string: str) -> bool:
>>> is_balance_sheet_account("Income:US:BayBook:Match401k")
False
"""
return any(is_account_type(acc_type, string) for acc_type in BAL_SHEET_ACCS)
return any(
is_account_type(acc_type, string)
for acc_type in [
RootAccountsContext.get("name_assets", "Assets"),
RootAccountsContext.get("name_liabilities", "Liabilities"),
]
)


def get_balance_sheet_accounts(txn: Transaction) -> list[str]:
Expand Down Expand Up @@ -845,7 +884,8 @@ def get_input(text: str) -> str:
-----
Function included to facilitate mocking user input when testing.
"""
return input(text)
print(text, file=sys.stderr, end=": ")
return input()


def response_is_valid_number(response: str, max_value: int) -> bool:
Expand Down
Loading