Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanups #213

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions beanquery/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import tatsu

from ..errors import ProgrammingError
from .parser import BQLParser
from . import ast
from . import parser


class BQLSemantics:
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, parseinfo):

def parse(text):
try:
return parser.BQLParser().parse(text, semantics=BQLSemantics())
return BQLParser().parse(text, semantics=BQLSemantics())
except tatsu.exceptions.ParseError as exc:
line = exc.tokenizer.line_info(exc.pos).line
parseinfo = tatsu.infos.ParseInfo(exc.tokenizer, exc.item, exc.pos, exc.pos + 1, line, [])
Expand Down
140 changes: 52 additions & 88 deletions beanquery/query_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import datetime
import unittest
from decimal import Decimal

from decimal import Decimal as D

from beanquery import Connection, CompilationError, ProgrammingError
from beanquery import compiler
Expand All @@ -12,9 +13,11 @@
from beanquery import parser
from beanquery import tables
from beanquery.parser import ast
from beanquery.tests.tables import TestTable as _Table


class Table:
# mock table to be used in tests
def __init__(self, name):
self.name = name

Expand All @@ -24,70 +27,66 @@ def __eq__(self, other):
return other.name == self.name


class Column(qc.EvalColumn):
# mock column to be used in tests
def __init__(self, name, dtype=str):
self.name = name
self.dtype = dtype

def __eq__(self, other):
if not isinstance(other, qc.EvalColumn):
return NotImplemented
return type(other).__name__ == self.name


class TestCompileExpression(unittest.TestCase):

@classmethod
def setUpClass(cls):
context = Connection()
cls.compiler = compiler.Compiler(context)
cls.compiler.table = qe.PostingsEnvironment()
cls.context = Connection()
cls.context.tables['test'] = _Table(0)
# parser for expressions
cls.parser = parser.BQLParser(start='expression')
cls.compiler = compiler.Compiler(cls.context)
cls.compiler.table = cls.context.tables['test']
# reference to the column ``x`` in table ``test`` to simplify the tests
cls.x = cls.context.tables['test'].columns['x']

def compile(self, expr):
expr = self.parser.parse(expr, semantics=parser.BQLSemantics())
return self.compiler.compile(expr)

def test_expr_invalid(self):
with self.assertRaises(CompilationError):
self.compile(ast.Column('invalid'))
self.compile('''invalid''')

def test_expr_column(self):
self.assertEqual(
qe.Column('filename'),
self.compile(ast.Column('filename')))
self.assertEqual(self.compile('''x'''), self.x)

def test_expr_function(self):
self.assertEqual(
qe.SumPosition(None, [qe.Column('position')]),
self.compile(ast.Function('sum', [ast.Column('position')])))
self.assertEqual(self.compile('''sum(x)'''), qe.SumInt(None, [self.x]))

def test_expr_unaryop(self):
self.assertEqual(
qc.Operator(ast.Not, [qe.Column('account')]),
self.compile(ast.Not(ast.Column('account'))))
self.assertEqual(self.compile('''not x'''), qc.Operator(ast.Not, [self.x]))

def test_expr_binaryop(self):
self.assertEqual(
qc.Operator(ast.Equal, [
qe.Column('date'),
qc.EvalConstant(datetime.date(2014, 1, 1))
]),
self.compile(ast.Equal(ast.Column('date'), ast.Constant(datetime.date(2014, 1, 1)))))
self.assertEqual(self.compile('''x = 1'''), qc.Operator(ast.Equal, [self.x, qc.EvalConstant(1)]))

def test_expr_constant(self):
self.assertEqual(
qc.EvalConstant(Decimal(17)),
self.compile(ast.Constant(Decimal(17))))
self.assertEqual(self.compile('''17'''), qc.EvalConstant(D('17')))

def test_expr_function_arity(self):
# Compile with the correct number of arguments.
self.compile(ast.Function('sum', [ast.Column('number')]))

# Compile with an incorrect number of arguments.
# compile with an incorrect number of arguments.
with self.assertRaises(CompilationError):
self.compile(ast.Function('sum', [ast.Column('date'), ast.Column('account')]))
self.compile('''sum(1, 2)''')

def test_constants_folding(self):
# unary op
self.assertEqual(
self.compile(ast.Neg(ast.Constant(2))),
qc.EvalConstant(-2))
self.assertEqual(self.compile('''-2'''), qc.EvalConstant(D('-2')))
# binary op
self.assertEqual(
self.compile(ast.Add(ast.Constant(2), ast.Constant(2))),
qc.EvalConstant(4))
self.assertEqual(self.compile('''2 + 2'''), qc.EvalConstant(D('4')))
# funtion
self.assertEqual(
self.compile(ast.Function('root', [ast.Constant('Assets:Cash'), ast.Constant(1)])),
qc.EvalConstant('Assets'))
self.assertEqual(self.compile('''root('Assets:Cash', 1)'''), qc.EvalConstant('Assets'))


class TestCompileAggregateChecks(unittest.TestCase):
Expand All @@ -96,13 +95,13 @@ def test_is_aggregate_derived(self):
columns, aggregates = compiler.get_columns_and_aggregates(
qc.EvalAnd([
qc.Operator(ast.Equal, [
qe.Column('lineno'),
Column('lineno', int),
qc.EvalConstant(42),
]),
qc.EvalOr([
qc.Operator(ast.Not, [
qc.Operator(ast.Equal, [
qe.Column('date'),
Column('date', datetime.date),
qc.EvalConstant(datetime.date(2014, 1, 1)),
]),
]),
Expand All @@ -114,14 +113,14 @@ def test_is_aggregate_derived(self):
columns, aggregates = compiler.get_columns_and_aggregates(
qc.EvalAnd([
qc.Operator(ast.Equal, [
qe.Column('lineno'),
Column('lineno', int),
qc.EvalConstant(42),
]),
qc.EvalOr([
qc.Operator(ast.Not, [
qc.Operator(ast.Not, [
qc.Operator(ast.Equal, [
qe.Column('date'),
Column('date', datetime.date),
qc.EvalConstant(datetime.date(2014, 1, 1)),
]),
]),
Expand All @@ -134,39 +133,39 @@ def test_is_aggregate_derived(self):

def test_get_columns_and_aggregates(self):
# Simple column.
c_query = qe.Column('position')
c_query = Column('position')
columns, aggregates = compiler.get_columns_and_aggregates(c_query)
self.assertEqual((1, 0), (len(columns), len(aggregates)))
self.assertFalse(compiler.is_aggregate(c_query))

# Multiple columns.
c_query = qc.EvalAnd([qe.Column('position'), qe.Column('date')])
c_query = qc.EvalAnd([Column('position'), Column('date')])
columns, aggregates = compiler.get_columns_and_aggregates(c_query)
self.assertEqual((2, 0), (len(columns), len(aggregates)))
self.assertFalse(compiler.is_aggregate(c_query))

# Simple aggregate.
c_query = qe.SumPosition(None, [qe.Column('position')])
c_query = qe.SumPosition(None, [Column('position')])
columns, aggregates = compiler.get_columns_and_aggregates(c_query)
self.assertEqual((0, 1), (len(columns), len(aggregates)))
self.assertTrue(compiler.is_aggregate(c_query))

# Multiple aggregates.
c_query = qc.EvalAnd([qe.First(None, [qe.Column('date')]), qe.Last(None, [qe.Column('flag')])])
c_query = qc.EvalAnd([qe.First(None, [Column('date')]), qe.Last(None, [Column('flag')])])
columns, aggregates = compiler.get_columns_and_aggregates(c_query)
self.assertEqual((0, 2), (len(columns), len(aggregates)))
self.assertTrue(compiler.is_aggregate(c_query))

# Simple non-aggregate function.
c_query = qe.Function('length', [qe.Column('account')])
c_query = qe.Function('length', [Column('account')])
columns, aggregates = compiler.get_columns_and_aggregates(c_query)
self.assertEqual((1, 0), (len(columns), len(aggregates)))
self.assertFalse(compiler.is_aggregate(c_query))

# Mix of column and aggregates (this is used to detect this illegal case).
c_query = qc.EvalAnd([
qe.Function('length', [qe.Column('account')]),
qe.SumPosition(None, [qe.Column('position')]),
qe.Function('length', [Column('account')]),
qe.SumPosition(None, [Column('position')]),
])
columns, aggregates = compiler.get_columns_and_aggregates(c_query)
self.assertEqual((1, 1), (len(columns), len(aggregates)))
Expand Down Expand Up @@ -275,40 +274,6 @@ def assertCompile(self, expected, query, debug=False):
raise


class TestCompileFundamentals(CompileSelectBase):

def test_operaotors(self):
expr = self.compile("SELECT 1 + 1 AS expr")
self.assertEqual(expr, qc.EvalQuery(Table('postings'), [
qc.EvalTarget(qc.EvalConstant(2), 'expr', False)
], None, None, None, None, None, None))

expr = self.compile("SELECT 1 + meta['int'] AS expr")
self.assertEqual(expr, qc.EvalQuery(Table('postings'), [
qc.EvalTarget(
qc.Operator(ast.Add, [
qc.EvalConstant(1),
qe.Function('decimal', [
qc.EvalGetItem(qe.Column('meta'), 'int')
]),
]), 'expr', False)
], None, None, None, None, None, None))

def test_coalesce(self):
expr = self.compile("SELECT coalesce(narration, str(date), '~') AS expr")
self.assertEqual(expr, qc.EvalQuery(Table('postings'), [
qc.EvalTarget(
qc.EvalCoalesce([
qe.Column('narration'),
qe.Function('str', [qe.Column('date')]),
qc.EvalConstant('~'),
]), 'expr', False)
], None, None, None, None, None, None))

with self.assertRaises(CompilationError):
self.compile("SELECT coalesce(narration, date, 1)")


class TestCompileSelect(CompileSelectBase):

def test_compile_from(self):
Expand Down Expand Up @@ -342,16 +307,15 @@ def test_compile_targets_wildcard(self):
query = self.compile("SELECT *;")
self.assertTrue(list, type(query.c_targets))
self.assertGreater(len(query.c_targets), 3)
self.assertTrue(all(isinstance(target.c_expr, qc.EvalColumn)
for target in query.c_targets))
self.assertTrue(all(isinstance(target.c_expr, qc.EvalColumn) for target in query.c_targets))

def test_compile_targets_named(self):
# Test the wildcard expansion.
query = self.compile("SELECT length(account), account as a, date;")
self.assertEqual(
[qc.EvalTarget(qe.Function('length', [qe.Column('account')]), 'length(account)', False),
qc.EvalTarget(qe.Column('account'), 'a', False),
qc.EvalTarget(qe.Column('date'), 'date', False)],
[qc.EvalTarget(qe.Function('length', [Column('account')]), 'length(account)', False),
qc.EvalTarget(Column('account'), 'a', False),
qc.EvalTarget(Column('date'), 'date', False)],
query.c_targets)

def test_compile_mixed_aggregates(self):
Expand Down Expand Up @@ -744,7 +708,7 @@ def test_print_from(self):
qc.EvalPrint(
Table('entries'),
qc.Operator(ast.Equal, [
qe.Column('year'),
Column('year', int),
qc.EvalConstant(2014),
])),
"PRINT FROM year = 2014;")
Expand Down
Loading