Skip to content
This repository has been archived by the owner on Sep 23, 2024. It is now read-only.

Distinguish between nulls and empty strings #69

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions target_postgres/db_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def validate_config(config):


# pylint: disable=fixme
def column_type(schema_property):
def column_type(schema_property, logger=None):
property_type = schema_property['type']
property_format = schema_property['format'] if 'format' in schema_property else None
col_type = 'character varying'
Expand Down Expand Up @@ -69,7 +69,8 @@ def column_type(schema_property):
elif 'boolean' in property_type:
col_type = 'boolean'

get_logger('target_postgres').debug("schema_property: %s -> col_type: %s", schema_property, col_type)
if logger:
logger.debug("schema_property: %s -> col_type: %s", schema_property, col_type)

return col_type

Expand All @@ -78,8 +79,8 @@ def safe_column_name(name):
return '"{}"'.format(name).lower()


def column_clause(name, schema_property):
return '{} {}'.format(safe_column_name(name), column_type(schema_property))
def column_clause(name, schema_property, logger):
return '{} {}'.format(safe_column_name(name), column_type(schema_property, logger))


def flatten_key(k, parent_key, sep):
Expand Down Expand Up @@ -349,7 +350,7 @@ def record_to_csv_line(self, record):
return ','.join(
[
json.dumps(flatten[name], ensure_ascii=False)
if name in flatten and (flatten[name] == 0 or flatten[name]) else ''
if name in flatten and (flatten[name] == 0 or flatten[name] or (column_type(self.flatten_schema[name]) == 'character varying' and flatten[name] == '')) else '###NULL###'
for name in self.flatten_schema
]
)
Expand All @@ -367,7 +368,7 @@ def load_csv(self, file, count, size_bytes):
temp_table = self.table_name(stream_schema_message['stream'], is_temporary=True)
cur.execute(self.create_table_query(table_name=temp_table, is_temporary=True))

copy_sql = "COPY {} ({}) FROM STDIN WITH (FORMAT CSV, ESCAPE '\\')".format(
copy_sql = "COPY {} ({}) FROM STDIN WITH (FORMAT CSV, ESCAPE '\\', NULL '###NULL###')".format(
temp_table,
', '.join(self.column_names())
)
Expand Down Expand Up @@ -436,7 +437,8 @@ def create_table_query(self, table_name=None, is_temporary=False):
columns = [
column_clause(
name,
schema
schema,
self.logger
)
for (name, schema) in self.flatten_schema.items()
]
Expand Down Expand Up @@ -537,7 +539,8 @@ def update_columns(self):
columns_to_add = [
column_clause(
name,
properties_schema
properties_schema,
self.logger
)
for (name, properties_schema) in self.flatten_schema.items()
if name.lower() not in columns_dict
Expand All @@ -549,11 +552,12 @@ def update_columns(self):
columns_to_replace = [
(safe_column_name(name), column_clause(
name,
properties_schema
properties_schema,
self.logger
))
for (name, properties_schema) in self.flatten_schema.items()
if name.lower() in columns_dict and
columns_dict[name.lower()]['data_type'].lower() != column_type(properties_schema).lower()
columns_dict[name.lower()]['data_type'].lower() != column_type(properties_schema, self.logger).lower()
]

for (column_name, column) in columns_to_replace:
Expand Down