Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AirbyteLib: Use case-insensitive method of finding column objects #34985

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions airbyte-lib/airbyte_lib/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import ulid
from overrides import overrides
from sqlalchemy import (
Column,
Table,
and_,
create_engine,
insert,
Expand Down Expand Up @@ -839,6 +841,25 @@ def _merge_temp_table_to_final_table(
""",
)

def _get_column_by_name(self, table: str | Table, column_name: str) -> Column:
"""Return the column object for the given column name.

This method is case-insensitive.
"""
if isinstance(table, str):
table = self._get_table_by_name(table)
try:
# Try to get the column in a case-insensitive manner
return next(col for col in table.c if col.name.lower() == column_name.lower())
except StopIteration:
raise exc.AirbyteLibInternalError(
message="Could not find matching column.",
context={
"table": table,
"column_name": column_name,
},
) from None

def _emulated_merge_temp_table_to_final_table(
self,
stream_name: str,
Expand All @@ -859,13 +880,16 @@ def _emulated_merge_temp_table_to_final_table(

# Create a dictionary mapping columns in users_final to users_stage for updating
update_values = {
getattr(final_table.c, column): getattr(temp_table.c, column)
self._get_column_by_name(final_table, column): (
self._get_column_by_name(temp_table, column)
)
for column in columns_to_update
}

# Craft the WHERE clause for composite primary keys
join_conditions = [
getattr(final_table.c, pk_column) == getattr(temp_table.c, pk_column)
self._get_column_by_name(final_table, pk_column)
== self._get_column_by_name(temp_table, pk_column)
for pk_column in pk_columns
]
join_clause = and_(*join_conditions)
Expand All @@ -878,7 +902,7 @@ def _emulated_merge_temp_table_to_final_table(

# Define a condition that checks for records in temp_table that do not have a corresponding
# record in final_table
where_not_exists_clause = getattr(final_table.c, pk_columns[0]) == null()
where_not_exists_clause = self._get_column_by_name(final_table, pk_columns[0]) == null()

# Select records from temp_table that are not in final_table
select_new_records_stmt = (
Expand Down
Loading