Skip to content

Commit

Permalink
update from main (manual merge resolution)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Feb 22, 2024
2 parents 816d807 + 81d1b9c commit 6e5ad29
Show file tree
Hide file tree
Showing 15 changed files with 368 additions and 88 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ jobs:
cache: 'poetry'

# Job-specifc step(s):
- name: Check code format
run: poetry run ruff format --check .
- name: Check MyPy typing
run: poetry run mypy .
5 changes: 5 additions & 0 deletions airbyte/_factories/connector_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_source(
name: str,
config: dict[str, Any] | None = None,
*,
streams: str | list[str] | None = None,
version: str | None = None,
pip_url: str | None = None,
local_executable: Path | str | None = None,
Expand All @@ -53,6 +54,9 @@ def get_source(
name: connector name
config: connector config - if not provided, you need to set it later via the set_config
method.
streams: list of stream names to select for reading. If set to "*", all streams will be
selected. If not provided, you can set it later via the `select_streams()` or
`select_all_streams()` method.
version: connector version - if not provided, the currently installed version will be used.
If no version is installed, the latest available version will be used. The version can
also be set to "latest" to force the use of the latest available version.
Expand Down Expand Up @@ -88,6 +92,7 @@ def get_source(
return Source(
name=name,
config=config,
streams=streams,
executor=PathExecutor(
name=name,
path=local_executable,
Expand Down
3 changes: 3 additions & 0 deletions airbyte/_file_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase
from .jsonl import JsonlWriter, JsonlWriterConfig
from .parquet import ParquetWriter, ParquetWriterConfig


__all__ = [
"FileWriterBatchHandle",
"FileWriterBase",
"FileWriterConfigBase",
"JsonlWriter",
"JsonlWriterConfig",
"ParquetWriter",
"ParquetWriterConfig",
]
68 changes: 68 additions & 0 deletions airbyte/_file_writers/jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Parquet cache implementation."""
from __future__ import annotations

import gzip
from pathlib import Path
from typing import TYPE_CHECKING, cast

import orjson
import ulid
from overrides import overrides

from airbyte._file_writers.base import (
FileWriterBase,
FileWriterBatchHandle,
FileWriterConfigBase,
)


if TYPE_CHECKING:
import pyarrow as pa


class JsonlWriterConfig(FileWriterConfigBase):
"""Configuration for the Snowflake cache."""

# Inherits `cache_dir` from base class


class JsonlWriter(FileWriterBase):
"""A Jsonl cache implementation."""

config_class = JsonlWriterConfig

def get_new_cache_file_path(
self,
stream_name: str,
batch_id: str | None = None, # ULID of the batch
) -> Path:
"""Return a new cache file path for the given stream."""
batch_id = batch_id or str(ulid.ULID())
config: JsonlWriterConfig = cast(JsonlWriterConfig, self.config)
target_dir = Path(config.cache_dir)
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / f"{stream_name}_{batch_id}.jsonl.gz"

@overrides
def _write_batch(
self,
stream_name: str,
batch_id: str,
record_batch: pa.Table,
) -> FileWriterBatchHandle:
"""Process a record batch.
Return the path to the cache file.
"""
_ = batch_id # unused
output_file_path = self.get_new_cache_file_path(stream_name)

with gzip.open(output_file_path, "w") as jsonl_file:
for record in record_batch.to_pylist():
jsonl_file.write(orjson.dumps(record) + b"\n")

batch_handle = FileWriterBatchHandle()
batch_handle.files.append(output_file_path)
return batch_handle
24 changes: 21 additions & 3 deletions airbyte/_file_writers/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

"""A Parquet cache implementation."""
"""A Parquet cache implementation.
NOTE: Parquet is a strongly typed columnar storage format, which has known issues when applied to
variable schemas, schemas with indeterminate types, and schemas that have empty data nodes.
This implementation is deprecated for now in favor of jsonl.gz, and may be removed or revamped in
the future.
"""
from __future__ import annotations

from pathlib import Path
Expand Down Expand Up @@ -81,8 +87,20 @@ def _write_batch(
for col in missing_columns:
record_batch = record_batch.append_column(col, null_array)

with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer:
writer.write_table(record_batch)
try:
with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer:
writer.write_table(record_batch)
except Exception as e:
raise exc.AirbyteLibInternalError(
message=f"Failed to write record batch to Parquet file: {e}",
context={
"stream_name": stream_name,
"batch_id": batch_id,
"output_file_path": output_file_path,
"schema": record_batch.schema,
"record_batch": record_batch,
},
) from e

batch_handle = FileWriterBatchHandle()
batch_handle.files.append(output_file_path)
Expand Down
41 changes: 29 additions & 12 deletions airbyte/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, cast, final

import pandas as pd
import pyarrow as pa
import ulid

Expand All @@ -36,6 +37,7 @@
from airbyte.config import CacheConfigBase
from airbyte.progress import progress
from airbyte.strategies import WriteStrategy
from airbyte.types import _get_pyarrow_type


if TYPE_CHECKING:
Expand Down Expand Up @@ -174,17 +176,16 @@ def process_airbyte_messages(
)

stream_batches: dict[str, list[dict]] = defaultdict(list, {})

# Process messages, writing to batches as we go
for message in messages:
if message.type is Type.RECORD:
record_msg = cast(AirbyteRecordMessage, message.record)
stream_name = record_msg.stream
stream_batch = stream_batches[stream_name]
stream_batch.append(protocol_util.airbyte_record_message_to_dict(record_msg))

if len(stream_batch) >= max_batch_size:
record_batch = pa.Table.from_pylist(stream_batch)
batch_df = pd.DataFrame(stream_batch)
record_batch = pa.Table.from_pandas(batch_df)
self._process_batch(stream_name, record_batch)
progress.log_batch_written(stream_name, len(stream_batch))
stream_batch.clear()
Expand All @@ -203,21 +204,23 @@ def process_airbyte_messages(
# Type.LOG, Type.TRACE, Type.CONTROL, etc.
pass

# Add empty streams to the dictionary, so we create a destination table for it
for stream_name in self._expected_streams:
if stream_name not in stream_batches:
if DEBUG_MODE:
print(f"Stream {stream_name} has no data")
stream_batches[stream_name] = []

# We are at the end of the stream. Process whatever else is queued.
for stream_name, stream_batch in stream_batches.items():
record_batch = pa.Table.from_pylist(stream_batch)
batch_df = pd.DataFrame(stream_batch)
record_batch = pa.Table.from_pandas(batch_df)
self._process_batch(stream_name, record_batch)
progress.log_batch_written(stream_name, len(stream_batch))

all_streams = list(self._pending_batches.keys())
# Add empty streams to the streams list, so we create a destination table for it
for stream_name in self._expected_streams:
if stream_name not in all_streams:
if DEBUG_MODE:
print(f"Stream {stream_name} has no data")
all_streams.append(stream_name)

# Finalize any pending batches
for stream_name in list(self._pending_batches.keys()):
for stream_name in all_streams:
self._finalize_batches(stream_name, write_strategy=write_strategy)
progress.log_stream_finalized(stream_name)

Expand Down Expand Up @@ -391,3 +394,17 @@ def _get_stream_json_schema(
) -> dict[str, Any]:
"""Return the column definitions for the given stream."""
return self._get_stream_config(stream_name).stream.json_schema

def _get_stream_pyarrow_schema(
self,
stream_name: str,
) -> pa.Schema:
"""Return the column definitions for the given stream."""
return pa.schema(
fields=[
pa.field(prop_name, _get_pyarrow_type(prop_def))
for prop_name, prop_def in self._get_stream_json_schema(stream_name)[
"properties"
].items()
]
)
68 changes: 36 additions & 32 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Any, cast, final

import pandas as pd
import pyarrow as pa
import sqlalchemy
import ulid
from overrides import overrides
Expand Down Expand Up @@ -43,6 +42,7 @@
from collections.abc import Generator
from pathlib import Path

import pyarrow as pa
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -164,6 +164,11 @@ def register_source(
This method is called by the source when it is initialized.
"""
self._source_name = source_name
self.file_writer.register_source(
source_name,
incoming_source_catalog,
stream_names=stream_names,
)
self._ensure_schema_exists()
super().register_source(
source_name,
Expand Down Expand Up @@ -507,17 +512,6 @@ def _finalize_batches(
although this is a fairly rare edge case we can ignore in V1.
"""
with self._finalizing_batches(stream_name) as batches_to_finalize:
if not batches_to_finalize:
return {}

files: list[Path] = []
# Get a list of all files to finalize from all pending batches.
for batch_handle in batches_to_finalize.values():
batch_handle = cast(FileWriterBatchHandle, batch_handle)
files += batch_handle.files
# Use the max batch ID as the batch ID for table names.
max_batch_id = max(batches_to_finalize.keys())

# Make sure the target schema and target table exist.
self._ensure_schema_exists()
final_table_name = self._ensure_final_table_exists(
Expand All @@ -529,6 +523,18 @@ def _finalize_batches(
raise_on_error=True,
)

if not batches_to_finalize:
# If there are no batches to finalize, return after ensuring the table exists.
return {}

files: list[Path] = []
# Get a list of all files to finalize from all pending batches.
for batch_handle in batches_to_finalize.values():
batch_handle = cast(FileWriterBatchHandle, batch_handle)
files += batch_handle.files
# Use the max batch ID as the batch ID for table names.
max_batch_id = max(batches_to_finalize.keys())

temp_table_name = self._write_files_to_new_table(
files=files,
stream_name=stream_name,
Expand Down Expand Up @@ -609,27 +615,25 @@ def _write_files_to_new_table(
"""
temp_table_name = self._create_table_for_loading(stream_name, batch_id)
for file_path in files:
with pa.parquet.ParquetFile(file_path) as pf:
record_batch = pf.read()
dataframe = record_batch.to_pandas()

# Pandas will auto-create the table if it doesn't exist, which we don't want.
if not self._table_exists(temp_table_name):
raise exc.AirbyteLibInternalError(
message="Table does not exist after creation.",
context={
"temp_table_name": temp_table_name,
},
)

dataframe.to_sql(
temp_table_name,
self.get_sql_alchemy_url(),
schema=self.config.schema_name,
if_exists="append",
index=False,
dtype=self._get_sql_column_definitions(stream_name),
dataframe = pd.read_json(file_path, lines=True)

# Pandas will auto-create the table if it doesn't exist, which we don't want.
if not self._table_exists(temp_table_name):
raise exc.AirbyteLibInternalError(
message="Table does not exist after creation.",
context={
"temp_table_name": temp_table_name,
},
)

dataframe.to_sql(
temp_table_name,
self.get_sql_alchemy_url(),
schema=self.config.schema_name,
if_exists="append",
index=False,
dtype=self._get_sql_column_definitions(stream_name),
)
return temp_table_name

@final
Expand Down
Loading

0 comments on commit 6e5ad29

Please sign in to comment.