Skip to content

Commit

Permalink
Fix S3ToRedshiftOperator (#19358)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariotaddeucci authored Nov 3, 2021
1 parent 338822b commit 6148ddd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
25 changes: 11 additions & 14 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,24 @@ def execute(self, context) -> None:
copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)

if self.method == 'REPLACE':
sql = f"""
BEGIN;
DELETE FROM {destination};
{copy_statement}
COMMIT
"""
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]
elif self.method == 'UPSERT':
keys = self.upsert_keys or redshift_hook.get_table_primary_key(self.table, self.schema)
if not keys:
raise AirflowException(
f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'"
)
where_statement = ' AND '.join([f'{self.table}.{k} = {copy_destination}.{k}' for k in keys])
sql = f"""
CREATE TABLE {copy_destination} (LIKE {destination});
{copy_statement}
BEGIN;
DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};
INSERT INTO {destination} SELECT * FROM {copy_destination};
COMMIT
"""

sql = [
f"CREATE TABLE {copy_destination} (LIKE {destination});",
copy_statement,
"BEGIN;",
f"DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};",
f"INSERT INTO {destination} SELECT * FROM {copy_destination};",
"COMMIT",
]

else:
sql = copy_statement

Expand Down
6 changes: 3 additions & 3 deletions tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_deprecated_truncate(self, mock_run, mock_session, mock_connection, mock
{copy_statement}
COMMIT
"""
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], transaction)
assert_equal_ignore_multiple_spaces(self, "\n".join(mock_run.call_args[0][0]), transaction)

assert mock_run.call_count == 1

Expand Down Expand Up @@ -222,7 +222,7 @@ def test_replace(self, mock_run, mock_session, mock_connection, mock_hook):
{copy_statement}
COMMIT
"""
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], transaction)
assert_equal_ignore_multiple_spaces(self, "\n".join(mock_run.call_args[0][0]), transaction)

assert mock_run.call_count == 1

Expand Down Expand Up @@ -277,7 +277,7 @@ def test_upsert(self, mock_run, mock_session, mock_connection, mock_hook):
INSERT INTO {schema}.{table} SELECT * FROM #{table};
COMMIT
"""
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], transaction)
assert_equal_ignore_multiple_spaces(self, "\n".join(mock_run.call_args[0][0]), transaction)

assert mock_run.call_count == 1

Expand Down

0 comments on commit 6148ddd

Please sign in to comment.