From 14248ed55d6718725bebb3263dc2dd9c68d408ed Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer Date: Fri, 29 Mar 2019 20:11:08 -0700 Subject: [PATCH] training.py: two tweaks to feature selection 1. Include posting amounts as a feature. This allows us to distinguish different classes of payments to the same payee (e.g. recurring membership fees, which often have a constant amount, from individual purchases). 2. For example key/value pairs, include the key by itself (with no substring of the value) as a feature. This is useful because different account types often have non-overlapping sets of example keys, and including the bare key as a value allows the decision tree to be effectively segmented by account type fairly close to the root. These two very small changes significantly improve training accuracy on my journal, from 94.81% to 99.32% (an 86% reduction in error rate!). --- beancount_import/training.py | 5 +++-- beancount_import/training_test.py | 6 +++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/beancount_import/training.py b/beancount_import/training.py index 30f74173..007233c4 100644 --- a/beancount_import/training.py +++ b/beancount_import/training.py @@ -30,10 +30,11 @@ def get_features(example: PredictionInput) -> Dict[str, bool]: features = collections.defaultdict(lambda: False) # type: Dict[str, bool] features['account:%s' % example.source_account] = True - - # For now, skip amount and date. + features['amount:%s' % example.amount.currency] = example.amount.number + # For now, skip date. for key, values in example.key_value_pairs.items(): + features[key] = True if isinstance(values, str): values = (values, ) for value in values: diff --git a/beancount_import/training_test.py b/beancount_import/training_test.py index aaabf27e..2ae65eb9 100644 --- a/beancount_import/training_test.py +++ b/beancount_import/training_test.py @@ -1,6 +1,7 @@ import datetime from beancount.core.data import Amount +from beancount.core.number import D from . import test_util from . import training @@ -21,7 +22,10 @@ def test_get_features(): 'a:hello': True, 'b:foo': True, 'b:bar': True, - 'b:foo bar': True + 'b:foo bar': True, + 'a': True, + 'b': True, + 'amount:USD': D(3) }