Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_columns_in_table parity #709

Merged
merged 11 commits into from
Apr 6, 2018
38 changes: 29 additions & 9 deletions dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from dbt.adapters.postgres import PostgresAdapter
from dbt.contracts.connection import validate_connection
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.schema import BigQueryColumn

import google.auth
import google.api_core
import google.oauth2
import google.cloud.exceptions
import google.cloud.bigquery
Expand All @@ -29,7 +31,8 @@ class BigQueryAdapter(PostgresAdapter):
"drop",
"execute",
"quote_schema_and_table",
"make_date_partitioned_table"
"make_date_partitioned_table",
"get_columns_in_table"
]

SCOPE = ('https://www.googleapis.com/auth/bigquery',
Expand Down Expand Up @@ -377,15 +380,32 @@ def get_existing_schemas(cls, profile, model_name=None):

@classmethod
def get_columns_in_table(cls, profile, schema_name, table_name,
model_name=None):
raise dbt.exceptions.NotImplementedException(
'`get_columns_in_table` is not implemented for this adapter!')
database=None, model_name=None):

@classmethod
def get_columns_in_table(cls, profile, schema_name, table_name,
model_name=None):
raise dbt.exceptions.NotImplementedException(
'`get_columns_in_table` is not implemented for this adapter!')
# BigQuery does not have databases -- the database parameter is here
# for consistency with the base implementation

conn = cls.get_connection(profile, model_name)
client = conn.get('handle')

try:
dataset_ref = client.dataset(schema_name)
table_ref = dataset_ref.table(table_name)
table = client.get_table(table_ref)
table_schema = table.schema
except (ValueError, google.api_core.exceptions.NotFound) as e:
logger.debug("get_columns_in_table error: {}".format(e))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i want to talk to you about error behavior in adapters, but this is fine for now

table_schema = []

columns = []
for col in table_schema:
name = col.name
data_type = col.field_type

column = BigQueryColumn(col.name, col.field_type, col.fields)
columns.append(column)

return columns

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
Expand Down
46 changes: 30 additions & 16 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,44 +181,58 @@ def get_missing_columns(cls, profile,
missing from to_table"""
from_columns = {col.name: col for col in
cls.get_columns_in_table(
profile, from_schema, from_table, model_name)}
profile, from_schema, from_table,
model_name=model_name)}
to_columns = {col.name: col for col in
cls.get_columns_in_table(
profile, to_schema, to_table, model_name)}
profile, to_schema, to_table,
model_name=model_name)}

missing_columns = set(from_columns.keys()) - set(to_columns.keys())

return [col for (col_name, col) in from_columns.items()
if col_name in missing_columns]

@classmethod
def _get_columns_in_table_sql(cls, schema_name, table_name):
sql = """
select column_name, data_type, character_maximum_length
from information_schema.columns
where table_name = '{table_name}'
""".format(table_name=table_name).strip()
def _get_columns_in_table_sql(cls, schema_name, table_name, database):

schema_filter = '1=1'
if schema_name is not None:
sql += (" AND table_schema = '{schema_name}'"
.format(schema_name=schema_name))
schema_filter = "table_schema = '{}'".format(schema_name)

db_prefix = '' if database is None else '{}.'.format(database)

sql = """
select
column_name,
data_type,
character_maximum_length,
numeric_precision || ',' || numeric_scale as numeric_size

from {db_prefix}information_schema.columns
where table_name = '{table_name}'
and {schema_filter}
order by ordinal_position
""".format(db_prefix=db_prefix,
table_name=table_name,
schema_filter=schema_filter).strip()

return sql

@classmethod
def get_columns_in_table(cls, profile, schema_name, table_name,
model_name=None):
database=None, model_name=None):

sql = cls._get_columns_in_table_sql(schema_name, table_name)
sql = cls._get_columns_in_table_sql(schema_name, table_name, database)
connection, cursor = cls.add_query(
profile, sql, model_name)

data = cursor.fetchall()
columns = []

for row in data:
name, data_type, char_size = row
column = Column(name, data_type, char_size)
name, data_type, char_size, numeric_size = row
column = Column(name, data_type, char_size, numeric_size)
columns.append(column)

return columns
Expand All @@ -235,11 +249,11 @@ def expand_target_column_types(cls, profile,

reference_columns = cls._table_columns_to_dict(
cls.get_columns_in_table(
profile, None, temp_table, model_name))
profile, None, temp_table, model_name=model_name))

target_columns = cls._table_columns_to_dict(
cls.get_columns_in_table(
profile, to_schema, to_table, model_name))
profile, to_schema, to_table, model_name=model_name))

for column_name, reference_column in reference_columns.items():
target_column = target_columns.get(column_name)
Expand Down
34 changes: 28 additions & 6 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def date_function(cls):
return 'getdate()'

@classmethod
def _get_columns_in_table_sql(cls, schema_name, table_name):
def _get_columns_in_table_sql(cls, schema_name, table_name, database):
# Redshift doesn't support cross-database queries,
# so we can ignore the `database` argument

# TODO : how do we make this a macro?
if schema_name is None:
table_schema_filter = '1=1'
Expand All @@ -29,29 +32,42 @@ def _get_columns_in_table_sql(cls, schema_name, table_name):
sql = """
with bound_views as (
select
ordinal_position,
table_schema,
column_name,
data_type,
character_maximum_length
character_maximum_length,
numeric_precision || ',' || numeric_scale as numeric_size

from information_schema.columns
where table_name = '{table_name}'
),

unbound_views as (
select
ordinal_position,
view_schema,
col_name,
col_type,
case
when col_type ilike 'character varying%' then
'character varying'
when col_type ilike 'numeric%' then 'numeric'
else col_type
end as col_type,
case
when col_type like 'character%'
then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int
else null
end as character_maximum_length
end as character_maximum_length,
case
when col_type like 'numeric%'
then nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '')
else null
end as numeric_size

from pg_get_late_binding_view_cols()
cols(view_schema name, view_name name, col_name name,
col_type varchar, col_num int)
col_type varchar, ordinal_position int)
where view_name = '{table_name}'
),

Expand All @@ -61,9 +77,15 @@ def _get_columns_in_table_sql(cls, schema_name, table_name):
select * from unbound_views
)

select column_name, data_type, character_maximum_length
select
column_name,
data_type,
character_maximum_length,
numeric_size

from unioned
where {table_schema_filter}
order by ordinal_position
""".format(table_name=table_name,
table_schema_filter=table_schema_filter).strip()
return sql
Expand Down
2 changes: 1 addition & 1 deletion dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _node_context(cls, adapter, project, node):

def call_get_columns_in_table(schema_name, table_name):
return adapter.get_columns_in_table(
profile, schema_name, table_name, node.get('name'))
profile, schema_name, table_name, model_name=node.get('name'))

def call_get_missing_columns(from_schema, from_table,
to_schema, to_table):
Expand Down
70 changes: 69 additions & 1 deletion dbt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@


class Column(object):
def __init__(self, column, dtype, char_size):
def __init__(self, column, dtype, char_size=None, numeric_size=None):
self.column = column
self.dtype = dtype
self.char_size = char_size
self.numeric_size = numeric_size

@property
def name(self):
Expand All @@ -19,12 +20,17 @@ def quoted(self):
def data_type(self):
if self.is_string():
return Column.string_type(self.string_size())
elif self.is_numeric():
return Column.numeric_type(self.dtype, self.numeric_size)
else:
return self.dtype

def is_string(self):
return self.dtype.lower() in ['text', 'character varying']

def is_numeric(self):
return self.dtype.lower() in ['numeric', 'number']

def string_size(self):
if not self.is_string():
raise RuntimeError("Called string_size() on non-string field!")
Expand All @@ -47,5 +53,67 @@ def can_expand_to(self, other_column):
def string_type(cls, size):
return "character varying({})".format(size)

@classmethod
def numeric_type(cls, dtype, size):
# This could be decimal(...), numeric(...), number(...)
# Just use whatever was fed in here -- don't try to get too clever
return "{}({})".format(dtype, size)

def __repr__(self):
return "<Column {} ({})>".format(self.name, self.data_type)


class BigQueryColumn(Column):
def __init__(self, column, dtype, fields):
super(BigQueryColumn, self).__init__(column, dtype)

self.fields = self.wrap_subfields(fields)

@classmethod
def wrap_subfields(cls, fields):
return [BigQueryColumn.create(field) for field in fields]

@classmethod
def create(cls, field):
return BigQueryColumn(field.name, field.field_type, field.fields)

@classmethod
def _flatten_recursive(cls, col, prefix=None):
if prefix is None:
prefix = []

if len(col.fields) == 0:
prefixed_name = ".".join(prefix + [col.column])
new_col = BigQueryColumn(prefixed_name, col.dtype, col.fields)
return [new_col]

new_fields = []
for field in col.fields:
new_prefix = prefix + [col.column]
new_fields.extend(cls._flatten_recursive(field, new_prefix))

return new_fields

def flatten(self):
return self._flatten_recursive(self)

@property
def quoted(self):
return '`{}`'.format(self.column)

@property
def data_type(self):
return self.dtype

def is_string(self):
return self.dtype.lower() == 'string'

def is_numeric(self):
return False

def can_expand_to(self, other_column):
"""returns True if both columns are strings"""
return self.is_string() and other_column.is_string()

def __repr__(self):
return "<BigQueryColumn {} ({})>".format(self.name, self.data_type)
12 changes: 12 additions & 0 deletions test/integration/022_bigquery_test/adapter-models/adapters.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@


{% set source = ref('source') %}
{% set cols = adapter.get_columns_in_table(source.schema, source.name) %}

select
{% for col in cols %}
{{ col.name }}
{% if not loop.last %}, {% endif %}
{% endfor %}

from {{ source }}
8 changes: 8 additions & 0 deletions test/integration/022_bigquery_test/adapter-models/schema.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


adapters:
constraints:
not_null:
- field_1
- field_2
- field_3
6 changes: 6 additions & 0 deletions test/integration/022_bigquery_test/adapter-models/source.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

select
1 as field_1,
2 as field_2,
3 as field_3

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from nose.plugins.attrib import attr
from test.integration.base import DBTIntegrationTest, FakeArgs


class TestBigqueryAdapterFunctions(DBTIntegrationTest):

@property
def schema(self):
return "bigquery_test_022"

@property
def models(self):
return "test/integration/022_bigquery_test/adapter-models"

@property
def profile_config(self):
return self.bigquery_profile()

@attr(type='bigquery')
def test__bigquery_date_partitioning(self):
self.use_profile('bigquery')
self.use_default_project()
self.run_dbt()

test_results = self.run_dbt(['test'])

self.assertTrue(len(test_results) > 0)
for result in test_results:
self.assertFalse(result.errored)
self.assertFalse(result.skipped)
# status = # of failing rows
self.assertEqual(result.status, 0)