From f9c12440263bddf80bea2507e9cb747fe17947df Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 21 Feb 2024 14:33:12 -0800 Subject: [PATCH 1/6] add error handling, migrate all caches to new jsonl file writer --- airbyte/_file_writers/__init__.py | 3 ++ airbyte/_file_writers/jsonl.py | 65 +++++++++++++++++++++++++++++++ airbyte/_file_writers/parquet.py | 24 ++++++++++-- airbyte/_processors.py | 20 ++++++++-- airbyte/caches/duckdb.py | 11 +++--- airbyte/caches/postgres.py | 8 ++-- airbyte/caches/snowflake.py | 10 ++--- airbyte/types.py | 60 ++++++++++++++++++++++++++++ 8 files changed, 180 insertions(+), 21 deletions(-) create mode 100644 airbyte/_file_writers/jsonl.py diff --git a/airbyte/_file_writers/__init__.py b/airbyte/_file_writers/__init__.py index aae8c474..bdec092c 100644 --- a/airbyte/_file_writers/__init__.py +++ b/airbyte/_file_writers/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase +from .jsonl import JsonlWriter, JsonlWriterConfig from .parquet import ParquetWriter, ParquetWriterConfig @@ -8,6 +9,8 @@ "FileWriterBatchHandle", "FileWriterBase", "FileWriterConfigBase", + "JsonlWriter", + "JsonlWriterConfig", "ParquetWriter", "ParquetWriterConfig", ] diff --git a/airbyte/_file_writers/jsonl.py b/airbyte/_file_writers/jsonl.py new file mode 100644 index 00000000..18e94a8f --- /dev/null +++ b/airbyte/_file_writers/jsonl.py @@ -0,0 +1,65 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +"""A Parquet cache implementation.""" +from __future__ import annotations + +import gzip +from pathlib import Path +from typing import cast + +import orjson +import pyarrow as pa +import ulid +from overrides import overrides + +from airbyte._file_writers.base import ( + FileWriterBase, + FileWriterBatchHandle, + FileWriterConfigBase, +) + + +class JsonlWriterConfig(FileWriterConfigBase): + """Configuration for the Snowflake cache.""" + + # Inherits `cache_dir` from base class + + +class JsonlWriter(FileWriterBase): + """A Jsonl cache implementation.""" + + config_class = JsonlWriterConfig + + def get_new_cache_file_path( + self, + stream_name: str, + batch_id: str | None = None, # ULID of the batch + ) -> Path: + """Return a new cache file path for the given stream.""" + batch_id = batch_id or str(ulid.ULID()) + config: JsonlWriterConfig = cast(JsonlWriterConfig, self.config) + target_dir = Path(config.cache_dir) + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir / f"{stream_name}_{batch_id}.jsonl.gz" + + @overrides + def _write_batch( + self, + stream_name: str, + batch_id: str, + record_batch: pa.Table, + ) -> FileWriterBatchHandle: + """Process a record batch. + + Return the path to the cache file. + """ + _ = batch_id # unused + output_file_path = self.get_new_cache_file_path(stream_name) + + with gzip.open(output_file_path, "w") as jsonl_file: + for record in record_batch.to_pylist(): + jsonl_file.write(orjson.dumps(record) + b"\n") + + batch_handle = FileWriterBatchHandle() + batch_handle.files.append(output_file_path) + return batch_handle diff --git a/airbyte/_file_writers/parquet.py b/airbyte/_file_writers/parquet.py index 8b03cfea..2c4d696b 100644 --- a/airbyte/_file_writers/parquet.py +++ b/airbyte/_file_writers/parquet.py @@ -1,6 +1,12 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -"""A Parquet cache implementation.""" +"""A Parquet cache implementation. + +NOTE: Parquet is a strongly typed columnar storage format, which has known issues when applied to +variable schemas, schemas with indeterminate types, and schemas that have empty data nodes. +This implementation is deprecated for now in favor of jsonl.gz, and may be removed or revamped in +the future. +""" from __future__ import annotations from pathlib import Path @@ -83,8 +89,20 @@ def _write_batch( for col in missing_columns: record_batch = record_batch.append_column(col, null_array) - with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer: - writer.write_table(record_batch) + try: + with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer: + writer.write_table(record_batch) + except Exception as e: + raise exc.AirbyteLibInternalError( + message=f"Failed to write record batch to Parquet file: {e}", + context={ + "stream_name": stream_name, + "batch_id": batch_id, + "output_file_path": output_file_path, + "schema": record_batch.schema, + "record_batch": record_batch, + }, + ) from e batch_handle = FileWriterBatchHandle() batch_handle.files.append(output_file_path) diff --git a/airbyte/_processors.py b/airbyte/_processors.py index b8db7b9f..e67069b5 100644 --- a/airbyte/_processors.py +++ b/airbyte/_processors.py @@ -17,6 +17,7 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, cast, final +import pandas as pd import pyarrow as pa import ulid @@ -35,6 +36,7 @@ from airbyte._util import protocol_util from airbyte.progress import progress from airbyte.strategies import WriteStrategy +from airbyte.types import _get_pyarrow_type if TYPE_CHECKING: @@ -177,17 +179,26 @@ def process_airbyte_messages( ) stream_batches: dict[str, list[dict]] = defaultdict(list, {}) - + pyarrow_schemas: dict[str, pa.Schema] = {} # Process messages, writing to batches as we go for message in messages: if message.type is Type.RECORD: record_msg = cast(AirbyteRecordMessage, message.record) stream_name = record_msg.stream + if stream_name not in pyarrow_schemas: + pyarrow_schemas[stream_name] = pa.schema( + fields=[ + (prop_name, _get_pyarrow_type(prop_def)) + for prop_name, prop_def in self._get_stream_json_schema(stream_name)[ + "properties" + ].items() + ] + ) stream_batch = stream_batches[stream_name] stream_batch.append(protocol_util.airbyte_record_message_to_dict(record_msg)) - if len(stream_batch) >= max_batch_size: - record_batch = pa.Table.from_pylist(stream_batch) + batch_df = pd.DataFrame(stream_batch) + record_batch = pa.Table.from_pandas(batch_df) self._process_batch(stream_name, record_batch) progress.log_batch_written(stream_name, len(stream_batch)) stream_batch.clear() @@ -215,7 +226,8 @@ def process_airbyte_messages( # We are at the end of the stream. Process whatever else is queued. for stream_name, stream_batch in stream_batches.items(): - record_batch = pa.Table.from_pylist(stream_batch) + batch_df = pd.DataFrame(stream_batch) + record_batch = pa.Table.from_pandas(batch_df) self._process_batch(stream_name, record_batch) progress.log_batch_written(stream_name, len(stream_batch)) diff --git a/airbyte/caches/duckdb.py b/airbyte/caches/duckdb.py index 6570158b..254b3c9e 100644 --- a/airbyte/caches/duckdb.py +++ b/airbyte/caches/duckdb.py @@ -11,7 +11,7 @@ from overrides import overrides -from airbyte._file_writers import ParquetWriter, ParquetWriterConfig +from airbyte._file_writers import JsonlWriter, JsonlWriterConfig from airbyte.caches.base import SQLCacheBase, SQLCacheConfigBase from airbyte.telemetry import CacheTelemetryInfo @@ -24,10 +24,10 @@ ) -class DuckDBCacheConfig(SQLCacheConfigBase, ParquetWriterConfig): +class DuckDBCacheConfig(SQLCacheConfigBase, JsonlWriterConfig): """Configuration for the DuckDB cache. - Also inherits config from the ParquetWriter, which is responsible for writing files to disk. + Also inherits config from the JsonlWriter, which is responsible for writing files to disk. """ db_path: Path | str @@ -88,7 +88,7 @@ class DuckDBCache(DuckDBCacheBase): so we insert as values instead. """ - file_writer_class = ParquetWriter + file_writer_class = JsonlWriter # TODO: Delete or rewrite this method after DuckDB adds support for primary key inspection. # @overrides @@ -195,8 +195,9 @@ def _write_files_to_new_table( ) SELECT {columns_list_str} - FROM read_parquet( + FROM read_json_auto( [{files_list}], + format = 'newline_delimited', union_by_name = true ) """ diff --git a/airbyte/caches/postgres.py b/airbyte/caches/postgres.py index 33ea585d..f1cd6468 100644 --- a/airbyte/caches/postgres.py +++ b/airbyte/caches/postgres.py @@ -6,15 +6,15 @@ from overrides import overrides -from airbyte._file_writers import ParquetWriter, ParquetWriterConfig +from airbyte._file_writers import JsonlWriter, JsonlWriterConfig from airbyte.caches.base import SQLCacheBase, SQLCacheConfigBase from airbyte.telemetry import CacheTelemetryInfo -class PostgresCacheConfig(SQLCacheConfigBase, ParquetWriterConfig): +class PostgresCacheConfig(SQLCacheConfigBase, JsonlWriterConfig): """Configuration for the Postgres cache. - Also inherits config from the ParquetWriter, which is responsible for writing files to disk. + Also inherits config from the JsonlWriter, which is responsible for writing files to disk. """ host: str @@ -47,7 +47,7 @@ class PostgresCache(SQLCacheBase): """ config_class = PostgresCacheConfig - file_writer_class = ParquetWriter + file_writer_class = JsonlWriter supports_merge_insert = False # TODO: Add native implementation for merge insert @overrides diff --git a/airbyte/caches/snowflake.py b/airbyte/caches/snowflake.py index dde50335..ec2a8692 100644 --- a/airbyte/caches/snowflake.py +++ b/airbyte/caches/snowflake.py @@ -11,7 +11,7 @@ from overrides import overrides from snowflake.sqlalchemy import URL, VARIANT -from airbyte._file_writers import ParquetWriter, ParquetWriterConfig +from airbyte._file_writers import JsonlWriter, JsonlWriterConfig from airbyte.caches.base import ( RecordDedupeMode, SQLCacheBase, @@ -27,10 +27,10 @@ from sqlalchemy.engine import Connection -class SnowflakeCacheConfig(SQLCacheConfigBase, ParquetWriterConfig): +class SnowflakeCacheConfig(SQLCacheConfigBase, JsonlWriterConfig): """Configuration for the Snowflake cache. - Also inherits config from the ParquetWriter, which is responsible for writing files to disk. + Also inherits config from the JsonlWriter, which is responsible for writing files to disk. """ account: str @@ -92,7 +92,7 @@ class SnowflakeSQLCache(SQLCacheBase): """ config_class = SnowflakeCacheConfig - file_writer_class = ParquetWriter + file_writer_class = JsonlWriter type_converter_class = SnowflakeTypeConverter @overrides @@ -134,7 +134,7 @@ def _write_files_to_new_table( FROM {internal_sf_stage_name} ) FILES = ( {files_list} ) - FILE_FORMAT = ( TYPE = PARQUET ) + FILE_FORMAT = ( TYPE = JSON ) ; """ ) diff --git a/airbyte/types.py b/airbyte/types.py index a95dbf59..5c795d84 100644 --- a/airbyte/types.py +++ b/airbyte/types.py @@ -5,6 +5,7 @@ from typing import cast +import pyarrow as pa import sqlalchemy from rich import print @@ -80,6 +81,65 @@ def _get_airbyte_type( # noqa: PLR0911 # Too many return statements raise SQLTypeConversionError(err_msg) +def _get_pyarrow_type( # noqa: PLR0911 # Too many return statements + json_schema_property_def: dict[str, str | dict | list], +) -> pa.DataType: + airbyte_type = cast(str, json_schema_property_def.get("airbyte_type", None)) + if airbyte_type: + return airbyte_type, None + + json_schema_type = json_schema_property_def.get("type", None) + json_schema_format = json_schema_property_def.get("format", None) + + # if json_schema_type is an array of two strings with one of them being null, pick the other one + # this strategy is often used by connectors to indicate a field might not be set all the time + if isinstance(json_schema_type, list): + non_null_types = [t for t in json_schema_type if t != "null"] + if len(non_null_types) == 1: + json_schema_type = non_null_types[0] + + if json_schema_type == "string": + if json_schema_format == "date": + return pa.date64() + + if json_schema_format == "date-time": + return pa.timestamp("ns") + + if json_schema_format == "time": + return pa.timestamp("ns") + + if json_schema_type == "string": + return pa.string() + + if json_schema_type == "number": + return pa.float64() + + if json_schema_type == "integer": + return pa.int64() + + if json_schema_type == "boolean": + return pa.bool_() + + if json_schema_type == "object": + return pa.struct( + fields={ + k: _get_pyarrow_type(v) + for k, v in json_schema_property_def.get("properties", {}).items() + } + ) + + if json_schema_type == "array": + items_def = json_schema_property_def.get("items", None) + if isinstance(items_def, dict): + subtype: pa.DataType = _get_pyarrow_type(items_def) + return pa.list_(subtype) + + return pa.list_(pa.string()) + + err_msg = f"Could not determine PyArrow type from JSON schema type: {json_schema_property_def}" + raise SQLTypeConversionError(err_msg) + + class SQLTypeConverter: """A base class to perform type conversions.""" From 22a6a40371077915fc0454e053c52dbf184186cc Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 21 Feb 2024 16:03:56 -0800 Subject: [PATCH 2/6] move out pyarrow schema detection --- airbyte/_processors.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/airbyte/_processors.py b/airbyte/_processors.py index e67069b5..cd1b1ae7 100644 --- a/airbyte/_processors.py +++ b/airbyte/_processors.py @@ -179,21 +179,11 @@ def process_airbyte_messages( ) stream_batches: dict[str, list[dict]] = defaultdict(list, {}) - pyarrow_schemas: dict[str, pa.Schema] = {} # Process messages, writing to batches as we go for message in messages: if message.type is Type.RECORD: record_msg = cast(AirbyteRecordMessage, message.record) stream_name = record_msg.stream - if stream_name not in pyarrow_schemas: - pyarrow_schemas[stream_name] = pa.schema( - fields=[ - (prop_name, _get_pyarrow_type(prop_def)) - for prop_name, prop_def in self._get_stream_json_schema(stream_name)[ - "properties" - ].items() - ] - ) stream_batch = stream_batches[stream_name] stream_batch.append(protocol_util.airbyte_record_message_to_dict(record_msg)) if len(stream_batch) >= max_batch_size: @@ -406,3 +396,17 @@ def _get_stream_json_schema( ) -> dict[str, Any]: """Return the column definitions for the given stream.""" return self._get_stream_config(stream_name).stream.json_schema + + def _get_stream_pyarrow_schema( + self, + stream_name: str, + ) -> pa.Schema: + """Return the column definitions for the given stream.""" + return pa.schema( + fields=[ + (prop_name, _get_pyarrow_type(prop_def)) + for prop_name, prop_def in self._get_stream_json_schema(stream_name)[ + "properties" + ].items() + ] + ) From 2346e3295d414c2e41ff1d2737efce0cc8387006 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 21 Feb 2024 16:04:17 -0800 Subject: [PATCH 3/6] make default and duckdb implementations compatible with jsonl --- airbyte/caches/base.py | 38 ++++++++++++++++++-------------------- airbyte/caches/duckdb.py | 23 +++++++++++++++++------ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index 02a1e977..2e22e1f9 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -659,27 +659,25 @@ def _write_files_to_new_table( """ temp_table_name = self._create_table_for_loading(stream_name, batch_id) for file_path in files: - with pa.parquet.ParquetFile(file_path) as pf: - record_batch = pf.read() - dataframe = record_batch.to_pandas() - - # Pandas will auto-create the table if it doesn't exist, which we don't want. - if not self._table_exists(temp_table_name): - raise exc.AirbyteLibInternalError( - message="Table does not exist after creation.", - context={ - "temp_table_name": temp_table_name, - }, - ) - - dataframe.to_sql( - temp_table_name, - self.get_sql_alchemy_url(), - schema=self.config.schema_name, - if_exists="append", - index=False, - dtype=self._get_sql_column_definitions(stream_name), + dataframe = pd.read_json(file_path, lines=True) + + # Pandas will auto-create the table if it doesn't exist, which we don't want. + if not self._table_exists(temp_table_name): + raise exc.AirbyteLibInternalError( + message="Table does not exist after creation.", + context={ + "temp_table_name": temp_table_name, + }, ) + + dataframe.to_sql( + temp_table_name, + self.get_sql_alchemy_url(), + schema=self.config.schema_name, + if_exists="append", + index=False, + dtype=self._get_sql_column_definitions(stream_name), + ) return temp_table_name @final diff --git a/airbyte/caches/duckdb.py b/airbyte/caches/duckdb.py index 254b3c9e..66bc046e 100644 --- a/airbyte/caches/duckdb.py +++ b/airbyte/caches/duckdb.py @@ -181,12 +181,22 @@ def _write_files_to_new_table( stream_name=stream_name, batch_id=batch_id, ) - columns_list = [ - self._quote_identifier(c) - for c in list(self._get_sql_column_definitions(stream_name).keys()) - ] - columns_list_str = indent("\n, ".join(columns_list), " ") + columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys()) + columns_list_str = indent( + "\n, ".join([self._quote_identifier(c) for c in columns_list]), + " ", + ) files_list = ", ".join([f"'{f!s}'" for f in files]) + columns_type_map = indent( + "\n, ".join( + [ + f"{self._quote_identifier(c)}: " + f"{self._get_sql_column_definitions(stream_name)[c]!s}" + for c in columns_list + ] + ), + " ", + ) insert_statement = dedent( f""" INSERT INTO {self.config.schema_name}.{temp_table_name} @@ -198,7 +208,8 @@ def _write_files_to_new_table( FROM read_json_auto( [{files_list}], format = 'newline_delimited', - union_by_name = true + union_by_name = true, + columns = {{ { columns_type_map } }} ) """ ) From ffb336a8957d91d032a737f42e456b922c7b710a Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 21 Feb 2024 16:06:28 -0800 Subject: [PATCH 4/6] fix lint issues --- airbyte/_file_writers/jsonl.py | 7 +++++-- airbyte/caches/base.py | 2 +- tests/unit_tests/test_caches.py | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/airbyte/_file_writers/jsonl.py b/airbyte/_file_writers/jsonl.py index 18e94a8f..c47056e5 100644 --- a/airbyte/_file_writers/jsonl.py +++ b/airbyte/_file_writers/jsonl.py @@ -5,10 +5,9 @@ import gzip from pathlib import Path -from typing import cast +from typing import TYPE_CHECKING, cast import orjson -import pyarrow as pa import ulid from overrides import overrides @@ -19,6 +18,10 @@ ) +if TYPE_CHECKING: + import pyarrow as pa + + class JsonlWriterConfig(FileWriterConfigBase): """Configuration for the Snowflake cache.""" diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index 2e22e1f9..9095df8c 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, cast, final import pandas as pd -import pyarrow as pa import sqlalchemy import ulid from overrides import overrides @@ -42,6 +41,7 @@ from collections.abc import Generator, Iterator from pathlib import Path + import pyarrow as pa from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.reflection import Inspector diff --git a/tests/unit_tests/test_caches.py b/tests/unit_tests/test_caches.py index 2830fbac..90301931 100644 --- a/tests/unit_tests/test_caches.py +++ b/tests/unit_tests/test_caches.py @@ -4,7 +4,7 @@ import pytest -from airbyte._file_writers import ParquetWriterConfig +from airbyte._file_writers import JsonlWriterConfig from airbyte.caches.base import SQLCacheBase, SQLCacheConfigBase from airbyte.caches.duckdb import DuckDBCacheBase, DuckDBCacheConfig @@ -28,7 +28,7 @@ def test_get_sql_alchemy_url_with_default_schema_name(): def test_duck_db_cache_config_inheritance(): assert issubclass(DuckDBCacheConfig, SQLCacheConfigBase) - assert issubclass(DuckDBCacheConfig, ParquetWriterConfig) + assert issubclass(DuckDBCacheConfig, JsonlWriterConfig) def test_duck_db_cache_config_get_sql_alchemy_url(): config = DuckDBCacheConfig(db_path='test_path', schema_name='test_schema') @@ -57,4 +57,4 @@ def test_duck_db_cache_config_inheritance_from_sql_cache_config_base(): assert issubclass(DuckDBCacheConfig, SQLCacheConfigBase) def test_duck_db_cache_config_inheritance_from_parquet_writer_config(): - assert issubclass(DuckDBCacheConfig, ParquetWriterConfig) + assert issubclass(DuckDBCacheConfig, JsonlWriterConfig) From 7023f0ad87b9465f52e1b8ab5e96b7473f013c43 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 21 Feb 2024 16:25:04 -0800 Subject: [PATCH 5/6] fix mypy --- airbyte/_processors.py | 2 +- airbyte/types.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/airbyte/_processors.py b/airbyte/_processors.py index cd1b1ae7..83d7f4ca 100644 --- a/airbyte/_processors.py +++ b/airbyte/_processors.py @@ -404,7 +404,7 @@ def _get_stream_pyarrow_schema( """Return the column definitions for the given stream.""" return pa.schema( fields=[ - (prop_name, _get_pyarrow_type(prop_def)) + pa.field(prop_name, _get_pyarrow_type(prop_def)) for prop_name, prop_def in self._get_stream_json_schema(stream_name)[ "properties" ].items() diff --git a/airbyte/types.py b/airbyte/types.py index 5c795d84..64322bda 100644 --- a/airbyte/types.py +++ b/airbyte/types.py @@ -84,10 +84,6 @@ def _get_airbyte_type( # noqa: PLR0911 # Too many return statements def _get_pyarrow_type( # noqa: PLR0911 # Too many return statements json_schema_property_def: dict[str, str | dict | list], ) -> pa.DataType: - airbyte_type = cast(str, json_schema_property_def.get("airbyte_type", None)) - if airbyte_type: - return airbyte_type, None - json_schema_type = json_schema_property_def.get("type", None) json_schema_format = json_schema_property_def.get("format", None) @@ -124,7 +120,7 @@ def _get_pyarrow_type( # noqa: PLR0911 # Too many return statements return pa.struct( fields={ k: _get_pyarrow_type(v) - for k, v in json_schema_property_def.get("properties", {}).items() + for k, v in cast(dict, json_schema_property_def.get("properties", {})).items() } ) From 75abf9828d081cafe60b0434cb5c818cb320208f Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 21 Feb 2024 20:11:33 -0800 Subject: [PATCH 6/6] fix hang from loading blank file to duckdb --- airbyte/_processors.py | 17 +++++------ airbyte/caches/base.py | 28 +++++++++++-------- .../test_source_test_fixture.py | 7 +++-- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/airbyte/_processors.py b/airbyte/_processors.py index 83d7f4ca..5b8c94d1 100644 --- a/airbyte/_processors.py +++ b/airbyte/_processors.py @@ -207,13 +207,6 @@ def process_airbyte_messages( # Type.LOG, Type.TRACE, Type.CONTROL, etc. pass - # Add empty streams to the dictionary, so we create a destination table for it - for stream_name in self._expected_streams: - if stream_name not in stream_batches: - if DEBUG_MODE: - print(f"Stream {stream_name} has no data") - stream_batches[stream_name] = [] - # We are at the end of the stream. Process whatever else is queued. for stream_name, stream_batch in stream_batches.items(): batch_df = pd.DataFrame(stream_batch) @@ -221,8 +214,16 @@ def process_airbyte_messages( self._process_batch(stream_name, record_batch) progress.log_batch_written(stream_name, len(stream_batch)) + all_streams = list(self._pending_batches.keys()) + # Add empty streams to the streams list, so we create a destination table for it + for stream_name in self._expected_streams: + if stream_name not in all_streams: + if DEBUG_MODE: + print(f"Stream {stream_name} has no data") + all_streams.append(stream_name) + # Finalize any pending batches - for stream_name in list(self._pending_batches.keys()): + for stream_name in all_streams: self._finalize_batches(stream_name, write_strategy=write_strategy) progress.log_stream_finalized(stream_name) diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index 9095df8c..47e0b743 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -545,17 +545,6 @@ def _finalize_batches( although this is a fairly rare edge case we can ignore in V1. """ with self._finalizing_batches(stream_name) as batches_to_finalize: - if not batches_to_finalize: - return {} - - files: list[Path] = [] - # Get a list of all files to finalize from all pending batches. - for batch_handle in batches_to_finalize.values(): - batch_handle = cast(FileWriterBatchHandle, batch_handle) - files += batch_handle.files - # Use the max batch ID as the batch ID for table names. - max_batch_id = max(batches_to_finalize.keys()) - # Make sure the target schema and target table exist. self._ensure_schema_exists() final_table_name = self._ensure_final_table_exists( @@ -567,6 +556,18 @@ def _finalize_batches( raise_on_error=True, ) + if not batches_to_finalize: + # If there are no batches to finalize, return after ensuring the table exists. + return {} + + files: list[Path] = [] + # Get a list of all files to finalize from all pending batches. + for batch_handle in batches_to_finalize.values(): + batch_handle = cast(FileWriterBatchHandle, batch_handle) + files += batch_handle.files + # Use the max batch ID as the batch ID for table names. + max_batch_id = max(batches_to_finalize.keys()) + temp_table_name = self._write_files_to_new_table( files=files, stream_name=stream_name, @@ -957,6 +958,11 @@ def register_source( This method is called by the source when it is initialized. """ self._source_name = source_name + self.file_writer.register_source( + source_name, + incoming_source_catalog, + stream_names=stream_names, + ) self._ensure_schema_exists() super().register_source( source_name, diff --git a/tests/integration_tests/test_source_test_fixture.py b/tests/integration_tests/test_source_test_fixture.py index c96a676a..76c30b7c 100644 --- a/tests/integration_tests/test_source_test_fixture.py +++ b/tests/integration_tests/test_source_test_fixture.py @@ -195,8 +195,11 @@ def test_file_write_and_cleanup() -> None: _ = source.read(cache_w_cleanup) _ = source.read(cache_wo_cleanup) - assert len(list(Path(temp_dir_1).glob("*.parquet"))) == 0, "Expected files to be cleaned up" - assert len(list(Path(temp_dir_2).glob("*.parquet"))) == 3, "Expected files to exist" + # We expect all files to be cleaned up: + assert len(list(Path(temp_dir_1).glob("*.jsonl.gz"))) == 0, "Expected files to be cleaned up" + + # There are three streams, but only two of them have data: + assert len(list(Path(temp_dir_2).glob("*.jsonl.gz"))) == 2, "Expected files to exist" def assert_cache_data(expected_test_stream_data: dict[str, list[dict[str, str | int]]], cache: SQLCacheBase, streams: list[str] = None):