Skip to content

fixed seed insertion #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import Iterable
from contextlib import contextmanager
from datetime import datetime
import datetime
from getpass import getuser
import re
import decimal
@@ -92,7 +92,10 @@ def _escape_value(cls, value):
return "'{}'".format(value.replace("'", "''"))
elif isinstance(value, (int,float,decimal.Decimal)):
return value
elif isinstance(value, datetime):
elif isinstance(value, datetime.date):
date_formatted = value.strftime('%Y-%m-%d')
return "DATE '{}'".format(date_formatted)
elif isinstance(value, datetime.datetime):
time_formatted = value.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
return "TIMESTAMP '{}'".format(time_formatted)
else:
@@ -167,7 +170,7 @@ def open(cls, connection):
conn = connect(
s3_staging_dir=credentials.s3_staging_dir,
region_name=credentials.region_name,
schema_name=credentials.database,
schema_name=credentials.schema,
cursor_class=AsyncCursor
)
connection.state = 'open'
31 changes: 27 additions & 4 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
import agate
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.graph.manifest import Manifest
from dbt.logger import GLOBAL_LOGGER as logger

import dbt
from dbt.adapters.athena import AthenaConnectionManager
from dbt.adapters.athena.relation import AthenaRelation

import agate


class AthenaAdapter(SQLAdapter):
ConnectionManager = AthenaConnectionManager
Relation = AthenaRelation

@classmethod
def date_function(cls):
return 'datenow()'

@classmethod
def convert_text_type(cls, agate_table, col_idx):
return "VARCHAR"
return "STRING"

@classmethod
def convert_number_type(cls, agate_table, col_idx):
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "DOUBLE" if decimals else "INTEGER"
return "DOUBLE" if decimals else "INT"

@classmethod
def convert_datetime_type(cls, agate_table, col_idx):
@@ -43,3 +47,22 @@ def drop_schema(self, database, schema, model_name=None):
schema=schema,
model_name=model_name
)

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
return

info_schema_name_map = self._get_cache_schemas(manifest,
exec_only=True)

for relation in self.list_relations_without_caching('INFORMATION_SCHEMA', ".*"):
logger.debug("add relation to cache: {}".format(relation))
self.cache.add(relation)

# it's possible that there were no relations in some schemas. We want
# to insert the schemas we query into the cache's `.schemas` attribute
# so we can check it later
self.cache.update_schemas(info_schema_name_map.schemas_searched())
10 changes: 5 additions & 5 deletions dbt/include/athena/macros/adapters.sql
Original file line number Diff line number Diff line change
@@ -24,11 +24,12 @@

from
{{ relation.information_schema('columns') }}
where table_name='{{ relation.identifier }}'
{% if relation.schema %}
and table_schema='{{ relation.schema }}'
{% endif %}


where {{ presto_ilike('table_name', relation.identifier) }}
{% if relation.schema %}
and {{ presto_ilike('table_schema', relation.schema) }}
{% endif %}
order by ordinal_position

{% endcall %}
@@ -37,7 +38,6 @@
{{ return(sql_convert_columns_in_relation(table)) }}
{% endmacro %}


{% macro athena__list_relations_without_caching(information_schema, schema) %}
{% call statement('list_relations_without_caching', fetch_result=True) -%}
select
147 changes: 147 additions & 0 deletions dbt/include/athena/macros/materializations/seed.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@

{% macro create_csv_table(model, agate_table) -%}
{{ adapter_macro('create_csv_table', model, agate_table) }}
{%- endmacro %}

{% macro reset_csv_table(model, full_refresh, old_relation, agate_table) -%}
{{ adapter_macro('reset_csv_table', model, full_refresh, old_relation, agate_table) }}
{%- endmacro %}

{% macro load_csv_rows(model, agate_table) -%}
{{ adapter_macro('load_csv_rows', model, agate_table) }}
{%- endmacro %}

{% macro athena__create_csv_table(model, agate_table) %}
{%- set column_override = model['config'].get('column_types', {}) -%}
{%- set quote_seed_column = model['config'].get('quote_columns', None) -%}

{% set sql %}
create external table {{ this.render() }} (
{%- for col_name in agate_table.column_names -%}
{%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%}
{%- set type = column_override.get(col_name, inferred_type) -%}
{%- set column_name = (col_name | string) -%}
{{ adapter.quote_seed_column(column_name, quote_seed_column) }} {{ type }} {%- if not loop.last -%}, {%- endif -%}
{%- endfor -%}
)
LOCATION '{{ target.s3_staging_dir }}/{{ this.schema }}/{{ this.table }}/{{ invocation_id }}'
{% endset %}

{% call statement('_') -%}
{{ sql }}
{%- endcall %}

{{ return(sql) }}
{% endmacro %}


{% macro athena__reset_csv_table(model, full_refresh, old_relation, agate_table) %}
{% set sql = "" %}
{{ adapter.drop_relation(old_relation) }}
{% set sql = create_csv_table(model, agate_table) %}
{{ return(sql) }}
{% endmacro %}


{% macro get_seed_column_quoted_csv(model, column_names) %}
{%- set quote_seed_column = model['config'].get('quote_columns', None) -%}
{% set quoted = [] %}
{% for col in column_names -%}
{%- do quoted.append(adapter.quote_seed_column(col, quote_seed_column)) -%}
{%- endfor %}

{%- set dest_cols_csv = quoted | join(', ') -%}
{{ return(dest_cols_csv) }}
{% endmacro %}


{% macro basic_load_csv_rows(model, batch_size, agate_table) %}
{% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %}
{% set bindings = [] %}

{% set statements = [] %}

{% for chunk in agate_table.rows | batch(batch_size) %}
{% set bindings = [] %}

{% for row in chunk %}
{% do bindings.extend(row) %}
{% endfor %}

{% set sql %}
insert into {{ this.render() }} ({{ cols_sql }}) values
{% for row in chunk -%}
({%- for column in agate_table.column_names -%}
%s
{%- if not loop.last%},{%- endif %}
{%- endfor -%})
{%- if not loop.last%},{%- endif %}
{%- endfor %}
{% endset %}

{% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %}

{% if loop.index0 == 0 %}
{% do statements.append(sql) %}
{% endif %}
{% endfor %}

{# Return SQL so we can render it out into the compiled files #}
{{ return(statements[0]) }}
{% endmacro %}


{% macro athena__load_csv_rows(model, agate_table) %}
{{ return(basic_load_csv_rows(model, 10000, agate_table) )}}
{% endmacro %}


{% materialization seed, default %}

{%- set identifier = model['alias'] -%}
{%- set full_refresh_mode = (flags.FULL_REFRESH == True) -%}

{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}

{%- set exists_as_table = (old_relation is not none and old_relation.is_table) -%}
{%- set exists_as_view = (old_relation is not none and old_relation.is_view) -%}

{%- set agate_table = load_agate_table() -%}
{%- do store_result('agate_table', status='OK', agate_table=agate_table) -%}

{{ run_hooks(pre_hooks, inside_transaction=False) }}

-- `BEGIN` happens here:
{{ run_hooks(pre_hooks, inside_transaction=True) }}

-- build model
{% set create_table_sql = "" %}
{% if exists_as_view %}
{{ exceptions.raise_compiler_error("Cannot seed to '{}', it is a view".format(old_relation)) }}
{% elif exists_as_table %}
{% set create_table_sql = reset_csv_table(model, full_refresh_mode, old_relation, agate_table) %}
{% else %}
{% set create_table_sql = create_csv_table(model, agate_table) %}
{% endif %}

{% set status = 'CREATE' if full_refresh_mode else 'INSERT' %}
{% set num_rows = (agate_table.rows | length) %}
{% set sql = load_csv_rows(model, agate_table) %}

{% call noop_statement('main', status ~ ' ' ~ num_rows) %}
{{ create_table_sql }};
-- dbt seed --
{{ sql }}
{% endcall %}

{{ run_hooks(post_hooks, inside_transaction=True) }}

-- `COMMIT` happens here
{{ adapter.commit() }}

{{ run_hooks(post_hooks, inside_transaction=False) }}

{% set target_relation = this.incorporate(type='table') %}
{{ return({'relations': [target_relation]}) }}

{% endmaterialization %}