Skip to content

Commit

Permalink
add an argument to control handling nulls in merge criteria
Browse files Browse the repository at this point in the history
  • Loading branch information
brendan-cook-87 committed Jul 12, 2024
1 parent 2a4234a commit 8273380
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 35 deletions.
151 changes: 117 additions & 34 deletions awswrangler/athena/_write_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,109 @@ def _validate_args(
)


def _merge_iceberg(
df: pd.DataFrame,
database: str,
table: str,
source_table: str,
merge_cols: list[str] | None = None,
merge_condition: Literal["update", "ignore"] = "update",
merge_match_nulls: bool = False,
kms_key: str | None = None,
boto3_session: boto3.Session | None = None,
s3_output: str | None = None,
workgroup: str = "primary",
encryption: str | None = None,
data_source: str | None = None,
) -> None:
"""
Merge iceberg.
Merge data from source_table and write it to an Athena iceberg table.
Parameters
----------
df : pd.DataFrame
Pandas DataFrame.
database : str
AWS Glue/Athena database name - It is only the origin database from where the query will be launched.
You can still using and mixing several databases writing the full table name within the sql
(e.g. `database.table`).
table : str
AWS Glue/Athena destination table name.
source_table: str
AWS Glue/Athena source table name.
merge_cols: List[str], optional
List of column names that will be used for conditional inserts and updates.
https://docs.aws.amazon.com/athena/latest/ug/merge-into-statement.html
merge_condition: str, optional
The condition to be used in the MERGE INTO statement. Valid values: ['update', 'ignore'].
merge_match_nulls: bool, optional
Instruct whether to have nulls in the merge condition match other nulls
kms_key : str, optional
For SSE-KMS, this is the KMS key ARN or ID.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
s3_output : str, optional
Amazon S3 path used for query execution.
workgroup : str
Athena workgroup. Primary by default.
encryption : str, optional
Valid values: [None, 'SSE_S3', 'SSE_KMS']. Notice: 'CSE_KMS' is not supported.
data_source : str, optional
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
Returns
-------
None
"""
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)

sql_statement: str
if merge_cols:
if merge_condition == "update":
match_condition = f"""WHEN MATCHED THEN
UPDATE SET {', '.join([f'"{x}" = source."{x}"' for x in df.columns])}"""
else:
match_condition = ""

if merge_match_nulls:
merge_conditions = [f'(target."{x}" IS NOT DISTINCT FROM source."{x}")' for x in merge_cols]
else:
merge_conditions = [f'(target."{x}" = source."{x}")' for x in merge_cols]

sql_statement = f"""
MERGE INTO "{database}"."{table}" target
USING "{database}"."{source_table}" source
ON {' AND '.join(merge_conditions)}
{match_condition}
WHEN NOT MATCHED THEN
INSERT ({', '.join([f'"{x}"' for x in df.columns])})
VALUES ({', '.join([f'source."{x}"' for x in df.columns])})
"""
else:
sql_statement = f"""
INSERT INTO "{database}"."{table}" ({', '.join([f'"{x}"' for x in df.columns])})
SELECT {', '.join([f'"{x}"' for x in df.columns])}
FROM "{database}"."{source_table}"
"""

query_execution_id: str = _start_query_execution(
sql=sql_statement,
workgroup=workgroup,
wg_config=wg_config,
database=database,
data_source=data_source,
s3_output=s3_output,
encryption=encryption,
kms_key=kms_key,
boto3_session=boto3_session,
)
wait_query(query_execution_id=query_execution_id, boto3_session=boto3_session)


@apply_configs
@_utils.validate_distributed_kwargs(
unsupported_kwargs=["boto3_session", "s3_additional_kwargs"],
Expand All @@ -253,6 +356,7 @@ def to_iceberg(
partition_cols: list[str] | None = None,
merge_cols: list[str] | None = None,
merge_condition: Literal["update", "ignore"] = "update",
merge_match_nulls: bool = False,
keep_files: bool = True,
data_source: str | None = None,
s3_output: str | None = None,
Expand Down Expand Up @@ -301,6 +405,8 @@ def to_iceberg(
https://docs.aws.amazon.com/athena/latest/ug/merge-into-statement.html
merge_condition: str, optional
The condition to be used in the MERGE INTO statement. Valid values: ['update', 'ignore'].
merge_match_nulls: bool, optional
Instruct whether to have nulls in the merge condition match other nulls
keep_files : bool
Whether staging files produced by Athena are retained. 'True' by default.
data_source : str, optional
Expand Down Expand Up @@ -504,44 +610,21 @@ def to_iceberg(
glue_table_settings=glue_table_settings,
)

# Insert or merge into Iceberg table
sql_statement: str
if merge_cols:
if merge_condition == "update":
match_condition = f"""WHEN MATCHED THEN
UPDATE SET {', '.join([f'"{x}" = source."{x}"' for x in df.columns])}"""
else:
match_condition = ""
sql_statement = f"""
MERGE INTO "{database}"."{table}" target
USING "{database}"."{temp_table}" source
ON {' AND '.join([
f'(target."{x}" = source."{x}" OR (target."{x}" IS NULL AND source."{x}" IS NULL))'
for x in merge_cols])}
{match_condition}
WHEN NOT MATCHED THEN
INSERT ({', '.join([f'"{x}"' for x in df.columns])})
VALUES ({', '.join([f'source."{x}"' for x in df.columns])})
"""
else:
sql_statement = f"""
INSERT INTO "{database}"."{table}" ({', '.join([f'"{x}"' for x in df.columns])})
SELECT {', '.join([f'"{x}"' for x in df.columns])}
FROM "{database}"."{temp_table}"
"""

query_execution_id: str = _start_query_execution(
sql=sql_statement,
workgroup=workgroup,
wg_config=wg_config,
_merge_iceberg(
df=df,
database=database,
data_source=data_source,
s3_output=s3_output,
encryption=encryption,
table=table,
source_table=temp_table,
merge_cols=merge_cols,
merge_condition=merge_condition,
merge_match_nulls=merge_match_nulls,
kms_key=kms_key,
boto3_session=boto3_session,
s3_output=s3_output,
workgroup=workgroup,
encryption=encryption,
data_source=data_source,
)
wait_query(query_execution_id=query_execution_id, boto3_session=boto3_session)

except Exception as ex:
_logger.error(ex)
Expand Down
Loading

0 comments on commit 8273380

Please sign in to comment.