Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concurrency: make it work #345

Merged
merged 17 commits into from
Mar 24, 2017
527 changes: 527 additions & 0 deletions dbt/adapters/default.py

Large diffs are not rendered by default.

496 changes: 60 additions & 436 deletions dbt/adapters/postgres.py

Large diffs are not rendered by default.

35 changes: 6 additions & 29 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,16 @@
import copy

import dbt.flags as flags

from dbt.adapters.postgres import PostgresAdapter
from dbt.contracts.connection import validate_connection
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.logger import GLOBAL_LOGGER as logger # noqa


class RedshiftAdapter(PostgresAdapter):

date_function = 'getdate()'

@classmethod
def acquire_connection(cls, profile):
# profile requires some marshalling right now because it includes a
# wee bit of global config.
# TODO remove this
credentials = copy.deepcopy(profile)

credentials.pop('type', None)
credentials.pop('threads', None)
def type(cls):
return 'redshift'

result = {
'type': 'redshift',
'state': 'init',
'handle': None,
'credentials': credentials
}

logger.info('Connecting to redshift.')

if flags.STRICT_MODE:
validate_connection(result)

return cls.open_connection(result)
@classmethod
def date_function(cls):
return 'getdate()'

@classmethod
def dist_qualifier(cls, dist):
Expand Down
212 changes: 61 additions & 151 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import absolute_import

import copy
import re
import time
import yaml

import snowflake.connector
import snowflake.connector.errors
Expand All @@ -17,74 +14,53 @@
from dbt.contracts.connection import validate_connection
from dbt.logger import GLOBAL_LOGGER as logger

connection_cache = {}


@contextmanager
def exception_handler(connection, cursor, model_name, query):
handle = connection.get('handle')
schema = connection.get('credentials', {}).get('schema')

try:
yield
except snowflake.connector.errors.ProgrammingError as e:
if 'Empty SQL statement' in e.msg:
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in e.msg:
handle.rollback()
raise dbt.exceptions.FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid credentials '
'are provided, or when your default role does not have '
'access to use the specified database. Please double check '
'your profile and try again.').format(str(e)))
else:
handle.rollback()
raise dbt.exceptions.ProgrammingException(str(e))
except Exception as e:
handle.rollback()
logger.debug("Error running SQL: %s", query)
logger.debug("rolling back connection")
raise e


class SnowflakeAdapter(PostgresAdapter):

date_function = 'CURRENT_TIMESTAMP()'

@classmethod
def acquire_connection(cls, profile):

# profile requires some marshalling right now because it includes a
# wee bit of global config.
# TODO remove this
credentials = copy.deepcopy(profile)
@contextmanager
def exception_handler(cls, profile, sql, model_name=None,
connection_name='master'):
connection = cls.get_connection(profile, connection_name)

credentials.pop('type', None)
credentials.pop('threads', None)
try:
yield
except snowflake.connector.errors.ProgrammingError as e:
if 'Empty SQL statement' in e.msg:
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in e.msg:
cls.rollback(connection)
raise dbt.exceptions.FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid '
'credentials are provided, or when your default role '
'does not have access to use the specified database. '
'Please double check your profile and try again.')
.format(str(e)))
else:
cls.rollback(connection)
raise dbt.exceptions.ProgrammingException(str(e))
except Exception as e:
logger.debug("Error running SQL: %s", sql)
logger.debug("Rolling back transaction.")
cls.rollback(connection)
raise e

result = {
'type': 'snowflake',
'state': 'init',
'handle': None,
'credentials': credentials
}
@classmethod
def type(cls):
return 'snowflake'

logger.info('Connecting to snowflake.')
@classmethod
def date_function(cls):
return 'CURRENT_TIMESTAMP()'

if flags.STRICT_MODE:
validate_connection(result)
@classmethod
def get_status(cls, cursor):
state = cursor.sqlstate

return cls.open_connection(result)
if state is None:
state = 'SUCCESS'

@staticmethod
def hash_profile(profile):
return ("{}--{}--{}--{}--{}".format(
profile.get('account'),
profile.get('database'),
profile.get('schema'),
profile.get('user'),
profile.get('warehouse'),
))
return "{} {}".format(state, cursor.rowcount)

@classmethod
def open_connection(cls, connection):
Expand Down Expand Up @@ -122,20 +98,14 @@ def open_connection(cls, connection):
return result

@classmethod
def query_for_existing(cls, profile, schema):
query = """
def query_for_existing(cls, profile, schema, model_name=None):
sql = """
select TABLE_NAME as name, TABLE_TYPE as type
from INFORMATION_SCHEMA.TABLES
where TABLE_SCHEMA = '{schema}'
""".format(schema=schema).strip() # noqa

connection = cls.get_connection(profile)

if flags.STRICT_MODE:
validate_connection(connection)

_, cursor = cls.add_query_to_transaction(
query, connection, schema)
_, cursor = cls.add_query(profile, sql, model_name)
results = cursor.fetchall()

relation_type_lookup = {
Expand All @@ -148,90 +118,38 @@ def query_for_existing(cls, profile, schema):

return dict(existing)

@classmethod
def get_status(cls, cursor):
state = cursor.sqlstate

if state is None:
state = 'SUCCESS'

return "{} {}".format(state, cursor.rowcount)

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
connection = cls.get_connection(profile)
schema = cls.get_default_schema(profile)

if flags.STRICT_MODE:
validate_connection(connection)

schema = connection.get('credentials', {}).get('schema')

# in snowflake, if you fail to include the quoted schema in the
# identifier, the new table will have `schema.upper()` as its new
# schema
query = ('''
alter table "{schema}"."{from_name}"
rename to "{schema}"."{to_name}"
'''.format(
schema=schema,
from_name=from_name,
to_name=to_name)).strip()
sql = (('alter table "{schema}"."{from_name}" '
'rename to "{schema}"."{to_name}"')
.format(schema=schema,
from_name=from_name,
to_name=to_name))

handle, cursor = cls.add_query_to_transaction(
query, connection, model_name)
connection, cursor = cls.add_query(profile, sql, model_name)

@classmethod
def execute_model(cls, profile, model):
parts = re.split(r'-- (DBT_OPERATION .*)', model.get('wrapped_sql'))
connection = cls.get_connection(profile)
connection = cls.get_connection(profile, model.get('name'))

if flags.STRICT_MODE:
validate_connection(connection)

# snowflake requires a schema to be specified for temporary tables
# TODO setup templates to be adapter-specific. then we can just use
# the existing schema for temp tables.
cls.add_query_to_transaction(
'USE SCHEMA "{}"'.format(
connection.get('credentials', {}).get('schema')),
connection)

for i, part in enumerate(parts):
matches = re.match(r'^DBT_OPERATION ({.*})$', part)
if matches is not None:
instruction_string = matches.groups()[0]
instruction = yaml.safe_load(instruction_string)
function = instruction['function']
kwargs = instruction['args']

def call_expand_target_column_types(kwargs):
kwargs.update({'profile': profile})
return cls.expand_target_column_types(**kwargs)

func_map = {
'expand_column_types_if_needed':
call_expand_target_column_types
}

func_map[function](kwargs)
else:
handle, cursor = cls.add_query_to_transaction(
part, connection, model.get('name'))

handle.commit()

status = cls.get_status(cursor)
cursor.close()

return status
return super(PostgresAdapter, cls).execute_model(
profile, model)

@classmethod
def add_query_to_transaction(cls, query, connection, model_name=None):
handle = connection.get('handle')
cursor = handle.cursor()

def add_query(cls, profile, sql, model_name=None):
# snowflake only allows one query per api call.
queries = query.strip().split(";")
queries = sql.strip().split(";")
cursor = None

super(PostgresAdapter, cls).add_query(
profile,
'use schema "{}"'.format(cls.get_default_schema(profile)),
model_name)

for individual_query in queries:
# hack -- after the last ';', remove comments and don't run
Expand All @@ -243,15 +161,7 @@ def add_query_to_transaction(cls, query, connection, model_name=None):

if without_comments == "":
continue
connection, cursor = super(PostgresAdapter, cls).add_query(
profile, individual_query, model_name)

with exception_handler(connection, cursor,
model_name, individual_query):
logger.debug("SQL: %s", individual_query)
pre = time.time()
cursor.execute(individual_query)
post = time.time()
logger.debug(
"SQL status: %s in %0.2f seconds",
cls.get_status(cursor), post-pre)

return handle, cursor
return connection, cursor
4 changes: 2 additions & 2 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get_compiler_context(self, linker, model, flat_graph):
# these get re-interpolated at runtime!
context['run_started_at'] = '{{ run_started_at }}'
context['invocation_id'] = '{{ invocation_id }}'
context['sql_now'] = adapter.date_function
context['sql_now'] = adapter.date_function()

context = recursively_parse_macros_for_node(
model, flat_graph, context)
Expand Down Expand Up @@ -310,7 +310,7 @@ def get_context(self, linker, model, models):
context['invocation_id'] = '{{ invocation_id }}'

adapter = get_adapter(self.project.run_environment())
context['sql_now'] = adapter.date_function
context['sql_now'] = adapter.date_function()

runtime.update_global(context)

Expand Down
6 changes: 4 additions & 2 deletions dbt/contracts/connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from voluptuous import Schema, Required, All, Any, Extra, Range, Optional
from voluptuous import Schema, Required, All, Any, Range, Optional

from dbt.compat import basestring
from dbt.contracts.common import validate_with
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.logger import GLOBAL_LOGGER as logger # noqa


connection_contract = Schema({
Required('type'): Any('postgres', 'redshift', 'snowflake'),
Required('name'): Any(None, basestring),
Required('state'): Any('init', 'open', 'closed', 'fail'),
Required('transaction_open'): bool,
Required('handle'): Any(None, object),
Required('credentials'): object,
})
Expand Down
4 changes: 4 additions & 0 deletions dbt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ class Exception(BaseException):
pass


class InternalException(Exception):
pass


class RuntimeException(RuntimeError, Exception):
pass

Expand Down
Loading