Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run hooks outside of a transaction #510

Merged
merged 11 commits into from
Aug 29, 2017
20 changes: 15 additions & 5 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class DefaultAdapter(object):
"truncate",
"add_query",
"expand_target_column_types",
"quote_schema_and_table",
]

raw_functions = [
"get_status",
"get_result_from_cursor",
"quote",
"quote_schema_and_table",
]

###
Expand Down Expand Up @@ -396,6 +396,10 @@ def reload(cls, connection):
def add_begin_query(cls, profile, name):
return cls.add_query(profile, 'BEGIN', name, auto_begin=False)

@classmethod
def add_commit_query(cls, profile, name):
return cls.add_query(profile, 'COMMIT', name, auto_begin=False)

@classmethod
def begin(cls, profile, name='master'):
global connections_in_use
Expand Down Expand Up @@ -428,10 +432,10 @@ def commit_if_has_connection(cls, profile, name):

connection = cls.get_connection(profile, name, False)

return cls.commit(connection)
return cls.commit(profile, connection)

@classmethod
def commit(cls, connection):
def commit(cls, profile, connection):
global connections_in_use

if dbt.flags.STRICT_MODE:
Expand All @@ -445,7 +449,7 @@ def commit(cls, connection):
'it does not have one open!'.format(connection.get('name')))

logger.debug('On {}: COMMIT'.format(connection.get('name')))
connection.get('handle').commit()
cls.add_commit_query(profile, connection.get('name'))

connection['transaction_open'] = False
connections_in_use[connection.get('name')] = connection
Expand Down Expand Up @@ -512,6 +516,12 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True):

return connection, cursor

@classmethod
def clear_transaction(cls, profile, conn_name='master'):
conn = cls.begin(profile, conn_name)
cls.commit(profile, conn)
return conn_name

@classmethod
def execute_one(cls, profile, sql, model_name=None, auto_begin=False):
cls.get_connection(profile, model_name)
Expand Down Expand Up @@ -576,6 +586,6 @@ def quote(cls, identifier):
return '"{}"'.format(identifier)

@classmethod
def quote_schema_and_table(cls, profile, schema, table):
def quote_schema_and_table(cls, profile, schema, table, model_name=None):
return '{}.{}'.format(cls.quote(schema),
cls.quote(table))
4 changes: 2 additions & 2 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def exception_handler(cls, profile, sql, model_name=None,
except psycopg2.DatabaseError as e:
logger.debug('Postgres error: {}'.format(str(e)))

cls.rollback(connection)
cls.release_connection(profile, connection_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this automatically roll back open transactions? would want to be sure we don't create a deadlock because of an open tx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, release_connection calls rollback under the hood if the transaction is open. This change is to fix an issue where we'd try to rollback transactions that weren't open

raise dbt.exceptions.DatabaseException(
dbt.compat.to_string(e).strip())

except Exception as e:
logger.debug("Error running SQL: %s", sql)
logger.debug("Rolling back transaction.")
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.RuntimeException(e)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def drop(cls, profile, relation, relation_type, model_name=None):
connection = cls.get_connection(profile, model_name)

if connection.get('transaction_open'):
cls.commit(connection)
cls.commit(profile, connection)

cls.begin(profile, connection.get('name'))

to_return = super(PostgresAdapter, cls).drop(
profile, relation, relation_type, model_name)

cls.commit(connection)
cls.commit(profile, connection)
cls.begin(profile, connection.get('name'))

return to_return
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def exception_handler(cls, profile, sql, model_name=None,
if 'Empty SQL statement' in msg:
logger.debug("got empty sql statement, moving on")
elif 'This session does not have a current database' in msg:
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.FailedToConnectException(
('{}\n\nThis error sometimes occurs when invalid '
'credentials are provided, or when your default role '
'does not have access to use the specified database. '
'Please double check your profile and try again.')
.format(msg))
else:
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.DatabaseException(msg)
except Exception as e:
logger.debug("Error running SQL: %s", sql)
logger.debug("Rolling back transaction.")
cls.rollback(connection)
cls.release_connection(profile, connection_name)
raise dbt.exceptions.RuntimeException(e.msg)

@classmethod
Expand Down
27 changes: 15 additions & 12 deletions dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,17 @@
import voluptuous

from dbt.adapters.factory import get_adapter
from dbt.compat import basestring
from dbt.compat import basestring, to_string

import dbt.clients.jinja
import dbt.flags
import dbt.schema
import dbt.tracking
import dbt.utils

from dbt.logger import GLOBAL_LOGGER as logger # noqa


def get_hooks(model, context, hook_key):
hooks = model.get('config', {}).get(hook_key, [])
import dbt.hooks

if isinstance(hooks, basestring):
hooks = [hooks]

return hooks
from dbt.logger import GLOBAL_LOGGER as logger # noqa


class DatabaseWrapper(object):
Expand Down Expand Up @@ -227,6 +220,15 @@ def fn(string):
return fn


def fromjson(node):
def fn(string, default=None):
try:
return json.loads(string)
except ValueError as e:
return default
return fn


def generate(model, project, flat_graph, provider=None):
"""
Not meant to be called directly. Call with either:
Expand All @@ -248,8 +250,8 @@ def generate(model, project, flat_graph, provider=None):
context = {'env': target}
schema = profile.get('schema', 'public')

pre_hooks = get_hooks(model, context, 'pre-hook')
post_hooks = get_hooks(model, context, 'post-hook')
pre_hooks = model.get('config', {}).get('pre-hook')
post_hooks = model.get('config', {}).get('post-hook')

db_wrapper = DatabaseWrapper(model, adapter, profile)

Expand All @@ -270,6 +272,7 @@ def generate(model, project, flat_graph, provider=None):
"schema": schema,
"sql": model.get('injected_sql'),
"sql_now": adapter.date_function(),
"fromjson": fromjson(model),
"target": target,
"this": dbt.utils.This(
schema,
Expand Down
12 changes: 10 additions & 2 deletions dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@

from dbt.logger import GLOBAL_LOGGER as logger # noqa

hook_contract = Schema({
Required('sql'): basestring,
Required('transaction'): bool,
})

config_contract = Schema({
Required('enabled'): bool,
Required('materialized'): basestring,
Required('post-hook'): list,
Required('pre-hook'): list,
Required('post-hook'): [hook_contract],
Required('pre-hook'): [hook_contract],
Required('vars'): dict,
}, extra=ALLOW_EXTRA)

Expand Down Expand Up @@ -69,6 +73,10 @@
})


def validate_hook(hook):
validate_with(hook_contract, hooks)


def validate_nodes(parsed_nodes):
validate_with(parsed_nodes_contract, parsed_nodes)

Expand Down
9 changes: 1 addition & 8 deletions dbt/graph/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import networkx as nx
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.utils import is_enabled, get_materialization
from dbt.utils import is_enabled, get_materialization, coalesce
from dbt.node_types import NodeType

SELECTOR_PARENTS = '+'
Expand Down Expand Up @@ -43,13 +43,6 @@ def parse_spec(node_spec):
}


def coalesce(*args):
for arg in args:
if arg is not None:
return arg
return None


def get_package_names(graph):
return set([node.split(".")[1] for node in graph.nodes()])

Expand Down
40 changes: 40 additions & 0 deletions dbt/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

import json
from dbt.compat import to_string


class ModelHookType:
PreHook = 'pre-hook'
PostHook = 'post-hook'
Both = [PreHook, PostHook]


def _parse_hook_to_dict(hook_string):
try:
hook_dict = json.loads(hook_string)
except ValueError as e:
hook_dict = {"sql": hook_string}

if 'transaction' not in hook_dict:
hook_dict['transaction'] = True

return hook_dict


def get_hook_dict(hook):
if isinstance(hook, dict):
hook_dict = hook
else:
hook_dict = _parse_hook_to_dict(to_string(hook))

return hook_dict


def get_hooks(model, hook_key):
hooks = model.get('config', {}).get(hook_key, [])

if not isinstance(hooks, (list, tuple)):
hooks = [hooks]

wrapped = [get_hook_dict(hook) for hook in hooks]
return wrapped
4 changes: 2 additions & 2 deletions dbt/include/global_project/macros/core.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% macro statement(name=None, fetch_result=False) -%}
{% macro statement(name=None, fetch_result=False, auto_begin=True) -%}
{%- if execute: -%}
{%- set sql = render(caller()) -%}

Expand All @@ -7,7 +7,7 @@
{{ write(sql) }}
{%- endif -%}

{%- set _, cursor = adapter.add_query(sql) -%}
{%- set _, cursor = adapter.add_query(sql, auto_begin=auto_begin) -%}
{%- if name is not none -%}
{%- set result = [] if not fetch_result else adapter.get_result_from_cursor(cursor) -%}
{{ store_result(name, status=adapter.get_status(cursor), data=result) }}
Expand Down
28 changes: 24 additions & 4 deletions dbt/include/global_project/macros/materializations/helpers.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% macro run_hooks(hooks) %}
{% for hook in hooks %}
{% call statement() %}
{{ hook }};
{% macro run_hooks(hooks, inside_transaction=True) %}
{% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %}
{% call statement(auto_begin=inside_transaction) %}
{{ hook.get('sql') }}
{% endcall %}
{% endfor %}
{% endmacro %}
Expand All @@ -21,6 +21,26 @@
{% endmacro %}


{% macro make_hook_config(sql, inside_transaction) %}
{{ {"sql": sql, "transaction": inside_transaction} | tojson }}
{% endmacro %}


{% macro before_begin(sql) %}
{{ make_hook_config(sql, inside_transaction=False) }}
{% endmacro %}


{% macro in_transaction(sql) %}
{{ make_hook_config(sql, inside_transaction=True) }}
{% endmacro %}


{% macro after_commit(sql) %}
{{ make_hook_config(sql, inside_transaction=False) }}
{% endmacro %}


{% macro drop_if_exists(existing, name) %}
{% set existing_type = existing.get(name) %}
{% if existing_type is not none %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
{{ adapter.drop(identifier, existing_type) }}
{%- endif %}

{{ run_hooks(pre_hooks) }}
{{ run_hooks(pre_hooks, inside_transaction=False) }}

-- `BEGIN` happens here:
{{ run_hooks(pre_hooks, inside_transaction=True) }}

-- build model
{% if force_create or not adapter.already_exists(schema, identifier) -%}
Expand Down Expand Up @@ -79,8 +82,11 @@
{% endcall %}
{%- endif %}

{{ run_hooks(post_hooks) }}
{{ run_hooks(post_hooks, inside_transaction=True) }}

-- `COMMIT` happens here
{{ adapter.commit() }}

{{ run_hooks(post_hooks, inside_transaction=False) }}

{%- endmaterialization %}
10 changes: 8 additions & 2 deletions dbt/include/global_project/macros/materializations/table.sql
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
{%- endif %}
{%- endif %}

{{ run_hooks(pre_hooks) }}
{{ run_hooks(pre_hooks, inside_transaction=False) }}

-- `BEGIN` happens here:
{{ run_hooks(pre_hooks, inside_transaction=True) }}

-- build model
{% call statement('main') -%}
Expand All @@ -39,7 +42,7 @@
{%- endif -%}
{%- endcall %}

{{ run_hooks(post_hooks) }}
{{ run_hooks(post_hooks, inside_transaction=True) }}

-- cleanup
{% if non_destructive_mode -%}
Expand All @@ -49,5 +52,8 @@
{{ adapter.rename(tmp_identifier, identifier) }}
{%- endif %}

-- `COMMIT` happens here
{{ adapter.commit() }}

{{ run_hooks(post_hooks, inside_transaction=False) }}
{% endmaterialization %}
Loading