From 686c31d05f42403c00425d0f1c83cd918824a996 Mon Sep 17 00:00:00 2001 From: "Aaron (\"AJ\") Steers" Date: Wed, 14 Feb 2024 10:04:19 -0800 Subject: [PATCH] AirbyteLib: Add support for JSON and VARIANT types (#35117) Co-authored-by: Joe Reuter --- airbyte-lib/airbyte_lib/caches/base.py | 15 +- airbyte-lib/airbyte_lib/caches/snowflake.py | 93 +++++++++-- airbyte-lib/airbyte_lib/types.py | 4 +- .../docs/generated/airbyte_lib/caches.html | 13 +- airbyte-lib/examples/run_github.py | 5 +- airbyte-lib/examples/run_snowflake_faker.py | 2 +- airbyte-lib/examples/run_spacex.py | 2 +- airbyte-lib/poetry.lock | 62 +++++++- airbyte-lib/pyproject.toml | 1 + .../integration_tests/test_snowflake_cache.py | 150 ++++++++++++++++++ .../test_source_faker_integration.py | 44 ++--- .../test_source_test_fixture.py | 10 +- .../tests/unit_tests/test_type_translation.py | 7 +- 13 files changed, 344 insertions(+), 64 deletions(-) create mode 100644 airbyte-lib/tests/integration_tests/test_snowflake_cache.py diff --git a/airbyte-lib/airbyte_lib/caches/base.py b/airbyte-lib/airbyte_lib/caches/base.py index 606707692930..b5ee35e680ac 100644 --- a/airbyte-lib/airbyte_lib/caches/base.py +++ b/airbyte-lib/airbyte_lib/caches/base.py @@ -809,13 +809,14 @@ def _swap_temp_table_with_final_table( _ = stream_name deletion_name = f"{final_table_name}_deleteme" - commands = [ - f"ALTER TABLE {final_table_name} RENAME TO {deletion_name}", - f"ALTER TABLE {temp_table_name} RENAME TO {final_table_name}", - f"DROP TABLE {deletion_name}", - ] - for cmd in commands: - self._execute_sql(cmd) + commands = "\n".join( + [ + f"ALTER TABLE {final_table_name} RENAME TO {deletion_name};", + f"ALTER TABLE {temp_table_name} RENAME TO {final_table_name};", + f"DROP TABLE {deletion_name};", + ] + ) + self._execute_sql(commands) def _merge_temp_table_to_final_table( self, diff --git a/airbyte-lib/airbyte_lib/caches/snowflake.py b/airbyte-lib/airbyte_lib/caches/snowflake.py index 05735a88026a..2a59f723af06 100644 --- a/airbyte-lib/airbyte_lib/caches/snowflake.py +++ b/airbyte-lib/airbyte_lib/caches/snowflake.py @@ -4,14 +4,21 @@ from __future__ import annotations +from textwrap import dedent, indent from typing import TYPE_CHECKING +import sqlalchemy from overrides import overrides -from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy import URL, VARIANT from airbyte_lib._file_writers import ParquetWriter, ParquetWriterConfig -from airbyte_lib.caches.base import RecordDedupeMode, SQLCacheBase, SQLCacheConfigBase +from airbyte_lib.caches.base import ( + RecordDedupeMode, + SQLCacheBase, + SQLCacheConfigBase, +) from airbyte_lib.telemetry import CacheTelemetryInfo +from airbyte_lib.types import SQLTypeConverter if TYPE_CHECKING: @@ -58,6 +65,26 @@ def get_database_name(self) -> str: return self.database +class SnowflakeTypeConverter(SQLTypeConverter): + """A class to convert types for Snowflake.""" + + @overrides + def to_sql_type( + self, + json_schema_property_def: dict[str, str | dict | list], + ) -> sqlalchemy.types.TypeEngine: + """Convert a value to a SQL type. + + We first call the parent class method to get the type. Then if the type JSON, we + replace it with VARIANT. + """ + sql_type = super().to_sql_type(json_schema_property_def) + if isinstance(sql_type, sqlalchemy.types.JSON): + return VARIANT() + + return sql_type + + class SnowflakeSQLCache(SQLCacheBase): """A Snowflake implementation of the cache. @@ -66,6 +93,7 @@ class SnowflakeSQLCache(SQLCacheBase): config_class = SnowflakeCacheConfig file_writer_class = ParquetWriter + type_converter_class = SnowflakeTypeConverter @overrides def _write_files_to_new_table( @@ -74,23 +102,62 @@ def _write_files_to_new_table( stream_name: str, batch_id: str, ) -> str: - """Write a file(s) to a new table. - - TODO: Override the base implementation to use the COPY command. - TODO: Make sure this works for all data types. - """ - return super()._write_files_to_new_table(files, stream_name, batch_id) + """Write files to a new table.""" + temp_table_name = self._create_table_for_loading( + stream_name=stream_name, + batch_id=batch_id, + ) + internal_sf_stage_name = f"@%{temp_table_name}" + put_files_statements = "\n".join( + [ + f"PUT 'file://{file_path.absolute()!s}' {internal_sf_stage_name};" + for file_path in files + ] + ) + self._execute_sql(put_files_statements) + + columns_list = [ + self._quote_identifier(c) + for c in list(self._get_sql_column_definitions(stream_name).keys()) + ] + files_list = ", ".join([f"'{f.name}'" for f in files]) + columns_list_str: str = indent("\n, ".join(columns_list), " " * 12) + variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list]) + copy_statement = dedent( + f""" + COPY INTO {temp_table_name} + ( + {columns_list_str} + ) + FROM ( + SELECT {variant_cols_str} + FROM {internal_sf_stage_name} + ) + FILES = ( {files_list} ) + FILE_FORMAT = ( TYPE = PARQUET ) + ; + """ + ) + self._execute_sql(copy_statement) + return temp_table_name @overrides def _init_connection_settings(self, connection: Connection) -> None: - """We override this method to set the QUOTED_IDENTIFIERS_IGNORE_CASE setting to True. - - This is necessary because Snowflake otherwise will treat quoted table and column references - as case-sensitive. + """We set Snowflake-specific settings for the session. + This sets QUOTED_IDENTIFIERS_IGNORE_CASE setting to True, which is necessary because + Snowflake otherwise will treat quoted table and column references as case-sensitive. More info: https://docs.snowflake.com/en/sql-reference/identifiers-syntax + + This also sets MULTI_STATEMENT_COUNT to 0, which allows multi-statement commands. """ - connection.execute("ALTER SESSION SET QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE") + connection.execute( + """ + ALTER SESSION SET + QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE + MULTI_STATEMENT_COUNT = 0 + """ + ) @overrides def get_telemetry_info(self) -> CacheTelemetryInfo: diff --git a/airbyte-lib/airbyte_lib/types.py b/airbyte-lib/airbyte_lib/types.py index c133b090c347..a95dbf59d68e 100644 --- a/airbyte-lib/airbyte_lib/types.py +++ b/airbyte-lib/airbyte_lib/types.py @@ -22,8 +22,8 @@ "time_without_timezone": sqlalchemy.types.TIME, # Technically 'object' and 'array' as JSON Schema types, not airbyte types. # We include them here for completeness. - "object": sqlalchemy.types.VARCHAR, - "array": sqlalchemy.types.VARCHAR, + "object": sqlalchemy.types.JSON, + "array": sqlalchemy.types.JSON, } diff --git a/airbyte-lib/docs/generated/airbyte_lib/caches.html b/airbyte-lib/docs/generated/airbyte_lib/caches.html index 3326ed02db8a..cf1eb7276567 100644 --- a/airbyte-lib/docs/generated/airbyte_lib/caches.html +++ b/airbyte-lib/docs/generated/airbyte_lib/caches.html @@ -924,6 +924,18 @@
Inherited Members
+ +
+
+ type_converter_class = +<class 'airbyte_lib.caches.snowflake.SnowflakeTypeConverter'> + + +
+ + + +
@@ -944,7 +956,6 @@
Inherited Members
SQLCacheBase
SQLCacheBase
-
type_converter_class
supports_merge_insert
use_singleton_connection
config
diff --git a/airbyte-lib/examples/run_github.py b/airbyte-lib/examples/run_github.py index fdb6c9f298b5..3faf22b60e19 100644 --- a/airbyte-lib/examples/run_github.py +++ b/airbyte-lib/examples/run_github.py @@ -24,9 +24,10 @@ } ) source.check() -source.select_streams(["issues", "pull_requests", "commits"]) +source.select_streams(["issues", "pull_requests", "commits", "collaborators"]) -result = source.read() +result = source.read(cache=ab.new_local_cache("github")) +print(result.processed_records) for name, records in result.streams.items(): print(f"Stream {name}: {len(records)} records") diff --git a/airbyte-lib/examples/run_snowflake_faker.py b/airbyte-lib/examples/run_snowflake_faker.py index 4587a0afce9e..56d8af8f10ef 100644 --- a/airbyte-lib/examples/run_snowflake_faker.py +++ b/airbyte-lib/examples/run_snowflake_faker.py @@ -39,7 +39,7 @@ source.check() -source.set_streams(["products"]) +source.select_streams(["products"]) result = source.read(cache) for name in ["products"]: diff --git a/airbyte-lib/examples/run_spacex.py b/airbyte-lib/examples/run_spacex.py index 02137f8ececf..f2695d7ff695 100644 --- a/airbyte-lib/examples/run_spacex.py +++ b/airbyte-lib/examples/run_spacex.py @@ -22,7 +22,7 @@ source.check() -source.set_streams(["launches", "rockets", "capsules"]) +source.select_streams(["launches", "rockets", "capsules"]) result = source.read(cache) diff --git a/airbyte-lib/poetry.lock b/airbyte-lib/poetry.lock index 752f802856eb..b4ca0726238c 100644 --- a/airbyte-lib/poetry.lock +++ b/airbyte-lib/poetry.lock @@ -1128,6 +1128,17 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "objprint" +version = "0.2.3" +description = "A library that can print Python objects in human readable format" +optional = false +python-versions = ">=3.6" +files = [ + {file = "objprint-0.2.3-py3-none-any.whl", hash = "sha256:1721e6f97bae5c5b86c2716a0d45a9dd2c9a4cd9f52cfc8a0dfbe801805554cb"}, + {file = "objprint-0.2.3.tar.gz", hash = "sha256:73d0ad5a7c3151fce634c8892e5c2a050ccae3b1a353bf1316f08b7854da863b"}, +] + [[package]] name = "orjson" version = "3.9.13" @@ -2509,6 +2520,55 @@ brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotl secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "viztracer" +version = "0.16.2" +description = "A debugging and profiling tool that can trace and visualize python code execution" +optional = false +python-versions = ">=3.8" +files = [ + {file = "viztracer-0.16.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:bdc62e90a2957e4119632e98f8b77d0ff1ab4db7029dd2e265bb3748e0fc0e05"}, + {file = "viztracer-0.16.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:789ac930e1c9621f04d275ee3ebb75a5d6109bcd4634796a77934608c60424d0"}, + {file = "viztracer-0.16.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee504771e3182045996a966d94d95d71693e59717b2643199162ec754a6e2400"}, + {file = "viztracer-0.16.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef9ecf4110d379245f17429d2a10391f3612f60b5618d0d61a30c110e9df2313"}, + {file = "viztracer-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:57c2574cc15b688eb0ce4e24a2c30f06c1df3bbe1dd16a1d18676e411e785f96"}, + {file = "viztracer-0.16.2-cp310-cp310-win32.whl", hash = "sha256:9fe652834f5073bf99debc25d8ba6084690fa2f26420621ca38a09efcae71b2f"}, + {file = "viztracer-0.16.2-cp310-cp310-win_amd64.whl", hash = "sha256:d59f57e3e46e116ce77e144f419739d1d8d976a903c51a822ba4ef167e5b37d4"}, + {file = "viztracer-0.16.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:b0bd434c43b7f87f76ddd21cf7371d910edb74b131aaff670a8fcc9f28251e67"}, + {file = "viztracer-0.16.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1bbbb9c80b08db692993c67e7b10d7b06db3eedc6c38f0d93a40ea31de82076e"}, + {file = "viztracer-0.16.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1e7842e437d81fb47ef8266b2dde76bf755c95305014eeec8346b2fce9711c0"}, + {file = "viztracer-0.16.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bddfe6a6f2a66f363fcca79a694986b0602ba0dc3dede57dc182cdd6d0823585"}, + {file = "viztracer-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc4a2639e6f18200b73a70f3e7dca4cbb3ba08e3807023fd526f44ebf2185d1e"}, + {file = "viztracer-0.16.2-cp311-cp311-win32.whl", hash = "sha256:371496734ebb3eafd6a6e033dbf04960618089e021dc7eded95179a8f3700c40"}, + {file = "viztracer-0.16.2-cp311-cp311-win_amd64.whl", hash = "sha256:d9c7670e7fb077fe48c92036766a6772e10a3caf41455d6244b8b1c8d48bbd87"}, + {file = "viztracer-0.16.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2fd8b5aa8143b5be4d696e53e8ac5027c20187c178396839f39f8aa610d5873d"}, + {file = "viztracer-0.16.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3a8ddc4990154f2d400b09deefc9236d963a733d458b2825bd590ced7e7bf89"}, + {file = "viztracer-0.16.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcf8b14dc8dd1567bca3f8cb13e31665a3cbf2ee95552de0afe9179e3a7bde22"}, + {file = "viztracer-0.16.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:309cf5d545222adb2581ae6aeb48d3d03d7241d335142408d87c49f1d0793f85"}, + {file = "viztracer-0.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee749a2a3f4ed662d35eb9378ff0648907aa6321befa16ad1d8bec6034b4d260"}, + {file = "viztracer-0.16.2-cp312-cp312-win32.whl", hash = "sha256:a082dab37b6b8cea43438b80a11a6e859f1b45522b8684a2fb9af03539d83803"}, + {file = "viztracer-0.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:03cd21181fe9a630ac5fb9ff1ee83fb7a67814e51e130f0ed83300e163fbac23"}, + {file = "viztracer-0.16.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:e920d383abae1b9314f2a60dd94e04c83998bfe759556af49d3c422d1d64d11e"}, + {file = "viztracer-0.16.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb9941b198fed8ba5b3f9d8105e59d37ab15f7f00b9a576686b1073990806d12"}, + {file = "viztracer-0.16.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1b7030aa6f934ff02882dfd48eca5a9442951b8be24c1dc5dc99fabbfb1997c"}, + {file = "viztracer-0.16.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:258087076c06d065d2786dc8a0f1f017d655d3753a8fe6836640c005c66a0c43"}, + {file = "viztracer-0.16.2-cp38-cp38-win32.whl", hash = "sha256:f0fd53e2fec972f9332677e6d11332ba789fcccf59060d7b9f309041602dc712"}, + {file = "viztracer-0.16.2-cp38-cp38-win_amd64.whl", hash = "sha256:ab067398029a50cc784d5456c5e8bef339b4bffaa1c3f0f9384a26b57c0efdaa"}, + {file = "viztracer-0.16.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:45879cf54ad9116245e2a6115660307f98ae3aa98a77347f2b336a904f260370"}, + {file = "viztracer-0.16.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abc61cfc36b33a301b950554d9e9027a506d580ebf1e764aa6656af0acfa3354"}, + {file = "viztracer-0.16.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:419f738bba8204e7ddb422faff3a40576896d030bbbf4fb79ace006147ca60e7"}, + {file = "viztracer-0.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c594022093bf9eee57ad2b9656f836dca2ed9c0b8e4d94a9d13a6cbc531386fe"}, + {file = "viztracer-0.16.2-cp39-cp39-win32.whl", hash = "sha256:4f98da282e87013a93917c2ae080ba52845e98ed5280faecdc42ee0c7fb74a4a"}, + {file = "viztracer-0.16.2-cp39-cp39-win_amd64.whl", hash = "sha256:64b97120374a572d2320fb795473c051c92d39dfc99fb74754e61e4c212e7617"}, + {file = "viztracer-0.16.2.tar.gz", hash = "sha256:8dff5637a7b42ffdbc1ed3768ce43979e71b09893ff370bc3c3ede54afed93ee"}, +] + +[package.dependencies] +objprint = ">0.1.3" + +[package.extras] +full = ["orjson"] + [[package]] name = "wcmatch" version = "8.4" @@ -2605,4 +2665,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "15665328452a67b8dfce18573caeae7856425cca2a3bafc5e9e455a619548314" +content-hash = "71b8790d6cb735dfe4349431938c1a7db17cf3578007481626d8408990eedd29" diff --git a/airbyte-lib/pyproject.toml b/airbyte-lib/pyproject.toml index 0101f8b49e18..203c446552ee 100644 --- a/airbyte-lib/pyproject.toml +++ b/airbyte-lib/pyproject.toml @@ -49,6 +49,7 @@ google-cloud-secret-manager = "^2.17.0" types-requests = "2.31.0.4" freezegun = "^1.4.0" airbyte-source-faker = "^6.0.0" +viztracer = "^0.16.2" [build-system] requires = ["poetry-core"] diff --git a/airbyte-lib/tests/integration_tests/test_snowflake_cache.py b/airbyte-lib/tests/integration_tests/test_snowflake_cache.py new file mode 100644 index 000000000000..be8aacd8decc --- /dev/null +++ b/airbyte-lib/tests/integration_tests/test_snowflake_cache.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +"""Integration tests which leverage the source-faker connector to test the framework end-to-end. + +Since source-faker is included in dev dependencies, we can assume `source-faker` is installed +and available on PATH for the poetry-managed venv. +""" +from __future__ import annotations +from collections.abc import Generator +import os +import sys +import shutil +from pathlib import Path + +import pytest +import ulid +import viztracer + +from airbyte_cdk.models import ConfiguredAirbyteCatalog + +import airbyte_lib as ab +from airbyte_lib import caches + + +# Product count is always the same, regardless of faker scale. +NUM_PRODUCTS = 100 + +SEED_A = 1234 +SEED_B = 5678 + +# Number of records in each of the 'users' and 'purchases' streams. +FAKER_SCALE_A = 200 +# We want this to be different from FAKER_SCALE_A. +FAKER_SCALE_B = 300 + + +# Patch PATH to include the source-faker executable. + +@pytest.fixture(autouse=True) +def add_venv_bin_to_path(monkeypatch): + # Get the path to the bin directory of the virtual environment + venv_bin_path = os.path.join(sys.prefix, 'bin') + + # Add the bin directory to the PATH + new_path = f"{venv_bin_path}:{os.environ['PATH']}" + monkeypatch.setenv('PATH', new_path) + + +@pytest.fixture(scope="function") # Each test gets a fresh source-faker instance. +def source_faker_seed_a() -> ab.Source: + """Fixture to return a source-faker connector instance.""" + source = ab.get_source( + "source-faker", + local_executable="source-faker", + config={ + "count": FAKER_SCALE_A, + "seed": SEED_A, + "parallelism": 16, # Otherwise defaults to 4. + }, + install_if_missing=False, # Should already be on PATH + ) + source.check() + source.select_streams([ + "users", + ]) + return source + + +@pytest.fixture(scope="function") # Each test gets a fresh source-faker instance. +def source_faker_seed_b() -> ab.Source: + """Fixture to return a source-faker connector instance.""" + source = ab.get_source( + "source-faker", + local_executable="source-faker", + config={ + "count": FAKER_SCALE_B, + "seed": SEED_B, + "parallelism": 16, # Otherwise defaults to 4. + }, + install_if_missing=False, # Should already be on PATH + ) + source.check() + source.select_streams([ + "users", + ]) + return source + + +@pytest.fixture(scope="function") +def snowflake_cache(snowflake_config) -> Generator[caches.SnowflakeCache, None, None]: + """Fixture to return a fresh cache.""" + cache: caches.SnowflakeCache = caches.SnowflakeSQLCache(snowflake_config) + yield cache + # TODO: Delete cache DB file after test is complete. + return + + +# Uncomment this line if you want to see performance trace logs. +# You can render perf traces using the viztracer CLI or the VS Code VizTracer Extension. +#@viztracer.trace_and_save(output_dir=".pytest_cache/snowflake_trace/") +def test_faker_read_to_snowflake( + source_faker_seed_a: ab.Source, + snowflake_cache: ab.SnowflakeCache, +) -> None: + """Test that the append strategy works as expected.""" + result = source_faker_seed_a.read( + snowflake_cache, write_strategy="replace", force_full_refresh=True + ) + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A + + +def test_replace_strategy( + source_faker_seed_a: ab.Source, + snowflake_cache: ab.SnowflakeCache, +) -> None: + """Test that the append strategy works as expected.""" + for _ in range(2): + result = source_faker_seed_a.read( + snowflake_cache, write_strategy="replace", force_full_refresh=True + ) + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A + + +def test_merge_strategy( + source_faker_seed_a: ab.Source, + source_faker_seed_b: ab.Source, + snowflake_cache: ab.DuckDBCache, +) -> None: + """Test that the merge strategy works as expected. + + Since all streams have primary keys, we should expect the auto strategy to be identical to the + merge strategy. + """ + # First run, seed A (counts should match the scale or the product count) + result = source_faker_seed_a.read(snowflake_cache, write_strategy="merge") + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A + + # Second run, also seed A (should have same exact data, no change in counts) + result = source_faker_seed_a.read(snowflake_cache, write_strategy="merge") + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_A + + # Third run, seed B - should increase record count to the scale of B, which is greater than A. + # TODO: See if we can reliably predict the exact number of records, since we use fixed seeds. + result = source_faker_seed_b.read(snowflake_cache, write_strategy="merge") + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_B + + # Third run, seed A again - count should stay at scale B, since A is smaller. + # TODO: See if we can reliably predict the exact number of records, since we use fixed seeds. + result = source_faker_seed_a.read(snowflake_cache, write_strategy="merge") + assert len(list(result.cache.streams["users"])) == FAKER_SCALE_B diff --git a/airbyte-lib/tests/integration_tests/test_source_faker_integration.py b/airbyte-lib/tests/integration_tests/test_source_faker_integration.py index 54da1f1f681e..22e16b6ef0af 100644 --- a/airbyte-lib/tests/integration_tests/test_source_faker_integration.py +++ b/airbyte-lib/tests/integration_tests/test_source_faker_integration.py @@ -13,11 +13,13 @@ from pathlib import Path import pytest +import ulid +import viztracer + +from airbyte_cdk.models import ConfiguredAirbyteCatalog import airbyte_lib as ab from airbyte_lib import caches -from airbyte_cdk.models import ConfiguredAirbyteCatalog -import ulid # Product count is always the same, regardless of faker scale. @@ -64,8 +66,8 @@ def source_faker_seed_a() -> ab.Source: install_if_missing=False, # Should already be on PATH ) source.check() - # TODO: We can optionally add back 'users' once Postgres can handle complex object types. - source.set_streams([ + source.select_streams([ + "users", "products", "purchases", ]) @@ -86,10 +88,10 @@ def source_faker_seed_b() -> ab.Source: install_if_missing=False, # Should already be on PATH ) source.check() - # TODO: We can optionally add back 'users' once Postgres can handle complex object types. - source.set_streams([ + source.select_streams([ "products", "purchases", + "users", ]) return source @@ -103,15 +105,6 @@ def duckdb_cache() -> Generator[caches.DuckDBCache, None, None]: return -@pytest.fixture(scope="function") -def snowflake_cache(snowflake_config) -> Generator[caches.SnowflakeCache, None, None]: - """Fixture to return a fresh cache.""" - cache: caches.SnowflakeCache = caches.SnowflakeSQLCache(snowflake_config) - yield cache - # TODO: Delete cache DB file after test is complete. - return - - @pytest.fixture(scope="function") def postgres_cache(new_pg_cache_config) -> Generator[caches.PostgresCache, None, None]: """Fixture to return a fresh cache.""" @@ -124,17 +117,14 @@ def postgres_cache(new_pg_cache_config) -> Generator[caches.PostgresCache, None, @pytest.fixture def all_cache_types( duckdb_cache: ab.DuckDBCache, - snowflake_cache: ab.SnowflakeCache, postgres_cache: ab.PostgresCache, ): _ = postgres_cache return [ duckdb_cache, postgres_cache, - # snowflake_cache, # Snowflake works, but is slow and expensive to test. # TODO: Re-enable. ] - def test_faker_pks( source_faker_seed_a: ab.Source, duckdb_cache: ab.DuckDBCache, @@ -143,7 +133,6 @@ def test_faker_pks( catalog: ConfiguredAirbyteCatalog = source_faker_seed_a.configured_catalog - assert len(catalog.streams) == 2 assert catalog.streams[0].primary_key assert catalog.streams[1].primary_key @@ -162,7 +151,6 @@ def test_replace_strategy( result = source_faker_seed_a.read( cache, write_strategy="replace", force_full_refresh=True ) - assert len(result.cache.streams) == 2 assert len(list(result.cache.streams["products"])) == NUM_PRODUCTS assert len(list(result.cache.streams["purchases"])) == FAKER_SCALE_A @@ -175,7 +163,6 @@ def test_append_strategy( for cache in all_cache_types: # Function-scoped fixtures can't be used in parametrized(). for iteration in range(1, 3): result = source_faker_seed_a.read(cache, write_strategy="append") - assert len(result.cache.streams) == 2 assert len(list(result.cache.streams["products"])) == NUM_PRODUCTS * iteration assert len(list(result.cache.streams["purchases"])) == FAKER_SCALE_A * iteration @@ -195,7 +182,6 @@ def test_merge_strategy( for cache in all_cache_types: # Function-scoped fixtures can't be used in parametrized(). # First run, seed A (counts should match the scale or the product count) result = source_faker_seed_a.read(cache, write_strategy=strategy) - assert len(result.cache.streams) == 2 assert len(list(result.cache.streams["products"])) == NUM_PRODUCTS assert len(list(result.cache.streams["purchases"])) == FAKER_SCALE_A @@ -232,9 +218,9 @@ def test_incremental_sync( result1 = source_faker_seed_a.read(duckdb_cache) assert len(list(result1.cache.streams["products"])) == NUM_PRODUCTS assert len(list(result1.cache.streams["purchases"])) == FAKER_SCALE_A - assert result1.processed_records == NUM_PRODUCTS + FAKER_SCALE_A + assert result1.processed_records == NUM_PRODUCTS + FAKER_SCALE_A * 2 - assert not duckdb_cache.get_state() == [] + assert not duckdb_cache.get_state() == [] # Second run should not return records as it picks up the state and knows it's up to date. result2 = source_faker_seed_b.read(duckdb_cache) @@ -257,13 +243,13 @@ def test_incremental_state_cache_persistence( cache_name = str(ulid.ULID()) cache = ab.new_local_cache(cache_name) result = source_faker_seed_a.read(cache) - assert result.processed_records == NUM_PRODUCTS + FAKER_SCALE_A + assert result.processed_records == NUM_PRODUCTS + FAKER_SCALE_A * 2 second_cache = ab.new_local_cache(cache_name) # The state should be persisted across cache instances. result2 = source_faker_seed_b.read(second_cache) assert result2.processed_records == 0 - assert not second_cache.get_state() == [] + assert not second_cache.get_state() == [] assert len(list(result2.cache.streams["products"])) == NUM_PRODUCTS assert len(list(result2.cache.streams["purchases"])) == FAKER_SCALE_A @@ -284,10 +270,10 @@ def test_incremental_state_prefix_isolation( different_prefix_cache = ab.DuckDBCache(config=ab.DuckDBCacheConfig(db_path=db_path, table_prefix="different_prefix_")) result = source_faker_seed_a.read(cache) - assert result.processed_records == NUM_PRODUCTS + FAKER_SCALE_A + assert result.processed_records == NUM_PRODUCTS + FAKER_SCALE_A * 2 result2 = source_faker_seed_b.read(different_prefix_cache) - assert result2.processed_records == NUM_PRODUCTS + FAKER_SCALE_B + assert result2.processed_records == NUM_PRODUCTS + FAKER_SCALE_B * 2 assert len(list(result2.cache.streams["products"])) == NUM_PRODUCTS - assert len(list(result2.cache.streams["purchases"])) == FAKER_SCALE_B \ No newline at end of file + assert len(list(result2.cache.streams["purchases"])) == FAKER_SCALE_B diff --git a/airbyte-lib/tests/integration_tests/test_source_test_fixture.py b/airbyte-lib/tests/integration_tests/test_source_test_fixture.py index 2824f0db9c6f..6fc9d04bc047 100644 --- a/airbyte-lib/tests/integration_tests/test_source_test_fixture.py +++ b/airbyte-lib/tests/integration_tests/test_source_test_fixture.py @@ -296,7 +296,7 @@ def test_read_isolated_by_prefix(expected_test_stream_data: dict[str, list[dict[ assert len(list(no_prefix_cache.__iter__())) == 0 # read partial data into the other two caches - source.set_streams(["stream1"]) + source.select_streams(["stream1"]) source.read(different_prefix_cache) source.read(no_prefix_cache) @@ -318,7 +318,7 @@ def test_merge_streams_in_cache(expected_test_stream_data: dict[str, list[dict[s source = ab.get_source("source-test", config={"apiKey": "test"}) cache = ab.new_local_cache(cache_name) - source.set_streams(["stream1"]) + source.select_streams(["stream1"]) source.read(cache) # Assert that the cache only contains stream1 @@ -327,7 +327,7 @@ def test_merge_streams_in_cache(expected_test_stream_data: dict[str, list[dict[s # Create a new cache with the same name second_cache = ab.new_local_cache(cache_name) - source.set_streams(["stream2"]) + source.select_streams(["stream2"]) result = source.read(second_cache) # Assert that the read result only contains stream2 @@ -604,7 +604,7 @@ def test_tracking( request_call_fails: bool, extra_env: dict[str, str], expected_flags: dict[str, bool], - cache_type: str, + cache_type: str, number_of_records_read: int ): """ @@ -709,7 +709,7 @@ def test_sync_limited_streams(expected_test_stream_data): source = ab.get_source("source-test", config={"apiKey": "test"}) cache = ab.new_local_cache() - source.set_streams(["stream2"]) + source.select_streams(["stream2"]) result = source.read(cache) diff --git a/airbyte-lib/tests/unit_tests/test_type_translation.py b/airbyte-lib/tests/unit_tests/test_type_translation.py index cb5f59f7feba..a2c255c5b0d7 100644 --- a/airbyte-lib/tests/unit_tests/test_type_translation.py +++ b/airbyte-lib/tests/unit_tests/test_type_translation.py @@ -15,6 +15,9 @@ ({"type": ["null", "string"]}, types.VARCHAR), ({"type": "boolean"}, types.BOOLEAN), ({"type": "string", "format": "date"}, types.DATE), + ({"type": ["null", "string"]}, types.VARCHAR), + ({"type": ["null", "boolean"]}, types.BOOLEAN), + ({"type": ["null", "number"]}, types.DECIMAL), ({"type": "string", "format": "date-time", "airbyte_type": "timestamp_without_timezone"}, types.TIMESTAMP), ({"type": "string", "format": "date-time", "airbyte_type": "timestamp_with_timezone"}, types.TIMESTAMP), ({"type": "string", "format": "time", "airbyte_type": "time_without_timezone"}, types.TIME), @@ -22,8 +25,8 @@ ({"type": "integer"}, types.BIGINT), ({"type": "number", "airbyte_type": "integer"}, types.BIGINT), ({"type": "number"}, types.DECIMAL), - ({"type": "array"}, types.VARCHAR), - ({"type": "object"}, types.VARCHAR), + ({"type": "array", "items": {"type": "object"}}, types.JSON), + ({"type": "object", "properties": {}}, types.JSON), ], ) def test_to_sql_type(json_schema_property_def, expected_sql_type):