diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 74573458..e074c82d 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -205,6 +205,9 @@ def visit_text(self, type_, **kw): def visit_string(self, type_, **kw): return 'STRING' + def visit_ARRAY(self, type_, **kw): + return "ARRAY<{}>".format(self.process(type_.item_type, **kw)) + def visit_BINARY(self, type_, **kw): return 'BYTES' diff --git a/test/test_sqlalchemy_bigquery.py b/test/test_sqlalchemy_bigquery.py index 831ee351..2ed998eb 100644 --- a/test/test_sqlalchemy_bigquery.py +++ b/test/test_sqlalchemy_bigquery.py @@ -358,6 +358,16 @@ def test_compiled_query_literal_binds(engine, engine_using_test_dataset, table, assert len(result) > 0 +@pytest.mark.parametrize(["column", "processed"], [ + (types.String(), "STRING"), + (types.NUMERIC(), "NUMERIC"), + (types.ARRAY(types.String), "ARRAY"), +]) +def test_compile_types(engine, column, processed): + result = engine.dialect.type_compiler.process(column) + assert result == processed + + def test_joins(session, table, table_one_row): result = (session.query(table.c.string, func.count(table_one_row.c.integer)) .join(table_one_row, table_one_row.c.string == table.c.string)