Skip to content

Commit

Permalink
one_hot_encoder: TSQL vs. boolean type (#9)
Browse files Browse the repository at this point in the history
* extract col list generation

* move col_list creation to core macro

* TSQL has no boolean type

Co-authored-by: James Weakley <jameswillisweakley@gmail.com>
  • Loading branch information
dataders and jamesweakley authored Mar 18, 2021
1 parent b5f5b38 commit 09191fe
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions macros/one_hot_encoder.sql
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,30 @@
{%- do exceptions.raise_compiler_error(error_message) -%}
{%- endif -%}

{{ adapter.dispatch('one_hot_encoder',packages=['dbt_ml_preprocessing'])(source_table, source_column, category_values, handle_unknown, include_columns, exclude_columns) }}
{%- endmacro %}

{% macro default__one_hot_encoder(source_table, source_column, category_values, handle_unknown, include_columns, exclude_columns) %}
{% set columns = adapter.get_columns_in_relation( source_table ) %}

select
{%- if include_columns=='*' and exclude_columns is none -%}
{% for column in columns %}
{{ column.name }},
{%- endfor -%}
{%- elif include_columns !='*'-%}
{% for column in include_columns %}
{{ source_table }}.{{ column }},
{%- endfor -%}
{%- else -%}
{% for column in columns %}
{%- if include_columns=='*' and exclude_columns is none -%}
{% set col_list = columns %}
{%- elif include_columns !='*'-%}
{% set col_list = include_columns %}
{%- else -%}
{% set col_list = [] %}
{% for column in columns %}
{%- if column.name | lower not in exclude_columns | lower %}
{{ column.name }},
{% do col_list.append(column) %}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}

{{ adapter.dispatch('one_hot_encoder',packages=['dbt_ml_preprocessing'])(source_table, source_column, category_values, handle_unknown, col_list) }}
{%- endmacro %}

{% macro default__one_hot_encoder(source_table, source_column, category_values, handle_unknown, col_list) %}

select
{% for column in col_list %}
{{ column.name }},
{%- endfor -%}
{% for category in category_values %}
{% set no_whitespace_column_name = category | replace( " ", "_") -%}
{%- if handle_unknown=='ignore' %}
Expand All @@ -93,3 +95,34 @@
{% endfor %}
from {{ source_table }}
{%- endmacro %}

{% macro sqlserver__one_hot_encoder(source_table, source_column, category_values, handle_unknown, col_list) %}

select
{% for column in col_list %}
{{ column.name }},
{%- endfor -%}
{% for category in category_values %}
{% set no_whitespace_column_name = category | replace( " ", "_") -%}
{%- if handle_unknown=='ignore' %}
case
when {{ source_column }} = '{{ category }}' then 1
else 0
end as is_{{ source_column }}_{{ no_whitespace_column_name }}
{% endif %}
{%- if handle_unknown=='error' %}
case
when {{ source_column }} = '{{ category }}' then 1
when {{ source_column }} in ('{{ category_values | join("','") }}') then 0
else cast('Error: unknown value found and handle_unknown parameter was "error"' as bit)
end as is_{{ source_column }}_{{ no_whitespace_column_name }}
{% endif %}
{%- if not loop.last %},{% endif -%}
{% endfor %}
from {{ source_table }}

{%- endmacro %}

{% macro synapse__one_hot_encoder(source_table, source_column, category_values, handle_unknown, col_list) %}
{% do return( dbt_ml_preprocessing.sqlserver__one_hot_encoder(source_table, source_column, category_values, handle_unknown, col_list)) %}
{%- endmacro %}

0 comments on commit 09191fe

Please sign in to comment.