diff --git a/dbt/adapters/bigquery.py b/dbt/adapters/bigquery.py index b8a2827703f..d2e2678ba31 100644 --- a/dbt/adapters/bigquery.py +++ b/dbt/adapters/bigquery.py @@ -4,6 +4,7 @@ import dbt.compat import dbt.exceptions +import dbt.schema import dbt.flags as flags import dbt.clients.gcloud import dbt.clients.agate_helper @@ -29,7 +30,8 @@ class BigQueryAdapter(PostgresAdapter): "drop", "execute", "quote_schema_and_table", - "make_date_partitioned_table" + "make_date_partitioned_table", + "get_columns_in_table" ] SCOPE = ('https://www.googleapis.com/auth/bigquery', @@ -38,6 +40,8 @@ class BigQueryAdapter(PostgresAdapter): QUERY_TIMEOUT = 300 + Column = dbt.schema.BigQueryColumn + @classmethod def handle_error(cls, error, message, sql): logger.debug(message.format(sql=sql)) @@ -377,15 +381,32 @@ def get_existing_schemas(cls, profile, model_name=None): @classmethod def get_columns_in_table(cls, profile, schema_name, table_name, - model_name=None): - raise dbt.exceptions.NotImplementedException( - '`get_columns_in_table` is not implemented for this adapter!') + database=None, model_name=None): - @classmethod - def get_columns_in_table(cls, profile, schema_name, table_name, - model_name=None): - raise dbt.exceptions.NotImplementedException( - '`get_columns_in_table` is not implemented for this adapter!') + # BigQuery does not have databases -- the database parameter is here + # for consistency with the base implementation + + conn = cls.get_connection(profile, model_name) + client = conn.get('handle') + + try: + dataset_ref = client.dataset(schema_name) + table_ref = dataset_ref.table(table_name) + table = client.get_table(table_ref) + table_schema = table.schema + except (ValueError, google.cloud.exceptions.NotFound) as e: + logger.debug("get_columns_in_table error: {}".format(e)) + table_schema = [] + + columns = [] + for col in table_schema: + name = col.name + data_type = col.field_type + + column = cls.Column(col.name, col.field_type, col.fields, col.mode) + columns.append(column) + + return columns @classmethod def check_schema_exists(cls, profile, schema, model_name=None): diff --git a/dbt/adapters/default.py b/dbt/adapters/default.py index 1f6cd8968c0..92c04fa3e94 100644 --- a/dbt/adapters/default.py +++ b/dbt/adapters/default.py @@ -8,11 +8,11 @@ import dbt.exceptions import dbt.flags +import dbt.schema import dbt.clients.agate_helper from dbt.contracts.connection import validate_connection from dbt.logger import GLOBAL_LOGGER as logger -from dbt.schema import Column lock = multiprocessing.Lock() @@ -44,6 +44,8 @@ class DefaultAdapter(object): "quote", ] + Column = dbt.schema.Column + ### # ADAPTER-SPECIFIC FUNCTIONS -- each of these must be overridden in # every adapter @@ -181,10 +183,12 @@ def get_missing_columns(cls, profile, missing from to_table""" from_columns = {col.name: col for col in cls.get_columns_in_table( - profile, from_schema, from_table, model_name)} + profile, from_schema, from_table, + model_name=model_name)} to_columns = {col.name: col for col in cls.get_columns_in_table( - profile, to_schema, to_table, model_name)} + profile, to_schema, to_table, + model_name=model_name)} missing_columns = set(from_columns.keys()) - set(to_columns.keys()) @@ -192,24 +196,36 @@ def get_missing_columns(cls, profile, if col_name in missing_columns] @classmethod - def _get_columns_in_table_sql(cls, schema_name, table_name): - sql = """ - select column_name, data_type, character_maximum_length - from information_schema.columns - where table_name = '{table_name}' - """.format(table_name=table_name).strip() + def _get_columns_in_table_sql(cls, schema_name, table_name, database): + schema_filter = '1=1' if schema_name is not None: - sql += (" AND table_schema = '{schema_name}'" - .format(schema_name=schema_name)) + schema_filter = "table_schema = '{}'".format(schema_name) + + db_prefix = '' if database is None else '{}.'.format(database) + + sql = """ + select + column_name, + data_type, + character_maximum_length, + numeric_precision || ',' || numeric_scale as numeric_size + + from {db_prefix}information_schema.columns + where table_name = '{table_name}' + and {schema_filter} + order by ordinal_position + """.format(db_prefix=db_prefix, + table_name=table_name, + schema_filter=schema_filter).strip() return sql @classmethod def get_columns_in_table(cls, profile, schema_name, table_name, - model_name=None): + database=None, model_name=None): - sql = cls._get_columns_in_table_sql(schema_name, table_name) + sql = cls._get_columns_in_table_sql(schema_name, table_name, database) connection, cursor = cls.add_query( profile, sql, model_name) @@ -217,8 +233,8 @@ def get_columns_in_table(cls, profile, schema_name, table_name, columns = [] for row in data: - name, data_type, char_size = row - column = Column(name, data_type, char_size) + name, data_type, char_size, numeric_size = row + column = cls.Column(name, data_type, char_size, numeric_size) columns.append(column) return columns @@ -235,18 +251,19 @@ def expand_target_column_types(cls, profile, reference_columns = cls._table_columns_to_dict( cls.get_columns_in_table( - profile, None, temp_table, model_name)) + profile, None, temp_table, model_name=model_name)) target_columns = cls._table_columns_to_dict( cls.get_columns_in_table( - profile, to_schema, to_table, model_name)) + profile, to_schema, to_table, model_name=model_name)) for column_name, reference_column in reference_columns.items(): target_column = target_columns.get(column_name) if target_column is not None and \ target_column.can_expand_to(reference_column): - new_type = Column.string_type(reference_column.string_size()) + col_string_size = reference_column.string_size() + new_type = cls.Column.string_type(col_string_size) logger.debug("Changing col type from %s to %s in table %s.%s", target_column.data_type, new_type, diff --git a/dbt/adapters/redshift.py b/dbt/adapters/redshift.py index d3009e3e5ce..bab408e7628 100644 --- a/dbt/adapters/redshift.py +++ b/dbt/adapters/redshift.py @@ -18,7 +18,10 @@ def date_function(cls): return 'getdate()' @classmethod - def _get_columns_in_table_sql(cls, schema_name, table_name): + def _get_columns_in_table_sql(cls, schema_name, table_name, database): + # Redshift doesn't support cross-database queries, + # so we can ignore the `database` argument + # TODO : how do we make this a macro? if schema_name is None: table_schema_filter = '1=1' @@ -29,10 +32,12 @@ def _get_columns_in_table_sql(cls, schema_name, table_name): sql = """ with bound_views as ( select + ordinal_position, table_schema, column_name, data_type, - character_maximum_length + character_maximum_length, + numeric_precision || ',' || numeric_scale as numeric_size from information_schema.columns where table_name = '{table_name}' @@ -40,18 +45,29 @@ def _get_columns_in_table_sql(cls, schema_name, table_name): unbound_views as ( select + ordinal_position, view_schema, col_name, - col_type, + case + when col_type ilike 'character varying%' then + 'character varying' + when col_type ilike 'numeric%' then 'numeric' + else col_type + end as col_type, case when col_type like 'character%' then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int else null - end as character_maximum_length + end as character_maximum_length, + case + when col_type like 'numeric%' + then nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '') + else null + end as numeric_size from pg_get_late_binding_view_cols() cols(view_schema name, view_name name, col_name name, - col_type varchar, col_num int) + col_type varchar, ordinal_position int) where view_name = '{table_name}' ), @@ -61,9 +77,15 @@ def _get_columns_in_table_sql(cls, schema_name, table_name): select * from unbound_views ) - select column_name, data_type, character_maximum_length + select + column_name, + data_type, + character_maximum_length, + numeric_size + from unioned where {table_schema_filter} + order by ordinal_position """.format(table_name=table_name, table_schema_filter=table_schema_filter).strip() return sql diff --git a/dbt/context/common.py b/dbt/context/common.py index 97188eabea6..4b7a07feaf0 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -300,7 +300,7 @@ def generate(model, project, flat_graph, provider=None): context = dbt.utils.merge(context, { "adapter": db_wrapper, - "column": dbt.schema.Column, + "column": adapter.Column, "config": provider.Config(model), "env_var": _env_var, "exceptions": dbt.exceptions, diff --git a/dbt/node_runners.py b/dbt/node_runners.py index bc74b74558e..8b6f7b0420b 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -251,7 +251,7 @@ def _node_context(cls, adapter, project, node): def call_get_columns_in_table(schema_name, table_name): return adapter.get_columns_in_table( - profile, schema_name, table_name, node.get('name')) + profile, schema_name, table_name, model_name=node.get('name')) def call_get_missing_columns(from_schema, from_table, to_schema, to_table): diff --git a/dbt/schema.py b/dbt/schema.py index f41f46a247f..b66cf95a821 100644 --- a/dbt/schema.py +++ b/dbt/schema.py @@ -2,10 +2,11 @@ class Column(object): - def __init__(self, column, dtype, char_size): + def __init__(self, column, dtype, char_size=None, numeric_size=None): self.column = column self.dtype = dtype self.char_size = char_size + self.numeric_size = numeric_size @property def name(self): @@ -19,12 +20,17 @@ def quoted(self): def data_type(self): if self.is_string(): return Column.string_type(self.string_size()) + elif self.is_numeric(): + return Column.numeric_type(self.dtype, self.numeric_size) else: return self.dtype def is_string(self): return self.dtype.lower() in ['text', 'character varying'] + def is_numeric(self): + return self.dtype.lower() in ['numeric', 'number'] + def string_size(self): if not self.is_string(): raise RuntimeError("Called string_size() on non-string field!") @@ -47,5 +53,71 @@ def can_expand_to(self, other_column): def string_type(cls, size): return "character varying({})".format(size) + @classmethod + def numeric_type(cls, dtype, size): + # This could be decimal(...), numeric(...), number(...) + # Just use whatever was fed in here -- don't try to get too clever + return "{}({})".format(dtype, size) + def __repr__(self): return "".format(self.name, self.data_type) + + +class BigQueryColumn(Column): + def __init__(self, column, dtype, fields, mode): + super(BigQueryColumn, self).__init__(column, dtype) + + self.mode = mode + self.fields = self.wrap_subfields(fields) + + @classmethod + def wrap_subfields(cls, fields): + return [BigQueryColumn.create(field) for field in fields] + + @classmethod + def create(cls, field): + return BigQueryColumn(field.name, field.field_type, field.fields, + field.mode) + + @classmethod + def _flatten_recursive(cls, col, prefix=None): + if prefix is None: + prefix = [] + + if len(col.fields) == 0: + prefixed_name = ".".join(prefix + [col.column]) + new_col = BigQueryColumn(prefixed_name, col.dtype, col.fields, + col.mode) + return [new_col] + + new_fields = [] + for field in col.fields: + new_prefix = prefix + [col.column] + new_fields.extend(cls._flatten_recursive(field, new_prefix)) + + return new_fields + + def flatten(self): + return self._flatten_recursive(self) + + @property + def quoted(self): + return '`{}`'.format(self.column) + + @property + def data_type(self): + return self.dtype + + def is_string(self): + return self.dtype.lower() == 'string' + + def is_numeric(self): + return False + + def can_expand_to(self, other_column): + """returns True if both columns are strings""" + return self.is_string() and other_column.is_string() + + def __repr__(self): + return "".format(self.name, self.data_type, + self.mode) diff --git a/test/integration/022_bigquery_test/adapter-models/schema.yml b/test/integration/022_bigquery_test/adapter-models/schema.yml new file mode 100644 index 00000000000..31a8249b31f --- /dev/null +++ b/test/integration/022_bigquery_test/adapter-models/schema.yml @@ -0,0 +1,21 @@ + + +test_get_columns_in_table: + constraints: + not_null: + - field_1 + - field_2 + - field_3 + - nested_field + - repeated_column + + +test_flattened_get_columns_in_table: + constraints: + not_null: + - field_1 + - field_2 + - field_3 + - field_4 + - field_5 + - repeated_column diff --git a/test/integration/022_bigquery_test/adapter-models/source.sql b/test/integration/022_bigquery_test/adapter-models/source.sql new file mode 100644 index 00000000000..2c419f25d22 --- /dev/null +++ b/test/integration/022_bigquery_test/adapter-models/source.sql @@ -0,0 +1,41 @@ + +with nested_base as ( + select + struct( + 'a' as field_a, + 'b' as field_b + ) as repeated_nested + + union all + + select + struct( + 'a' as field_a, + 'b' as field_b + ) as repeated_nested +), + +nested as ( + + select + array_agg(repeated_nested) as repeated_column + + from nested_base + +), + +base as ( + + select + 1 as field_1, + 2 as field_2, + 3 as field_3, + + struct( + 4 as field_4, + 5 as field_5 + ) as nested_field +) + +select * +from base, nested diff --git a/test/integration/022_bigquery_test/adapter-models/test_flattened_get_columns_in_table.sql b/test/integration/022_bigquery_test/adapter-models/test_flattened_get_columns_in_table.sql new file mode 100644 index 00000000000..1a741dbd102 --- /dev/null +++ b/test/integration/022_bigquery_test/adapter-models/test_flattened_get_columns_in_table.sql @@ -0,0 +1,21 @@ + + +{% set source = ref('source') %} +{% set cols = adapter.get_columns_in_table(source.schema, source.name) %} + +{% set flattened = [] %} +{% for col in cols %} + {% if col.mode == 'REPEATED' %} + {% set _ = flattened.append(col) %} + {% else %} + {% set _ = flattened.extend(col.flatten()) %} + {% endif %} +{% endfor %} + +select + {% for col in flattened %} + {{ col.name }} + {% if not loop.last %}, {% endif %} + {% endfor %} + +from {{ source }} diff --git a/test/integration/022_bigquery_test/adapter-models/test_get_columns_in_table.sql b/test/integration/022_bigquery_test/adapter-models/test_get_columns_in_table.sql new file mode 100644 index 00000000000..3653fee869c --- /dev/null +++ b/test/integration/022_bigquery_test/adapter-models/test_get_columns_in_table.sql @@ -0,0 +1,12 @@ + + +{% set source = ref('source') %} +{% set cols = adapter.get_columns_in_table(source.schema, source.name) %} + +select + {% for col in cols %} + {{ col.name }} + {% if not loop.last %}, {% endif %} + {% endfor %} + +from {{ source }} diff --git a/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py b/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py new file mode 100644 index 00000000000..95bde58d566 --- /dev/null +++ b/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py @@ -0,0 +1,33 @@ +from nose.plugins.attrib import attr +from test.integration.base import DBTIntegrationTest, FakeArgs + + +class TestBigqueryAdapterFunctions(DBTIntegrationTest): + + @property + def schema(self): + return "bigquery_test_022" + + @property + def models(self): + return "test/integration/022_bigquery_test/adapter-models" + + @property + def profile_config(self): + return self.bigquery_profile() + + @attr(type='bigquery') + def test__bigquery_adapter_functions(self): + self.use_profile('bigquery') + self.use_default_project() + self.run_dbt() + + test_results = self.run_dbt(['test']) + + self.assertTrue(len(test_results) > 0) + for result in test_results: + self.assertFalse(result.errored) + self.assertFalse(result.skipped) + # status = # of failing rows + self.assertEqual(result.status, 0) +