Skip to content

Commit

Permalink
AirbyteLib: Fix column count mismatch bug (#34783)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored Feb 3, 2024
1 parent 28021d9 commit d9b500c
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 89 deletions.
4 changes: 2 additions & 2 deletions airbyte-lib/airbyte_lib/_file_writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Process a record batch.
Expand All @@ -64,7 +64,7 @@ def write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Write a batch of records to the cache.
Expand Down
32 changes: 28 additions & 4 deletions airbyte-lib/airbyte_lib/_file_writers/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from overrides import overrides
from pyarrow import parquet

from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase
from airbyte_lib import exceptions as exc
from airbyte_lib._file_writers.base import (
FileWriterBase,
FileWriterBatchHandle,
FileWriterConfigBase,
)


class ParquetWriterConfig(FileWriterConfigBase):
Expand All @@ -37,12 +42,24 @@ def get_new_cache_file_path(
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / f"{stream_name}_{batch_id}.parquet"

def _get_missing_columns(
self,
stream_name: str,
record_batch: pa.Table,
) -> list[str]:
"""Return a list of columns that are missing in the batch."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(message="Catalog manager should exist but does not.")
stream = self._catalog_manager.get_stream_config(stream_name)
stream_property_names = stream.stream.json_schema["properties"].keys()
return [col for col in stream_property_names if col not in record_batch.schema.names]

@overrides
def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Process a record batch.
Expand All @@ -51,8 +68,15 @@ def _write_batch(
_ = batch_id # unused
output_file_path = self.get_new_cache_file_path(stream_name)

with parquet.ParquetWriter(output_file_path, record_batch.schema) as writer:
writer.write_table(cast(pa.Table, record_batch))
missing_columns = self._get_missing_columns(stream_name, record_batch)
if missing_columns:
# We need to append columns with the missing column name(s) and a null type
null_array = cast(pa.Array, pa.array([None] * len(record_batch), type=pa.null()))
for col in missing_columns:
record_batch = record_batch.append_column(col, null_array)

with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer:
writer.write_table(record_batch)

batch_handle = FileWriterBatchHandle()
batch_handle.files.append(output_file_path)
Expand Down
51 changes: 39 additions & 12 deletions airbyte-lib/airbyte_lib/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@
AirbyteStateType,
AirbyteStreamState,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Type,
)

from airbyte_lib import exceptions as exc
from airbyte_lib._util import protocol_util # Internal utility functions
from airbyte_lib.progress import progress


if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Iterator

from airbyte_lib.caches._catalog_manager import CatalogManager
from airbyte_lib.config import CacheConfigBase


Expand All @@ -60,6 +63,8 @@ class RecordProcessor(abc.ABC):
def __init__(
self,
config: CacheConfigBase | dict | None,
*,
catalog_manager: CatalogManager | None = None,
) -> None:
if isinstance(config, dict):
config = self.config_class(**config)
Expand All @@ -72,8 +77,6 @@ def __init__(
)
raise TypeError(err_msg)

self.source_catalog: ConfiguredAirbyteCatalog | None = None

self._pending_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})
self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {})

Expand All @@ -83,22 +86,25 @@ def __init__(
list[AirbyteStateMessage],
] = defaultdict(list, {})

self._catalog_manager: CatalogManager | None = catalog_manager
self._setup()

def register_source(
self,
source_name: str,
incoming_source_catalog: ConfiguredAirbyteCatalog,
stream_names: set[str],
) -> None:
"""Register the source name and catalog.
For now, only one source at a time is supported.
If this method is called multiple times, the last call will overwrite the previous one.
TODO: Expand this to handle multiple sources.
"""
_ = source_name
self.source_catalog = incoming_source_catalog
"""Register the source name and catalog."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(
message="Catalog manager should exist but does not.",
)
self._catalog_manager.register_source(
source_name,
incoming_source_catalog=incoming_source_catalog,
incoming_stream_names=stream_names,
)

@property
def _streams_with_data(self) -> set[str]:
Expand Down Expand Up @@ -215,7 +221,7 @@ def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table | pa.RecordBatch,
record_batch: pa.Table,
) -> BatchHandle:
"""Process a single batch.
Expand Down Expand Up @@ -319,3 +325,24 @@ def _teardown(self) -> None:
def __del__(self) -> None:
"""Teardown temporary resources when instance is unloaded from memory."""
self._teardown()

@final
def _get_stream_config(
self,
stream_name: str,
) -> ConfiguredAirbyteStream:
"""Return the column definitions for the given stream."""
if not self._catalog_manager:
raise exc.AirbyteLibInternalError(
message="Catalog manager should exist but does not.",
)

return self._catalog_manager.get_stream_config(stream_name)

@final
def _get_stream_json_schema(
self,
stream_name: str,
) -> dict[str, Any]:
"""Return the column definitions for the given stream."""
return self._get_stream_config(stream_name).stream.json_schema
108 changes: 83 additions & 25 deletions airbyte-lib/airbyte_lib/caches/_catalog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,23 @@ def __init__(
) -> None:
self._engine: Engine = engine
self._table_name_resolver = table_name_resolver
self.source_catalog: ConfiguredAirbyteCatalog | None = None
self._source_catalog: ConfiguredAirbyteCatalog | None = None
self._load_catalog_from_internal_table()
assert self._source_catalog is not None

@property
def source_catalog(self) -> ConfiguredAirbyteCatalog:
"""Return the source catalog.
Raises:
AirbyteLibInternalError: If the source catalog is not set.
"""
if not self._source_catalog:
raise exc.AirbyteLibInternalError(
message="Source catalog should be initialized but is not.",
)

return self._source_catalog

def _ensure_internal_tables(self) -> None:
engine = self._engine
Expand All @@ -57,34 +72,70 @@ def register_source(
self,
source_name: str,
incoming_source_catalog: ConfiguredAirbyteCatalog,
incoming_stream_names: set[str],
) -> None:
if not self.source_catalog:
self.source_catalog = incoming_source_catalog
else:
# merge in the new streams, keyed by name
new_streams = {stream.stream.name: stream for stream in incoming_source_catalog.streams}
for stream in self.source_catalog.streams:
if stream.stream.name not in new_streams:
new_streams[stream.stream.name] = stream
self.source_catalog = ConfiguredAirbyteCatalog(
streams=list(new_streams.values()),
"""Register a source and its streams in the cache."""
self._update_catalog(
incoming_source_catalog=incoming_source_catalog,
incoming_stream_names=incoming_stream_names,
)
self._save_catalog_to_internal_table(
source_name=source_name,
incoming_source_catalog=incoming_source_catalog,
incoming_stream_names=incoming_stream_names,
)

def _update_catalog(
self,
incoming_source_catalog: ConfiguredAirbyteCatalog,
incoming_stream_names: set[str],
) -> None:
if not self._source_catalog:
self._source_catalog = ConfiguredAirbyteCatalog(
streams=[
stream
for stream in incoming_source_catalog.streams
if stream.stream.name in incoming_stream_names
],
)
assert len(self._source_catalog.streams) == len(incoming_stream_names)
return

# Keep existing streams untouched if not incoming
unchanged_streams: list[ConfiguredAirbyteStream] = [
stream
for stream in self._source_catalog.streams
if stream.stream.name not in incoming_stream_names
]
new_streams: list[ConfiguredAirbyteStream] = [
stream
for stream in incoming_source_catalog.streams
if stream.stream.name in incoming_stream_names
]
self._source_catalog = ConfiguredAirbyteCatalog(streams=unchanged_streams + new_streams)

def _save_catalog_to_internal_table(
self,
source_name: str,
incoming_source_catalog: ConfiguredAirbyteCatalog,
incoming_stream_names: set[str],
) -> None:
self._ensure_internal_tables()
engine = self._engine
with Session(engine) as session:
# delete all existing streams from the db
session.query(CachedStream).filter(
CachedStream.table_name.in_(
[
self._table_name_resolver(stream.stream.name)
for stream in self.source_catalog.streams
]
)
).delete()
# Delete and replace existing stream entries from the catalog cache
table_name_entries_to_delete = [
self._table_name_resolver(incoming_stream_name)
for incoming_stream_name in incoming_stream_names
]
result = (
session.query(CachedStream)
.filter(CachedStream.table_name.in_(table_name_entries_to_delete))
.delete()
)
_ = result
session.commit()
# add the new ones
streams = [
insert_streams = [
CachedStream(
source_name=source_name,
stream_name=stream.stream.name,
Expand All @@ -93,8 +144,7 @@ def register_source(
)
for stream in incoming_source_catalog.streams
]
session.add_all(streams)

session.add_all(insert_streams)
session.commit()

def get_stream_config(
Expand All @@ -113,6 +163,11 @@ def get_stream_config(
if not matching_streams:
raise exc.AirbyteStreamNotFoundError(
stream_name=stream_name,
context={
"available_streams": [
stream.stream.name for stream in self.source_catalog.streams
],
},
)

if len(matching_streams) > 1:
Expand All @@ -133,10 +188,13 @@ def _load_catalog_from_internal_table(self) -> None:
streams: list[CachedStream] = session.query(CachedStream).all()
if not streams:
# no streams means the cache is pristine
if not self._source_catalog:
self._source_catalog = ConfiguredAirbyteCatalog(streams=[])

return

# load the catalog
self.source_catalog = ConfiguredAirbyteCatalog(
self._source_catalog = ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=AirbyteStream(
Expand Down
Loading

0 comments on commit d9b500c

Please sign in to comment.