Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_columns_in_table parity #709

Merged
merged 11 commits into from
Apr 6, 2018
39 changes: 30 additions & 9 deletions dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dbt.compat
import dbt.exceptions
import dbt.schema
import dbt.flags as flags
import dbt.clients.gcloud
import dbt.clients.agate_helper
Expand All @@ -29,7 +30,8 @@ class BigQueryAdapter(PostgresAdapter):
"drop",
"execute",
"quote_schema_and_table",
"make_date_partitioned_table"
"make_date_partitioned_table",
"get_columns_in_table"
]

SCOPE = ('https://www.googleapis.com/auth/bigquery',
Expand All @@ -38,6 +40,8 @@ class BigQueryAdapter(PostgresAdapter):

QUERY_TIMEOUT = 300

Column = dbt.schema.BigQueryColumn

@classmethod
def handle_error(cls, error, message, sql):
logger.debug(message.format(sql=sql))
Expand Down Expand Up @@ -377,15 +381,32 @@ def get_existing_schemas(cls, profile, model_name=None):

@classmethod
def get_columns_in_table(cls, profile, schema_name, table_name,
model_name=None):
raise dbt.exceptions.NotImplementedException(
'`get_columns_in_table` is not implemented for this adapter!')
database=None, model_name=None):

@classmethod
def get_columns_in_table(cls, profile, schema_name, table_name,
model_name=None):
raise dbt.exceptions.NotImplementedException(
'`get_columns_in_table` is not implemented for this adapter!')
# BigQuery does not have databases -- the database parameter is here
# for consistency with the base implementation

conn = cls.get_connection(profile, model_name)
client = conn.get('handle')

try:
dataset_ref = client.dataset(schema_name)
table_ref = dataset_ref.table(table_name)
table = client.get_table(table_ref)
table_schema = table.schema
except (ValueError, google.cloud.exceptions.NotFound) as e:
logger.debug("get_columns_in_table error: {}".format(e))
table_schema = []

columns = []
for col in table_schema:
name = col.name
data_type = col.field_type

column = cls.Column(col.name, col.field_type, col.fields, col.mode)
columns.append(column)

return columns

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
Expand Down
53 changes: 35 additions & 18 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import dbt.exceptions
import dbt.flags
import dbt.schema
import dbt.clients.agate_helper

from dbt.contracts.connection import validate_connection
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.schema import Column


lock = multiprocessing.Lock()
Expand Down Expand Up @@ -44,6 +44,8 @@ class DefaultAdapter(object):
"quote",
]

Column = dbt.schema.Column

###
# ADAPTER-SPECIFIC FUNCTIONS -- each of these must be overridden in
# every adapter
Expand Down Expand Up @@ -181,44 +183,58 @@ def get_missing_columns(cls, profile,
missing from to_table"""
from_columns = {col.name: col for col in
cls.get_columns_in_table(
profile, from_schema, from_table, model_name)}
profile, from_schema, from_table,
model_name=model_name)}
to_columns = {col.name: col for col in
cls.get_columns_in_table(
profile, to_schema, to_table, model_name)}
profile, to_schema, to_table,
model_name=model_name)}

missing_columns = set(from_columns.keys()) - set(to_columns.keys())

return [col for (col_name, col) in from_columns.items()
if col_name in missing_columns]

@classmethod
def _get_columns_in_table_sql(cls, schema_name, table_name):
sql = """
select column_name, data_type, character_maximum_length
from information_schema.columns
where table_name = '{table_name}'
""".format(table_name=table_name).strip()
def _get_columns_in_table_sql(cls, schema_name, table_name, database):

schema_filter = '1=1'
if schema_name is not None:
sql += (" AND table_schema = '{schema_name}'"
.format(schema_name=schema_name))
schema_filter = "table_schema = '{}'".format(schema_name)

db_prefix = '' if database is None else '{}.'.format(database)

sql = """
select
column_name,
data_type,
character_maximum_length,
numeric_precision || ',' || numeric_scale as numeric_size

from {db_prefix}information_schema.columns
where table_name = '{table_name}'
and {schema_filter}
order by ordinal_position
""".format(db_prefix=db_prefix,
table_name=table_name,
schema_filter=schema_filter).strip()

return sql

@classmethod
def get_columns_in_table(cls, profile, schema_name, table_name,
model_name=None):
database=None, model_name=None):

sql = cls._get_columns_in_table_sql(schema_name, table_name)
sql = cls._get_columns_in_table_sql(schema_name, table_name, database)
connection, cursor = cls.add_query(
profile, sql, model_name)

data = cursor.fetchall()
columns = []

for row in data:
name, data_type, char_size = row
column = Column(name, data_type, char_size)
name, data_type, char_size, numeric_size = row
column = cls.Column(name, data_type, char_size, numeric_size)
columns.append(column)

return columns
Expand All @@ -235,18 +251,19 @@ def expand_target_column_types(cls, profile,

reference_columns = cls._table_columns_to_dict(
cls.get_columns_in_table(
profile, None, temp_table, model_name))
profile, None, temp_table, model_name=model_name))

target_columns = cls._table_columns_to_dict(
cls.get_columns_in_table(
profile, to_schema, to_table, model_name))
profile, to_schema, to_table, model_name=model_name))

for column_name, reference_column in reference_columns.items():
target_column = target_columns.get(column_name)

if target_column is not None and \
target_column.can_expand_to(reference_column):
new_type = Column.string_type(reference_column.string_size())
col_string_size = reference_column.string_size()
new_type = cls.Column.string_type(col_string_size)
logger.debug("Changing col type from %s to %s in table %s.%s",
target_column.data_type,
new_type,
Expand Down
34 changes: 28 additions & 6 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def date_function(cls):
return 'getdate()'

@classmethod
def _get_columns_in_table_sql(cls, schema_name, table_name):
def _get_columns_in_table_sql(cls, schema_name, table_name, database):
# Redshift doesn't support cross-database queries,
# so we can ignore the `database` argument

# TODO : how do we make this a macro?
if schema_name is None:
table_schema_filter = '1=1'
Expand All @@ -29,29 +32,42 @@ def _get_columns_in_table_sql(cls, schema_name, table_name):
sql = """
with bound_views as (
select
ordinal_position,
table_schema,
column_name,
data_type,
character_maximum_length
character_maximum_length,
numeric_precision || ',' || numeric_scale as numeric_size

from information_schema.columns
where table_name = '{table_name}'
),

unbound_views as (
select
ordinal_position,
view_schema,
col_name,
col_type,
case
when col_type ilike 'character varying%' then
'character varying'
when col_type ilike 'numeric%' then 'numeric'
else col_type
end as col_type,
case
when col_type like 'character%'
then nullif(REGEXP_SUBSTR(col_type, '[0-9]+'), '')::int
else null
end as character_maximum_length
end as character_maximum_length,
case
when col_type like 'numeric%'
then nullif(REGEXP_SUBSTR(col_type, '[0-9,]+'), '')
else null
end as numeric_size

from pg_get_late_binding_view_cols()
cols(view_schema name, view_name name, col_name name,
col_type varchar, col_num int)
col_type varchar, ordinal_position int)
where view_name = '{table_name}'
),

Expand All @@ -61,9 +77,15 @@ def _get_columns_in_table_sql(cls, schema_name, table_name):
select * from unbound_views
)

select column_name, data_type, character_maximum_length
select
column_name,
data_type,
character_maximum_length,
numeric_size

from unioned
where {table_schema_filter}
order by ordinal_position
""".format(table_name=table_name,
table_schema_filter=table_schema_filter).strip()
return sql
Expand Down
2 changes: 1 addition & 1 deletion dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def generate(model, project, flat_graph, provider=None):

context = dbt.utils.merge(context, {
"adapter": db_wrapper,
"column": dbt.schema.Column,
"column": adapter.Column,
"config": provider.Config(model),
"env_var": _env_var,
"exceptions": dbt.exceptions,
Expand Down
2 changes: 1 addition & 1 deletion dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _node_context(cls, adapter, project, node):

def call_get_columns_in_table(schema_name, table_name):
return adapter.get_columns_in_table(
profile, schema_name, table_name, node.get('name'))
profile, schema_name, table_name, model_name=node.get('name'))

def call_get_missing_columns(from_schema, from_table,
to_schema, to_table):
Expand Down
74 changes: 73 additions & 1 deletion dbt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@


class Column(object):
def __init__(self, column, dtype, char_size):
def __init__(self, column, dtype, char_size=None, numeric_size=None):
self.column = column
self.dtype = dtype
self.char_size = char_size
self.numeric_size = numeric_size

@property
def name(self):
Expand All @@ -19,12 +20,17 @@ def quoted(self):
def data_type(self):
if self.is_string():
return Column.string_type(self.string_size())
elif self.is_numeric():
return Column.numeric_type(self.dtype, self.numeric_size)
else:
return self.dtype

def is_string(self):
return self.dtype.lower() in ['text', 'character varying']

def is_numeric(self):
return self.dtype.lower() in ['numeric', 'number']

def string_size(self):
if not self.is_string():
raise RuntimeError("Called string_size() on non-string field!")
Expand All @@ -47,5 +53,71 @@ def can_expand_to(self, other_column):
def string_type(cls, size):
return "character varying({})".format(size)

@classmethod
def numeric_type(cls, dtype, size):
# This could be decimal(...), numeric(...), number(...)
# Just use whatever was fed in here -- don't try to get too clever
return "{}({})".format(dtype, size)

def __repr__(self):
return "<Column {} ({})>".format(self.name, self.data_type)


class BigQueryColumn(Column):
def __init__(self, column, dtype, fields, mode):
super(BigQueryColumn, self).__init__(column, dtype)

self.mode = mode
self.fields = self.wrap_subfields(fields)

@classmethod
def wrap_subfields(cls, fields):
return [BigQueryColumn.create(field) for field in fields]

@classmethod
def create(cls, field):
return BigQueryColumn(field.name, field.field_type, field.fields,
field.mode)

@classmethod
def _flatten_recursive(cls, col, prefix=None):
if prefix is None:
prefix = []

if len(col.fields) == 0:
prefixed_name = ".".join(prefix + [col.column])
new_col = BigQueryColumn(prefixed_name, col.dtype, col.fields,
col.mode)
return [new_col]

new_fields = []
for field in col.fields:
new_prefix = prefix + [col.column]
new_fields.extend(cls._flatten_recursive(field, new_prefix))

return new_fields

def flatten(self):
return self._flatten_recursive(self)

@property
def quoted(self):
return '`{}`'.format(self.column)

@property
def data_type(self):
return self.dtype

def is_string(self):
return self.dtype.lower() == 'string'

def is_numeric(self):
return False

def can_expand_to(self, other_column):
"""returns True if both columns are strings"""
return self.is_string() and other_column.is_string()

def __repr__(self):
return "<BigQueryColumn {} ({}, {})>".format(self.name, self.data_type,
self.mode)
Loading