Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for referencing output names in WHERE and HAVING #29

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ Version 0.1 (unreleased)
behavior, the query should be written::

SELECT date, narration ORDER BY date DESC, narration DESC

- Output names defined with ``SELECT ... AS`` can now be used in the
``WHERE`` and ``HAVING`` clauses in addition to the ``GROUP BY`` and
``ORDER BY`` clauses where they were already supported.
57 changes: 40 additions & 17 deletions beanquery/query_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,27 @@ class CompilationEnvironment:
# The name of the context.
context_name = None

# Maps of names to evaluators for columns and functions.
columns = None
functions = None
# Maps of names to evaluators for output names, columns, and functions.
names = {}
columns = {}
functions = {}

def get_column(self, name):
"""Return a column accessor for the given named column.
Args:
name: A string, the name of the column to access.
"""
try:
return self.columns[name]()
except KeyError as exc:
raise CompilationError("Invalid column name '{}' in {} context.".format(
name, self.context_name)) from exc
expr = self.names.get(name)
if expr is not None:
# Expression evaluatoes may keep state (for example
# aggregate functions) thus we need to return a copy.
return copy.copy(expr)

column = self.columns.get(name)
if column is not None:
return column()

raise CompilationError(f'Unknown column "{name}" in {self.context_name}')

def get_function(self, name, operands):
"""Return a function accessor for the given named function.
Expand Down Expand Up @@ -492,7 +499,7 @@ def is_hashable_type(node):
return not issubclass(node.dtype, inventory.Inventory)


def find_unique_name(name, allocated_set):
def unique_name(name, allocated_set):
"""Come up with a unique name for 'name' amongst 'allocated_set'.

Args:
Expand Down Expand Up @@ -525,9 +532,15 @@ def compile_targets(targets, environ):
Args:
targets: A list of target expressions from the parser.
environ: A compilation context for the targets.

Returns:
A list of compiled target expressions with resolved names.
A tuple containing list of compiled expressions and a dictionary
mapping explicit output names assigned with the AS keyword to
compiled extpressions.

"""
names = {}

# Bind the targets expressions to the execution context.
if isinstance(targets, query_parser.Wildcard):
# Insert the full list of available columns.
Expand All @@ -539,11 +552,19 @@ def compile_targets(targets, environ):
target_names = set()
for target in targets:
c_expr = compile_expression(target.expression, environ)
target_name = find_unique_name(
target.name or query_parser.get_expression_name(target.expression),
target_names)
target_names.add(target_name)
c_targets.append(EvalTarget(c_expr, target_name, is_aggregate(c_expr)))
if target.name:
# The target as an explicit output name: make sure that it
# does not collied with any other output name.
name = target.name
if name in target_names:
raise CompilationError(f'Duplicate output name "{name}" in SELECT list')
# Keep track of explicit output names.
names[name] = c_expr
else:
# Otherwise generate an unique output name.
name = unique_name(query_parser.get_expression_name(target.expression), target_names)
target_names.add(name)
c_targets.append(EvalTarget(c_expr, name, is_aggregate(c_expr)))

columns, aggregates = get_columns_and_aggregates(c_expr)

Expand All @@ -559,7 +580,7 @@ def compile_targets(targets, environ):
raise CompilationError(
"Aggregates of aggregates are not allowed")

return c_targets
return c_targets, names


def compile_group_by(group_by, c_targets, environ):
Expand Down Expand Up @@ -843,7 +864,9 @@ def compile_select(select, targets_environ, postings_environ, entries_environ):
c_from = compile_from(select.from_clause, entries_environ)

# Compile the targets.
c_targets = compile_targets(select.targets, targets_environ)
c_targets, output_names = compile_targets(select.targets, targets_environ)
targets_environ.names = output_names
postings_environ.names = output_names

# Bind the WHERE expression to the execution environment.
c_where = compile_expression(select.where_clause, postings_environ)
Expand Down
29 changes: 20 additions & 9 deletions beanquery/query_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,11 @@ def test_compile_EvalSub(self):

class TestCompileMisc(unittest.TestCase):

def test_find_unique_names(self):
self.assertEqual('date', qc.find_unique_name('date', {}))
self.assertEqual('date', qc.find_unique_name('date', {'account', 'number'}))
self.assertEqual('date_1', qc.find_unique_name('date', {'date', 'number'}))
self.assertEqual('date_2',
qc.find_unique_name('date', {'date', 'date_1', 'date_3'}))
def test_unique_name(self):
self.assertEqual('date', qc.unique_name('date', {}))
self.assertEqual('date', qc.unique_name('date', {'account', 'number'}))
self.assertEqual('date_1', qc.unique_name('date', {'date', 'number'}))
self.assertEqual('date_2', qc.unique_name('date', {'date', 'date_1', 'date_3'}))


class CompileSelectBase(unittest.TestCase):
Expand Down Expand Up @@ -349,10 +348,9 @@ def test_compile_targets_wildcard(self):
for target in query.c_targets))

def test_compile_targets_named(self):
# Test the wildcard expansion.
query = self.compile("SELECT length(account), account as a, date;")
query = self.compile("SELECT length(account) AS l, account AS a, date;")
self.assertEqual(
[qc.EvalTarget(qe.F('length', str)([qe.AccountColumn()]), 'length_account', False),
[qc.EvalTarget(qe.F('length', str)([qe.AccountColumn()]), 'l', False),
qc.EvalTarget(qe.AccountColumn(), 'a', False),
qc.EvalTarget(qe.DateColumn(), 'date', False)],
query.c_targets)
Expand Down Expand Up @@ -610,6 +608,19 @@ def test_compile_order_by_aggregate(self):
self.assertEqual([(1, False)], query.order_spec)


class TestCompileSelectNamed(CompileSelectBase):

def test_compile_select_where_name(self):
query = self.compile("""
SELECT date AS d WHERE d = 2022-03-30;
""")

def test_compile_select_having_name(self):
query = self.compile("""
SELECT sum(position) AS s GROUP BY year HAVING not empty(s);
""")


class TestTranslationJournal(CompileSelectBase):

maxDiff = 4096
Expand Down
39 changes: 39 additions & 0 deletions beanquery/query_execute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,5 +1079,44 @@ def test_flatten(self):
])


class TestOutputNames(QueryBase):

data = """
2020-01-01 open Assets:Bank
2020-01-01 open Assets:Receivable
2020-01-01 open Income:Sponsorship

2020-03-01 * "Sponsorship from A"
invoice: "A01"
Assets:Receivable 100.00 USD
Income:Sponsorship -100.00 USD

2020-03-01 * "Sponsorship from B"
invoice: "B01"
Assets:Receivable 30.00 USD
Income:Sponsorship -30.00 USD

2020-03-10 * "Payment from A"
invoice: "A01"
Assets:Bank 100.00 USD
Assets:Receivable -100.00 USD
"""

def test_output_names(self):
self.check_query(self.data, """
SELECT
entry_meta('invoice') AS invoice,
sum(position) AS balance
WHERE
root(account, 2) = 'Assets:Receivable'
GROUP BY
invoice
HAVING
not empty(balance);
""",
[('invoice', object), ('balance', inventory.Inventory)],
[('B01', inventory.from_string("30.00 USD"))])


if __name__ == '__main__':
unittest.main()