Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AirbyteLib: Fix column count mismatch bug #34783

Merged
merged 13 commits into from
Feb 3, 2024
4 changes: 2 additions & 2 deletions airbyte-lib/airbyte_lib/_file_writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Process a record batch.

Expand All @@ -64,7 +64,7 @@ def write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Write a batch of records to the cache.

Expand Down
32 changes: 28 additions & 4 deletions airbyte-lib/airbyte_lib/_file_writers/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from overrides import overrides
from pyarrow import parquet

from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase
from airbyte_lib import exceptions as exc
from airbyte_lib._file_writers.base import (
FileWriterBase,
FileWriterBatchHandle,
FileWriterConfigBase,
)


class ParquetWriterConfig(FileWriterConfigBase):
Expand All @@ -37,12 +42,24 @@ def get_new_cache_file_path(
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / f"{stream_name}_{batch_id}.parquet"

def _get_missing_columns(
self,
stream_name: str,
record_batch: pa.Table,
) -> list[str]:
"""Return a list of columns that are missing in the batch."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(message="Catalog manager should exist but does not.")
stream = self._catalog_manager.get_stream_config(stream_name)
stream_property_names = stream.stream.json_schema["properties"].keys()
return [col for col in stream_property_names if col not in record_batch.schema.names]

@overrides
def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Process a record batch.

Expand All @@ -51,8 +68,15 @@ def _write_batch(
_ = batch_id # unused
output_file_path = self.get_new_cache_file_path(stream_name)

with parquet.ParquetWriter(output_file_path, record_batch.schema) as writer:
writer.write_table(cast(pa.Table, record_batch))
missing_columns = self._get_missing_columns(stream_name, record_batch)
if missing_columns:
# We need to append columns with the missing column name(s) and a null type
null_array = cast(pa.Array, pa.array([None] * len(record_batch), type=pa.null()))
for col in missing_columns:
record_batch = record_batch.append_column(col, null_array)

with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer:
writer.write_table(record_batch)

batch_handle = FileWriterBatchHandle()
batch_handle.files.append(output_file_path)
Expand Down
51 changes: 39 additions & 12 deletions airbyte-lib/airbyte_lib/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@
AirbyteStateType,
AirbyteStreamState,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Type,
)

from airbyte_lib import exceptions as exc
from airbyte_lib._util import protocol_util # Internal utility functions
from airbyte_lib.progress import progress


if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Iterator

from airbyte_lib.caches._catalog_manager import CatalogManager
from airbyte_lib.config import CacheConfigBase


Expand All @@ -60,6 +63,8 @@ class RecordProcessor(abc.ABC):
def __init__(
self,
config: CacheConfigBase | dict | None,
*,
catalog_manager: CatalogManager | None = None,
) -> None:
if isinstance(config, dict):
config = self.config_class(**config)
Expand All @@ -72,8 +77,6 @@ def __init__(
)
raise TypeError(err_msg)

self.source_catalog: ConfiguredAirbyteCatalog | None = None

self._pending_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})

Expand All @@ -83,22 +86,25 @@ def __init__(
list[AirbyteStateMessage],
] = defaultdict(list, {})

self._catalog_manager: CatalogManager | None = catalog_manager
self._setup()

def register_source(
self,
source_name: str,
incoming_source_catalog: ConfiguredAirbyteCatalog,
stream_names: set[str],
) -> None:
"""Register the source name and catalog.

For now, only one source at a time is supported.
If this method is called multiple times, the last call will overwrite the previous one.

TODO: Expand this to handle multiple sources.
"""
_ = source_name
self.source_catalog = incoming_source_catalog
"""Register the source name and catalog."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(
message="Catalog manager should exist but does not.",
)
self._catalog_manager.register_source(
source_name,
incoming_source_catalog=incoming_source_catalog,
incoming_stream_names=stream_names,
)

@property
def _streams_with_data(self) -> set[str]:
Expand Down Expand Up @@ -215,7 +221,7 @@ def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> BatchHandle:
"""Process a single batch.

Expand Down Expand Up @@ -319,3 +325,24 @@ def _teardown(self) -> None:
def __del__(self) -> None:
"""Teardown temporary resources when instance is unloaded from memory."""
self._teardown()

@final
def _get_stream_config(
self,
stream_name: str,
) -> ConfiguredAirbyteStream:
"""Return the column definitions for the given stream."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(
message="Catalog manager should exist but does not.",
)

return self._catalog_manager.get_stream_config(stream_name)

@final
def _get_stream_json_schema(
self,
stream_name: str,
) -> dict[str, Any]:
"""Return the column definitions for the given stream."""
return self._get_stream_config(stream_name).stream.json_schema
108 changes: 83 additions & 25 deletions airbyte-lib/airbyte_lib/caches/_catalog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,23 @@ def __init__(
) -> None:
self._engine: Engine = engine
self._table_name_resolver = table_name_resolver
self.source_catalog: ConfiguredAirbyteCatalog | None = None
self._source_catalog: ConfiguredAirbyteCatalog | None = None
self._load_catalog_from_internal_table()
assert self._source_catalog is not None

@property
def source_catalog(self) -> ConfiguredAirbyteCatalog:
"""Return the source catalog.

Raises:
AirbyteLibInternalError: If the source catalog is not set.
"""
if not self._source_catalog:
raise exc.AirbyteLibInternalError(
message="Source catalog should be initialized but is not.",
)

return self._source_catalog

def _ensure_internal_tables(self) -> None:
engine = self._engine
Expand All @@ -57,34 +72,70 @@ def register_source(
self,
source_name: str,
incoming_source_catalog: ConfiguredAirbyteCatalog,
incoming_stream_names: set[str],
) -> None:
if not self.source_catalog:
self.source_catalog = incoming_source_catalog
else:
# merge in the new streams, keyed by name
new_streams = {stream.stream.name: stream for stream in incoming_source_catalog.streams}
for stream in self.source_catalog.streams:
if stream.stream.name not in new_streams:
new_streams[stream.stream.name] = stream
self.source_catalog = ConfiguredAirbyteCatalog(
streams=list(new_streams.values()),
"""Register a source and its streams in the cache."""
self._update_catalog(
incoming_source_catalog=incoming_source_catalog,
incoming_stream_names=incoming_stream_names,
)
self._save_catalog_to_internal_table(
source_name=source_name,
incoming_source_catalog=incoming_source_catalog,
incoming_stream_names=incoming_stream_names,
)

def _update_catalog(
self,
incoming_source_catalog: ConfiguredAirbyteCatalog,
incoming_stream_names: set[str],
) -> None:
if not self._source_catalog:
self._source_catalog = ConfiguredAirbyteCatalog(
streams=[
stream
for stream in incoming_source_catalog.streams
if stream.stream.name in incoming_stream_names
],
)
assert len(self._source_catalog.streams) == len(incoming_stream_names)
return

# Keep existing streams untouched if not incoming
unchanged_streams: list[ConfiguredAirbyteStream] = [
stream
for stream in self._source_catalog.streams
if stream.stream.name not in incoming_stream_names
]
new_streams: list[ConfiguredAirbyteStream] = [
stream
for stream in incoming_source_catalog.streams
if stream.stream.name in incoming_stream_names
]
self._source_catalog = ConfiguredAirbyteCatalog(streams=unchanged_streams + new_streams)

def _save_catalog_to_internal_table(
self,
source_name: str,
incoming_source_catalog: ConfiguredAirbyteCatalog,
incoming_stream_names: set[str],
) -> None:
self._ensure_internal_tables()
engine = self._engine
with Session(engine) as session:
# delete all existing streams from the db
session.query(CachedStream).filter(
CachedStream.table_name.in_(
[
self._table_name_resolver(stream.stream.name)
for stream in self.source_catalog.streams
]
)
).delete()
# Delete and replace existing stream entries from the catalog cache
table_name_entries_to_delete = [
self._table_name_resolver(incoming_stream_name)
for incoming_stream_name in incoming_stream_names
]
result = (
session.query(CachedStream)
.filter(CachedStream.table_name.in_(table_name_entries_to_delete))
.delete()
)
_ = result
session.commit()
# add the new ones
streams = [
insert_streams = [
CachedStream(
source_name=source_name,
stream_name=stream.stream.name,
Expand All @@ -93,8 +144,7 @@ def register_source(
)
for stream in incoming_source_catalog.streams
]
session.add_all(streams)

session.add_all(insert_streams)
session.commit()

def get_stream_config(
Expand All @@ -113,6 +163,11 @@ def get_stream_config(
if not matching_streams:
raise exc.AirbyteStreamNotFoundError(
stream_name=stream_name,
context={
"available_streams": [
stream.stream.name for stream in self.source_catalog.streams
],
},
)

if len(matching_streams) > 1:
Expand All @@ -133,10 +188,13 @@ def _load_catalog_from_internal_table(self) -> None:
streams: list[CachedStream] = session.query(CachedStream).all()
if not streams:
# no streams means the cache is pristine
if not self._source_catalog:
self._source_catalog = ConfiguredAirbyteCatalog(streams=[])

return

# load the catalog
self.source_catalog = ConfiguredAirbyteCatalog(
self._source_catalog = ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=AirbyteStream(
Expand Down
Loading
Loading