diff --git a/beanquery/query_env.py b/beanquery/query_env.py index 9c2ba0c..532849c 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -666,13 +666,64 @@ def interval(x): unit = m.group(2) if unit == 'day': return relativedelta(days=number) + if unit == 'week': + return relativedelta(weeks=number) if unit == 'month': return relativedelta(months=number) if unit == 'year': return relativedelta(years=number) + if unit == 'decade': + return relativedelta(years=number * 10) + if unit == 'century': + return relativedelta(years=number * 100) + if unit == 'millennium': + return relativedelta(years=number * 1000) return None +@function([relativedelta, datetime.date, datetime.date], datetime.date) +def date_bin(stride, source, origin): + """Bin a date into the specified stride aligned with the specified origin. + + As an extension to the the SQL standard ``date_bin()`` function this + function also accepts strides containing units of months and years. + """ + if stride.months or stride.years: + if origin + stride <= origin: + # FIXME: this should raise and error: stride must be greater than zero + return None + if source >= origin: + d = n = origin + while True: + n += stride + if n >= source: + return d + d = n + else: + n = origin + while True: + n -= stride + if n <= source: + return n + else: + seconds = stride.days * 86400 + stride.hours * 3600 + stride.minutes * 60 + stride.seconds + if seconds < 0: + # FIXME: this should raise and error: stride must be greater than zero + return None + diff = (source - origin).total_seconds() + modulo = diff % seconds + delta = diff - modulo + result = origin + datetime.timedelta(seconds=delta) + if modulo < 0: + result -= datetime.timedelta(seconds=seconds) + return result + + +@function([str, datetime.date, datetime.date], datetime.date, name='date_bin') +def date_bin_str(stride, source, origin): + return date_bin(interval(stride), source, origin) + + def aggregator(intypes, name=None): def decorator(cls): cls.__intypes__ = intypes diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 3ddf1f3..af8ec83 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -401,6 +401,14 @@ def test_interval_ops(self): self.assertResult('''SELECT interval("1 baz")''', None, relativedelta) self.assertResult('''SELECT interval("A days")''', None, relativedelta) + def test_date_bin(self): + self.assertResult('''SELECT date_bin(interval('1 year'), 2024-11-10, 2024-06-01)''', datetime.date(2024, 6, 1)) + self.assertResult('''SELECT date_bin('1 year', 2024-11-10, 2024-06-01)''', datetime.date(2024, 6, 1)) + self.assertResult('''SELECT date_bin('1 year', 2024-11-10, 2025-06-01)''', datetime.date(2024, 6, 1)) + self.assertResult('''SELECT date_bin('1 month', 2024-11-10, 2024-06-03)''', datetime.date(2024, 11, 3)) + self.assertResult('''SELECT date_bin('3 days', 2024-11-10, 2024-11-02)''', datetime.date(2024, 11, 8)) + self.assertResult('''SELECT date_bin('3 days', 2024-11-10, 2024-11-14)''', datetime.date(2024, 11, 8)) + class TestBeancountFunctions(QueryBase): INPUT = textwrap.dedent("""