Skip to content

Commit

Permalink
AirbyteLib: Add support for JSON and VARIANT types (#35117)
Browse files Browse the repository at this point in the history
Co-authored-by: Joe Reuter <joe@airbyte.io>
  • Loading branch information
aaronsteers and Joe Reuter authored Feb 14, 2024
1 parent ada1196 commit 686c31d
Show file tree
Hide file tree
Showing 13 changed files with 344 additions and 64 deletions.
15 changes: 8 additions & 7 deletions airbyte-lib/airbyte_lib/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,14 @@ def _swap_temp_table_with_final_table(

_ = stream_name
deletion_name = f"{final_table_name}_deleteme"
commands = [
f"ALTER TABLE {final_table_name} RENAME TO {deletion_name}",
f"ALTER TABLE {temp_table_name} RENAME TO {final_table_name}",
f"DROP TABLE {deletion_name}",
]
for cmd in commands:
self._execute_sql(cmd)
commands = "\n".join(
[
f"ALTER TABLE {final_table_name} RENAME TO {deletion_name};",
f"ALTER TABLE {temp_table_name} RENAME TO {final_table_name};",
f"DROP TABLE {deletion_name};",
]
)
self._execute_sql(commands)

def _merge_temp_table_to_final_table(
self,
Expand Down
93 changes: 80 additions & 13 deletions airbyte-lib/airbyte_lib/caches/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@

from __future__ import annotations

from textwrap import dedent, indent
from typing import TYPE_CHECKING

import sqlalchemy
from overrides import overrides
from snowflake.sqlalchemy import URL
from snowflake.sqlalchemy import URL, VARIANT

from airbyte_lib._file_writers import ParquetWriter, ParquetWriterConfig
from airbyte_lib.caches.base import RecordDedupeMode, SQLCacheBase, SQLCacheConfigBase
from airbyte_lib.caches.base import (
RecordDedupeMode,
SQLCacheBase,
SQLCacheConfigBase,
)
from airbyte_lib.telemetry import CacheTelemetryInfo
from airbyte_lib.types import SQLTypeConverter


if TYPE_CHECKING:
Expand Down Expand Up @@ -58,6 +65,26 @@ def get_database_name(self) -> str:
return self.database


class SnowflakeTypeConverter(SQLTypeConverter):
"""A class to convert types for Snowflake."""

@overrides
def to_sql_type(
self,
json_schema_property_def: dict[str, str | dict | list],
) -> sqlalchemy.types.TypeEngine:
"""Convert a value to a SQL type.
We first call the parent class method to get the type. Then if the type JSON, we
replace it with VARIANT.
"""
sql_type = super().to_sql_type(json_schema_property_def)
if isinstance(sql_type, sqlalchemy.types.JSON):
return VARIANT()

return sql_type


class SnowflakeSQLCache(SQLCacheBase):
"""A Snowflake implementation of the cache.
Expand All @@ -66,6 +93,7 @@ class SnowflakeSQLCache(SQLCacheBase):

config_class = SnowflakeCacheConfig
file_writer_class = ParquetWriter
type_converter_class = SnowflakeTypeConverter

@overrides
def _write_files_to_new_table(
Expand All @@ -74,23 +102,62 @@ def _write_files_to_new_table(
stream_name: str,
batch_id: str,
) -> str:
"""Write a file(s) to a new table.
TODO: Override the base implementation to use the COPY command.
TODO: Make sure this works for all data types.
"""
return super()._write_files_to_new_table(files, stream_name, batch_id)
"""Write files to a new table."""
temp_table_name = self._create_table_for_loading(
stream_name=stream_name,
batch_id=batch_id,
)
internal_sf_stage_name = f"@%{temp_table_name}"
put_files_statements = "\n".join(
[
f"PUT 'file://{file_path.absolute()!s}' {internal_sf_stage_name};"
for file_path in files
]
)
self._execute_sql(put_files_statements)

columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
]
files_list = ", ".join([f"'{f.name}'" for f in files])
columns_list_str: str = indent("\n, ".join(columns_list), " " * 12)
variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list])
copy_statement = dedent(
f"""
COPY INTO {temp_table_name}
(
{columns_list_str}
)
FROM (
SELECT {variant_cols_str}
FROM {internal_sf_stage_name}
)
FILES = ( {files_list} )
FILE_FORMAT = ( TYPE = PARQUET )
;
"""
)
self._execute_sql(copy_statement)
return temp_table_name

@overrides
def _init_connection_settings(self, connection: Connection) -> None:
"""We override this method to set the QUOTED_IDENTIFIERS_IGNORE_CASE setting to True.
This is necessary because Snowflake otherwise will treat quoted table and column references
as case-sensitive.
"""We set Snowflake-specific settings for the session.
This sets QUOTED_IDENTIFIERS_IGNORE_CASE setting to True, which is necessary because
Snowflake otherwise will treat quoted table and column references as case-sensitive.
More info: https://docs.snowflake.com/en/sql-reference/identifiers-syntax
This also sets MULTI_STATEMENT_COUNT to 0, which allows multi-statement commands.
"""
connection.execute("ALTER SESSION SET QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE")
connection.execute(
"""
ALTER SESSION SET
QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE
MULTI_STATEMENT_COUNT = 0
"""
)

@overrides
def get_telemetry_info(self) -> CacheTelemetryInfo:
Expand Down
4 changes: 2 additions & 2 deletions airbyte-lib/airbyte_lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"time_without_timezone": sqlalchemy.types.TIME,
# Technically 'object' and 'array' as JSON Schema types, not airbyte types.
# We include them here for completeness.
"object": sqlalchemy.types.VARCHAR,
"array": sqlalchemy.types.VARCHAR,
"object": sqlalchemy.types.JSON,
"array": sqlalchemy.types.JSON,
}


Expand Down
13 changes: 12 additions & 1 deletion airbyte-lib/docs/generated/airbyte_lib/caches.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions airbyte-lib/examples/run_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
}
)
source.check()
source.select_streams(["issues", "pull_requests", "commits"])
source.select_streams(["issues", "pull_requests", "commits", "collaborators"])

result = source.read()
result = source.read(cache=ab.new_local_cache("github"))
print(result.processed_records)

for name, records in result.streams.items():
print(f"Stream {name}: {len(records)} records")
2 changes: 1 addition & 1 deletion airbyte-lib/examples/run_snowflake_faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

source.check()

source.set_streams(["products"])
source.select_streams(["products"])
result = source.read(cache)

for name in ["products"]:
Expand Down
2 changes: 1 addition & 1 deletion airbyte-lib/examples/run_spacex.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

source.check()

source.set_streams(["launches", "rockets", "capsules"])
source.select_streams(["launches", "rockets", "capsules"])

result = source.read(cache)

Expand Down
62 changes: 61 additions & 1 deletion airbyte-lib/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions airbyte-lib/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ google-cloud-secret-manager = "^2.17.0"
types-requests = "2.31.0.4"
freezegun = "^1.4.0"
airbyte-source-faker = "^6.0.0"
viztracer = "^0.16.2"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading

0 comments on commit 686c31d

Please sign in to comment.