diff --git a/airbyte/_factories/cache_factories.py b/airbyte/_factories/cache_factories.py index 9399369a..7a48b1a9 100644 --- a/airbyte/_factories/cache_factories.py +++ b/airbyte/_factories/cache_factories.py @@ -15,9 +15,10 @@ def get_default_cache() -> DuckDBCache: Cache files are stored in the `.cache` directory, relative to the current working directory. """ - + cache_dir = Path("./.cache/default_cache") return DuckDBCache( - db_path="./.cache/default_cache_db.duckdb", + db_path=cache_dir / "default_cache.duckdb", + cache_dir=cache_dir, ) diff --git a/airbyte/_file_writers/__init__.py b/airbyte/_file_writers/__init__.py deleted file mode 100644 index bdec092c..00000000 --- a/airbyte/_file_writers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase -from .jsonl import JsonlWriter, JsonlWriterConfig -from .parquet import ParquetWriter, ParquetWriterConfig - - -__all__ = [ - "FileWriterBatchHandle", - "FileWriterBase", - "FileWriterConfigBase", - "JsonlWriter", - "JsonlWriterConfig", - "ParquetWriter", - "ParquetWriterConfig", -] diff --git a/airbyte/_processors/__init__.py b/airbyte/_processors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/airbyte/_processors.py b/airbyte/_processors/base.py similarity index 94% rename from airbyte/_processors.py rename to airbyte/_processors/base.py index c4126dfb..35cb5b08 100644 --- a/airbyte/_processors.py +++ b/airbyte/_processors/base.py @@ -1,8 +1,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Abstract base class for Processors, including SQL and File writers. -"""Define abstract base class for Processors, including Caches and File writers. - -Processors can all take input from STDIN or a stream of Airbyte messages. +Processors can take input from STDIN or a stream of Airbyte messages. Caches will pass their input to the File Writer. They share a common base class so certain abstractions like "write" and "finalize" can be handled in either layer, or both. @@ -34,7 +33,7 @@ from airbyte import exceptions as exc from airbyte._util import protocol_util -from airbyte.config import CacheConfigBase +from airbyte.caches.base import CacheBase from airbyte.progress import progress from airbyte.strategies import WriteStrategy from airbyte.types import _get_pyarrow_type @@ -62,20 +61,20 @@ class RecordProcessor(abc.ABC): """Abstract base class for classes which can process input records.""" skip_finalize_step: bool = False - _expected_streams: set[str] def __init__( self, - config: CacheConfigBase, + cache: CacheBase, *, catalog_manager: CatalogManager | None = None, ) -> None: - self.config = config - if not isinstance(self.config, CacheConfigBase): + self._expected_streams: set[str] | None = None + self.cache: CacheBase = cache + if not isinstance(self.cache, CacheBase): raise exc.AirbyteLibInputError( message=( - f"Expected config class of type 'CacheConfigBase'. " - f"Instead received type '{type(self.config).__name__}'." + f"Expected config class of type 'CacheBase'. " + f"Instead received type '{type(self.cache).__name__}'." ), ) @@ -94,6 +93,11 @@ def __init__( self._catalog_manager: CatalogManager | None = catalog_manager self._setup() + @property + def expected_streams(self) -> set[str]: + """Return the expected stream names.""" + return self._expected_streams or set() + def register_source( self, source_name: str, @@ -112,11 +116,6 @@ def register_source( ) self._expected_streams = stream_names - @property - def _streams_with_data(self) -> set[str]: - """Return a list of known streams.""" - return self._pending_batches.keys() | self._finalized_batches.keys() - @final def process_stdin( self, @@ -213,7 +212,7 @@ def process_airbyte_messages( all_streams = list(self._pending_batches.keys()) # Add empty streams to the streams list, so we create a destination table for it - for stream_name in self._expected_streams: + for stream_name in self.expected_streams: if stream_name not in all_streams: if DEBUG_MODE: print(f"Stream {stream_name} has no data") diff --git a/airbyte/_processors/file/__init__.py b/airbyte/_processors/file/__init__.py new file mode 100644 index 00000000..26c25484 --- /dev/null +++ b/airbyte/_processors/file/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""File processors.""" + +from __future__ import annotations + +from .base import FileWriterBase, FileWriterBatchHandle +from .jsonl import JsonlWriter +from .parquet import ParquetWriter + + +__all__ = [ + "FileWriterBatchHandle", + "FileWriterBase", + "JsonlWriter", + "ParquetWriter", +] diff --git a/airbyte/_file_writers/base.py b/airbyte/_processors/file/base.py similarity index 90% rename from airbyte/_file_writers/base.py rename to airbyte/_processors/file/base.py index 8f0dbc60..26833e2f 100644 --- a/airbyte/_file_writers/base.py +++ b/airbyte/_processors/file/base.py @@ -1,6 +1,5 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. - -"""Define abstract base class for File Writers, which write and read from file storage.""" +"""Abstract base class for File Writers, which write and read from file storage.""" from __future__ import annotations @@ -11,8 +10,8 @@ from overrides import overrides -from airbyte._processors import BatchHandle, RecordProcessor -from airbyte.config import CacheConfigBase +from airbyte._processors.base import BatchHandle, RecordProcessor +from airbyte.caches.base import CacheBase if TYPE_CHECKING: @@ -34,7 +33,7 @@ class FileWriterBatchHandle(BatchHandle): files: list[Path] = field(default_factory=list) -class FileWriterConfigBase(CacheConfigBase): +class FileWriterConfigBase(CacheBase): """Configuration for the Snowflake cache.""" cache_dir: Path = Path("./.cache/files/") @@ -46,8 +45,6 @@ class FileWriterConfigBase(CacheConfigBase): class FileWriterBase(RecordProcessor, abc.ABC): """A generic base implementation for a file-based cache.""" - config: FileWriterConfigBase - @abc.abstractmethod @overrides def _write_batch( @@ -90,7 +87,7 @@ def _cleanup_batch( This method is a no-op if the `cleanup` config option is set to False. """ - if self.config.cleanup: + if self.cache.cleanup: batch_handle = cast(FileWriterBatchHandle, batch_handle) _ = stream_name, batch_id for file_path in batch_handle.files: diff --git a/airbyte/_file_writers/jsonl.py b/airbyte/_processors/file/jsonl.py similarity index 78% rename from airbyte/_file_writers/jsonl.py rename to airbyte/_processors/file/jsonl.py index c47056e5..1ea66bde 100644 --- a/airbyte/_file_writers/jsonl.py +++ b/airbyte/_processors/file/jsonl.py @@ -1,20 +1,19 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. - """A Parquet cache implementation.""" + from __future__ import annotations import gzip from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import orjson import ulid from overrides import overrides -from airbyte._file_writers.base import ( +from airbyte._processors.file.base import ( FileWriterBase, FileWriterBatchHandle, - FileWriterConfigBase, ) @@ -22,17 +21,9 @@ import pyarrow as pa -class JsonlWriterConfig(FileWriterConfigBase): - """Configuration for the Snowflake cache.""" - - # Inherits `cache_dir` from base class - - class JsonlWriter(FileWriterBase): """A Jsonl cache implementation.""" - config_class = JsonlWriterConfig - def get_new_cache_file_path( self, stream_name: str, @@ -40,8 +31,7 @@ def get_new_cache_file_path( ) -> Path: """Return a new cache file path for the given stream.""" batch_id = batch_id or str(ulid.ULID()) - config: JsonlWriterConfig = cast(JsonlWriterConfig, self.config) - target_dir = Path(config.cache_dir) + target_dir = Path(self.cache.cache_dir) target_dir.mkdir(parents=True, exist_ok=True) return target_dir / f"{stream_name}_{batch_id}.jsonl.gz" diff --git a/airbyte/_file_writers/parquet.py b/airbyte/_processors/file/parquet.py similarity index 95% rename from airbyte/_file_writers/parquet.py rename to airbyte/_processors/file/parquet.py index 91d5ed12..8159a319 100644 --- a/airbyte/_file_writers/parquet.py +++ b/airbyte/_processors/file/parquet.py @@ -1,6 +1,5 @@ -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. - -"""A Parquet cache implementation. +# Copyright (c) 2023 Airbyte, Inc., all rights reserved +"""A Parquet file writer implementation. NOTE: Parquet is a strongly typed columnar storage format, which has known issues when applied to variable schemas, schemas with indeterminate types, and schemas that have empty data nodes. @@ -18,7 +17,7 @@ from pyarrow import parquet from airbyte import exceptions as exc -from airbyte._file_writers.base import ( +from airbyte._processors.file.base import ( FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase, @@ -42,7 +41,7 @@ def get_new_cache_file_path( ) -> Path: """Return a new cache file path for the given stream.""" batch_id = batch_id or str(ulid.ULID()) - config: ParquetWriterConfig = cast(ParquetWriterConfig, self.config) + config: ParquetWriterConfig = cast(ParquetWriterConfig, self.cache) target_dir = Path(config.cache_dir) target_dir.mkdir(parents=True, exist_ok=True) return target_dir / f"{stream_name}_{batch_id}.parquet" diff --git a/airbyte/_processors/sql/__init__.py b/airbyte/_processors/sql/__init__.py new file mode 100644 index 00000000..61851a71 --- /dev/null +++ b/airbyte/_processors/sql/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""SQL processors.""" diff --git a/airbyte/_processors/sql/base.py b/airbyte/_processors/sql/base.py new file mode 100644 index 00000000..5fce9f7e --- /dev/null +++ b/airbyte/_processors/sql/base.py @@ -0,0 +1,902 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""The base SQL Cache implementation.""" + +from __future__ import annotations + +import abc +import enum +from contextlib import contextmanager +from functools import cached_property +from typing import TYPE_CHECKING, cast, final + +import pandas as pd +import sqlalchemy +import ulid +from overrides import overrides +from sqlalchemy import ( + Column, + Table, + and_, + create_engine, + insert, + null, + select, + text, + update, +) +from sqlalchemy.pool import StaticPool +from sqlalchemy.sql.elements import TextClause + +from airbyte import exceptions as exc +from airbyte._processors.base import BatchHandle, RecordProcessor +from airbyte._processors.file.base import FileWriterBase, FileWriterBatchHandle +from airbyte._util.text_util import lower_case_set +from airbyte.caches._catalog_manager import CatalogManager +from airbyte.datasets._sql import CachedDataset +from airbyte.strategies import WriteStrategy +from airbyte.types import SQLTypeConverter + + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + import pyarrow as pa + from sqlalchemy.engine import Connection, Engine + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.base import Executable + + from airbyte_protocol.models import ( + AirbyteStateMessage, + ConfiguredAirbyteCatalog, + ) + + from airbyte.caches.base import CacheBase + from airbyte.telemetry import CacheTelemetryInfo + + +DEBUG_MODE = False # Set to True to enable additional debug logging. + + +class RecordDedupeMode(enum.Enum): + APPEND = "append" + REPLACE = "replace" + + +class SQLRuntimeError(Exception): + """Raised when an SQL operation fails.""" + + +class SqlProcessorBase(RecordProcessor): + """A base class to be used for SQL Caches.""" + + type_converter_class: type[SQLTypeConverter] = SQLTypeConverter + file_writer_class: type[FileWriterBase] + + supports_merge_insert = False + use_singleton_connection = False # If true, the same connection is used for all operations. + + # Constructor: + + @final # We don't want subclasses to have to override the constructor. + def __init__( + self, + cache: CacheBase, + file_writer: FileWriterBase | None = None, + ) -> None: + self._engine: Engine | None = None + self._connection_to_reuse: Connection | None = None + super().__init__(cache, catalog_manager=None) + self._ensure_schema_exists() + self._catalog_manager = CatalogManager( + engine=self.get_sql_engine(), + table_name_resolver=lambda stream_name: self.get_sql_table_name(stream_name), + ) + self.file_writer = file_writer or self.file_writer_class( + cache, catalog_manager=self._catalog_manager + ) + self.type_converter = self.type_converter_class() + self._cached_table_definitions: dict[str, sqlalchemy.Table] = {} + + # Public interface: + + def get_sql_alchemy_url(self) -> str: + """Return the SQLAlchemy URL to use.""" + return self.cache.get_sql_alchemy_url() + + @final + @cached_property + def database_name(self) -> str: + """Return the name of the database.""" + return self.cache.get_database_name() + + @final + def get_sql_engine(self) -> Engine: + """Return a new SQL engine to use.""" + if self._engine: + return self._engine + + sql_alchemy_url = self.get_sql_alchemy_url() + + execution_options = {"schema_translate_map": {None: self.cache.schema_name}} + if self.use_singleton_connection: + if self._connection_to_reuse is None: + # This temporary bootstrap engine will be created once and is needed to + # create the long-lived connection object. + bootstrap_engine = create_engine( + sql_alchemy_url, + ) + self._connection_to_reuse = bootstrap_engine.connect() + + self._engine = create_engine( + sql_alchemy_url, + creator=lambda: self._connection_to_reuse, + poolclass=StaticPool, + echo=DEBUG_MODE, + execution_options=execution_options, + # isolation_level="AUTOCOMMIT", + ) + else: + # Regular engine creation for new connections + self._engine = create_engine( + sql_alchemy_url, + echo=DEBUG_MODE, + execution_options=execution_options, + # isolation_level="AUTOCOMMIT", + ) + + return self._engine + + @overrides + def register_source( + self, + source_name: str, + incoming_source_catalog: ConfiguredAirbyteCatalog, + stream_names: set[str], + ) -> None: + """Register the source with the cache. + + We use stream_names to determine which streams will receive data, and + we only register the stream if is expected to receive data. + + This method is called by the source when it is initialized. + """ + self._source_name = source_name + self.file_writer.register_source( + source_name, + incoming_source_catalog, + stream_names=stream_names, + ) + self._ensure_schema_exists() + super().register_source( + source_name, + incoming_source_catalog, + stream_names=stream_names, + ) + + @contextmanager + def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, None]: + """A context manager which returns a new SQL connection for running queries. + + If the connection needs to close, it will be closed automatically. + """ + if self.use_singleton_connection and self._connection_to_reuse is not None: + connection = self._connection_to_reuse + self._init_connection_settings(connection) + yield connection + + else: + with self.get_sql_engine().begin() as connection: + self._init_connection_settings(connection) + yield connection + + if not self.use_singleton_connection: + connection.close() + del connection + + def get_sql_table_name( + self, + stream_name: str, + ) -> str: + """Return the name of the SQL table for the given stream.""" + table_prefix = self.cache.table_prefix or "" + + # TODO: Add default prefix based on the source name. + + return self._normalize_table_name( + f"{table_prefix}{stream_name}{self.cache.table_suffix}", + ) + + @final + def get_sql_table( + self, + stream_name: str, + ) -> sqlalchemy.Table: + """Return the main table object for the stream.""" + return self._get_table_by_name(self.get_sql_table_name(stream_name)) + + # Read methods: + + def get_records( + self, + stream_name: str, + ) -> CachedDataset: + """Uses SQLAlchemy to select all rows from the table.""" + return CachedDataset(self.cache, stream_name) + + def get_pandas_dataframe( + self, + stream_name: str, + ) -> pd.DataFrame: + """Return a Pandas data frame with the stream's data.""" + table_name = self.get_sql_table_name(stream_name) + engine = self.get_sql_engine() + return pd.read_sql_table(table_name, engine) + + # Protected members (non-public interface): + + def _init_connection_settings(self, connection: Connection) -> None: + """This is called automatically whenever a new connection is created. + + By default this is a no-op. Subclasses can use this to set connection settings, such as + timezone, case-sensitivity settings, and other session-level variables. + """ + pass + + def _get_table_by_name( + self, + table_name: str, + *, + force_refresh: bool = False, + ) -> sqlalchemy.Table: + """Return a table object from a table name. + + To prevent unnecessary round-trips to the database, the table is cached after the first + query. To ignore the cache and force a refresh, set 'force_refresh' to True. + """ + if force_refresh or table_name not in self._cached_table_definitions: + self._cached_table_definitions[table_name] = sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(schema=self.cache.schema_name), + autoload_with=self.get_sql_engine(), + ) + + return self._cached_table_definitions[table_name] + + def _ensure_schema_exists( + self, + ) -> None: + """Return a new (unique) temporary table name.""" + schema_name = self.cache.schema_name + if schema_name in self._get_schemas_list(): + return + + sql = f"CREATE SCHEMA IF NOT EXISTS {schema_name}" + + try: + self._execute_sql(sql) + except Exception as ex: + # Ignore schema exists errors. + if "already exists" not in str(ex): + raise + + if DEBUG_MODE: + found_schemas = self._get_schemas_list() + assert ( + schema_name in found_schemas + ), f"Schema {schema_name} was not created. Found: {found_schemas}" + + def _quote_identifier(self, identifier: str) -> str: + """Return the given identifier, quoted.""" + return f'"{identifier}"' + + @final + def _get_temp_table_name( + self, + stream_name: str, + batch_id: str | None = None, # ULID of the batch + ) -> str: + """Return a new (unique) temporary table name.""" + batch_id = batch_id or str(ulid.ULID()) + return self._normalize_table_name(f"{stream_name}_{batch_id}") + + def _fully_qualified( + self, + table_name: str, + ) -> str: + """Return the fully qualified name of the given table.""" + return f"{self.cache.schema_name}.{self._quote_identifier(table_name)}" + + @final + def _create_table_for_loading( + self, + /, + stream_name: str, + batch_id: str, + ) -> str: + """Create a new table for loading data.""" + temp_table_name = self._get_temp_table_name(stream_name, batch_id) + column_definition_str = ",\n ".join( + f"{self._quote_identifier(column_name)} {sql_type}" + for column_name, sql_type in self._get_sql_column_definitions(stream_name).items() + ) + self._create_table(temp_table_name, column_definition_str) + + return temp_table_name + + def _get_tables_list( + self, + ) -> list[str]: + """Return a list of all tables in the database.""" + with self.get_sql_connection() as conn: + inspector: Inspector = sqlalchemy.inspect(conn) + return inspector.get_table_names(schema=self.cache.schema_name) + + def _get_schemas_list( + self, + database_name: str | None = None, + ) -> list[str]: + """Return a list of all tables in the database.""" + inspector: Inspector = sqlalchemy.inspect(self.get_sql_engine()) + database_name = database_name or self.database_name + found_schemas = inspector.get_schema_names() + return [ + found_schema.split(".")[-1].strip('"') + for found_schema in found_schemas + if "." not in found_schema + or (found_schema.split(".")[0].lower().strip('"') == database_name.lower()) + ] + + def _ensure_final_table_exists( + self, + stream_name: str, + *, + create_if_missing: bool = True, + ) -> str: + """Create the final table if it doesn't already exist. + + Return the table name. + """ + table_name = self.get_sql_table_name(stream_name) + did_exist = self._table_exists(table_name) + if not did_exist and create_if_missing: + column_definition_str = ",\n ".join( + f"{self._quote_identifier(column_name)} {sql_type}" + for column_name, sql_type in self._get_sql_column_definitions( + stream_name, + ).items() + ) + self._create_table(table_name, column_definition_str) + + return table_name + + def _ensure_compatible_table_schema( + self, + stream_name: str, + *, + raise_on_error: bool = False, + ) -> bool: + """Return true if the given table is compatible with the stream's schema. + + If raise_on_error is true, raise an exception if the table is not compatible. + + TODO: Expand this to check for column types and sizes, and to add missing columns. + + Returns true if the table is compatible, false if it is not. + """ + json_schema = self._get_stream_json_schema(stream_name) + stream_column_names: list[str] = json_schema["properties"].keys() + table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys() + + lower_case_table_column_names = lower_case_set(table_column_names) + missing_columns = [ + stream_col + for stream_col in stream_column_names + if stream_col.lower() not in lower_case_table_column_names + ] + if missing_columns: + if raise_on_error: + raise exc.AirbyteLibCacheTableValidationError( + violation="Cache table is missing expected columns.", + context={ + "stream_column_names": stream_column_names, + "table_column_names": table_column_names, + "missing_columns": missing_columns, + }, + ) + return False # Some columns are missing. + + return True # All columns exist. + + @final + def _create_table( + self, + table_name: str, + column_definition_str: str, + primary_keys: list[str] | None = None, + ) -> None: + if DEBUG_MODE: + assert table_name not in self._get_tables_list(), f"Table {table_name} already exists." + + if primary_keys: + pk_str = ", ".join(primary_keys) + column_definition_str += f",\n PRIMARY KEY ({pk_str})" + + cmd = f""" + CREATE TABLE {self._fully_qualified(table_name)} ( + {column_definition_str} + ) + """ + _ = self._execute_sql(cmd) + if DEBUG_MODE: + tables_list = self._get_tables_list() + assert ( + table_name in tables_list + ), f"Table {table_name} was not created. Found: {tables_list}" + + def _normalize_column_name( + self, + raw_name: str, + ) -> str: + return raw_name.lower().replace(" ", "_").replace("-", "_") + + def _normalize_table_name( + self, + raw_name: str, + ) -> str: + return raw_name.lower().replace(" ", "_").replace("-", "_") + + @final + def _get_sql_column_definitions( + self, + stream_name: str, + ) -> dict[str, sqlalchemy.types.TypeEngine]: + """Return the column definitions for the given stream.""" + columns: dict[str, sqlalchemy.types.TypeEngine] = {} + properties = self._get_stream_json_schema(stream_name)["properties"] + for property_name, json_schema_property_def in properties.items(): + clean_prop_name = self._normalize_column_name(property_name) + columns[clean_prop_name] = self.type_converter.to_sql_type( + json_schema_property_def, + ) + + # TODO: Add the metadata columns (this breaks tests) + # columns["_airbyte_extracted_at"] = sqlalchemy.TIMESTAMP() + # columns["_airbyte_loaded_at"] = sqlalchemy.TIMESTAMP() + return columns + + @overrides + def _write_batch( + self, + stream_name: str, + batch_id: str, + record_batch: pa.Table, + ) -> FileWriterBatchHandle: + """Process a record batch. + + Return the path to the cache file. + """ + return self.file_writer.write_batch(stream_name, batch_id, record_batch) + + def _cleanup_batch( + self, + stream_name: str, + batch_id: str, + batch_handle: BatchHandle, + ) -> None: + """Clean up the cache. + + For SQL caches, we only need to call the cleanup operation on the file writer. + + Subclasses should call super() if they override this method. + """ + self.file_writer.cleanup_batch(stream_name, batch_id, batch_handle) + + @final + @overrides + def _finalize_batches( + self, + stream_name: str, + write_strategy: WriteStrategy, + ) -> dict[str, BatchHandle]: + """Finalize all uncommitted batches. + + This is a generic 'final' implementation, which should not be overridden. + + Returns a mapping of batch IDs to batch handles, for those processed batches. + + TODO: Add a dedupe step here to remove duplicates from the temp table. + Some sources will send us duplicate records within the same stream, + although this is a fairly rare edge case we can ignore in V1. + """ + with self._finalizing_batches(stream_name) as batches_to_finalize: + # Make sure the target schema and target table exist. + self._ensure_schema_exists() + final_table_name = self._ensure_final_table_exists( + stream_name, + create_if_missing=True, + ) + self._ensure_compatible_table_schema( + stream_name=stream_name, + raise_on_error=True, + ) + + if not batches_to_finalize: + # If there are no batches to finalize, return after ensuring the table exists. + return {} + + files: list[Path] = [] + # Get a list of all files to finalize from all pending batches. + for batch_handle in batches_to_finalize.values(): + batch_handle = cast(FileWriterBatchHandle, batch_handle) + files += batch_handle.files + # Use the max batch ID as the batch ID for table names. + max_batch_id = max(batches_to_finalize.keys()) + + temp_table_name = self._write_files_to_new_table( + files=files, + stream_name=stream_name, + batch_id=max_batch_id, + ) + try: + self._write_temp_table_to_final_table( + stream_name=stream_name, + temp_table_name=temp_table_name, + final_table_name=final_table_name, + write_strategy=write_strategy, + ) + finally: + self._drop_temp_table(temp_table_name, if_exists=True) + + # Return the batch handles as measure of work completed. + return batches_to_finalize + + @overrides + def _finalize_state_messages( + self, + stream_name: str, + state_messages: list[AirbyteStateMessage], + ) -> None: + """Handle state messages by passing them to the catalog manager.""" + if not self._catalog_manager: + raise exc.AirbyteLibInternalError( + message="Catalog manager should exist but does not.", + ) + if state_messages and self._source_name: + self._catalog_manager.save_state( + source_name=self._source_name, + stream_name=stream_name, + state=state_messages[-1], + ) + + def _execute_sql(self, sql: str | TextClause | Executable) -> CursorResult: + """Execute the given SQL statement.""" + if isinstance(sql, str): + sql = text(sql) + if isinstance(sql, TextClause): + sql = sql.execution_options( + autocommit=True, + ) + + with self.get_sql_connection() as conn: + try: + result = conn.execute(sql) + except ( + sqlalchemy.exc.ProgrammingError, + sqlalchemy.exc.SQLAlchemyError, + ) as ex: + msg = f"Error when executing SQL:\n{sql}\n{type(ex).__name__}{ex!s}" + raise SQLRuntimeError(msg) from None # from ex + + return result + + def _drop_temp_table( + self, + table_name: str, + *, + if_exists: bool = True, + ) -> None: + """Drop the given table.""" + exists_str = "IF EXISTS" if if_exists else "" + self._execute_sql(f"DROP TABLE {exists_str} {self._fully_qualified(table_name)}") + + def _write_files_to_new_table( + self, + files: list[Path], + stream_name: str, + batch_id: str, + ) -> str: + """Write a file(s) to a new table. + + This is a generic implementation, which can be overridden by subclasses + to improve performance. + """ + temp_table_name = self._create_table_for_loading(stream_name, batch_id) + for file_path in files: + dataframe = pd.read_json(file_path, lines=True) + + # Pandas will auto-create the table if it doesn't exist, which we don't want. + if not self._table_exists(temp_table_name): + raise exc.AirbyteLibInternalError( + message="Table does not exist after creation.", + context={ + "temp_table_name": temp_table_name, + }, + ) + + dataframe.to_sql( + temp_table_name, + self.get_sql_alchemy_url(), + schema=self.cache.schema_name, + if_exists="append", + index=False, + dtype=self._get_sql_column_definitions(stream_name), + ) + return temp_table_name + + @final + def _write_temp_table_to_final_table( + self, + stream_name: str, + temp_table_name: str, + final_table_name: str, + write_strategy: WriteStrategy, + ) -> None: + """Write the temp table into the final table using the provided write strategy.""" + has_pks: bool = bool(self._get_primary_keys(stream_name)) + has_incremental_key: bool = bool(self._get_incremental_key(stream_name)) + if write_strategy == WriteStrategy.MERGE and not has_pks: + raise exc.AirbyteLibInputError( + message="Cannot use merge strategy on a stream with no primary keys.", + context={ + "stream_name": stream_name, + }, + ) + + if write_strategy == WriteStrategy.AUTO: + if has_pks: + write_strategy = WriteStrategy.MERGE + elif has_incremental_key: + write_strategy = WriteStrategy.APPEND + else: + write_strategy = WriteStrategy.REPLACE + + if write_strategy == WriteStrategy.REPLACE: + self._swap_temp_table_with_final_table( + stream_name=stream_name, + temp_table_name=temp_table_name, + final_table_name=final_table_name, + ) + return + + if write_strategy == WriteStrategy.APPEND: + self._append_temp_table_to_final_table( + stream_name=stream_name, + temp_table_name=temp_table_name, + final_table_name=final_table_name, + ) + return + + if write_strategy == WriteStrategy.MERGE: + if not self.supports_merge_insert: + # Fallback to emulated merge if the database does not support merge natively. + self._emulated_merge_temp_table_to_final_table( + stream_name=stream_name, + temp_table_name=temp_table_name, + final_table_name=final_table_name, + ) + return + + self._merge_temp_table_to_final_table( + stream_name=stream_name, + temp_table_name=temp_table_name, + final_table_name=final_table_name, + ) + return + + raise exc.AirbyteLibInternalError( + message="Write strategy is not supported.", + context={ + "write_strategy": write_strategy, + }, + ) + + def _append_temp_table_to_final_table( + self, + temp_table_name: str, + final_table_name: str, + stream_name: str, + ) -> None: + nl = "\n" + columns = [self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)] + self._execute_sql( + f""" + INSERT INTO {self._fully_qualified(final_table_name)} ( + {f',{nl} '.join(columns)} + ) + SELECT + {f',{nl} '.join(columns)} + FROM {self._fully_qualified(temp_table_name)} + """, + ) + + def _get_primary_keys( + self, + stream_name: str, + ) -> list[str]: + pks = self._get_stream_config(stream_name).primary_key + if not pks: + return [] + + joined_pks = [".".join(pk) for pk in pks] + for pk in joined_pks: + if "." in pk: + msg = "Nested primary keys are not yet supported. Found: {pk}" + raise NotImplementedError(msg) + + return joined_pks + + def _get_incremental_key( + self, + stream_name: str, + ) -> str | None: + return self._get_stream_config(stream_name).cursor_field + + def _swap_temp_table_with_final_table( + self, + stream_name: str, + temp_table_name: str, + final_table_name: str, + ) -> None: + """Merge the temp table into the main one. + + This implementation requires MERGE support in the SQL DB. + Databases that do not support this syntax can override this method. + """ + if final_table_name is None: + raise exc.AirbyteLibInternalError(message="Arg 'final_table_name' cannot be None.") + if temp_table_name is None: + raise exc.AirbyteLibInternalError(message="Arg 'temp_table_name' cannot be None.") + + _ = stream_name + deletion_name = f"{final_table_name}_deleteme" + commands = "\n".join( + [ + f"ALTER TABLE {final_table_name} RENAME TO {deletion_name};", + f"ALTER TABLE {temp_table_name} RENAME TO {final_table_name};", + f"DROP TABLE {deletion_name};", + ] + ) + self._execute_sql(commands) + + def _merge_temp_table_to_final_table( + self, + stream_name: str, + temp_table_name: str, + final_table_name: str, + ) -> None: + """Merge the temp table into the main one. + + This implementation requires MERGE support in the SQL DB. + Databases that do not support this syntax can override this method. + """ + nl = "\n" + columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)} + pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)} + non_pk_columns = columns - pk_columns + join_clause = "{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns) + set_clause = "{nl} ".join(f"{col} = tmp.{col}" for col in non_pk_columns) + self._execute_sql( + f""" + MERGE INTO {self._fully_qualified(final_table_name)} final + USING ( + SELECT * + FROM {self._fully_qualified(temp_table_name)} + ) AS tmp + ON {join_clause} + WHEN MATCHED THEN UPDATE + SET + {set_clause} + WHEN NOT MATCHED THEN INSERT + ( + {f',{nl} '.join(columns)} + ) + VALUES ( + tmp.{f',{nl} tmp.'.join(columns)} + ); + """, + ) + + def _get_column_by_name(self, table: str | Table, column_name: str) -> Column: + """Return the column object for the given column name. + + This method is case-insensitive. + """ + if isinstance(table, str): + table = self._get_table_by_name(table) + try: + # Try to get the column in a case-insensitive manner + return next(col for col in table.c if col.name.lower() == column_name.lower()) + except StopIteration: + raise exc.AirbyteLibInternalError( + message="Could not find matching column.", + context={ + "table": table, + "column_name": column_name, + }, + ) from None + + def _emulated_merge_temp_table_to_final_table( + self, + stream_name: str, + temp_table_name: str, + final_table_name: str, + ) -> None: + """Emulate the merge operation using a series of SQL commands. + + This is a fallback implementation for databases that do not support MERGE. + """ + final_table = self._get_table_by_name(final_table_name) + temp_table = self._get_table_by_name(temp_table_name) + pk_columns = self._get_primary_keys(stream_name) + + columns_to_update: set[str] = self._get_sql_column_definitions( + stream_name=stream_name + ).keys() - set(pk_columns) + + # Create a dictionary mapping columns in users_final to users_stage for updating + update_values = { + self._get_column_by_name(final_table, column): ( + self._get_column_by_name(temp_table, column) + ) + for column in columns_to_update + } + + # Craft the WHERE clause for composite primary keys + join_conditions = [ + self._get_column_by_name(final_table, pk_column) + == self._get_column_by_name(temp_table, pk_column) + for pk_column in pk_columns + ] + join_clause = and_(*join_conditions) + + # Craft the UPDATE statement + update_stmt = update(final_table).values(update_values).where(join_clause) + + # Define a join between temp_table and final_table + joined_table = temp_table.outerjoin(final_table, join_clause) + + # Define a condition that checks for records in temp_table that do not have a corresponding + # record in final_table + where_not_exists_clause = self._get_column_by_name(final_table, pk_columns[0]) == null() + + # Select records from temp_table that are not in final_table + select_new_records_stmt = ( + select([temp_table]).select_from(joined_table).where(where_not_exists_clause) + ) + + # Craft the INSERT statement using the select statement + insert_new_records_stmt = insert(final_table).from_select( + names=[column.name for column in temp_table.columns], select=select_new_records_stmt + ) + + if DEBUG_MODE: + print(str(update_stmt)) + print(str(insert_new_records_stmt)) + + with self.get_sql_connection() as conn: + conn.execute(update_stmt) + conn.execute(insert_new_records_stmt) + + @final + def _table_exists( + self, + table_name: str, + ) -> bool: + """Return true if the given table exists.""" + return table_name in self._get_tables_list() + + @abc.abstractmethod + def _get_telemetry_info(self) -> CacheTelemetryInfo: + pass diff --git a/airbyte/_processors/sql/duckdb.py b/airbyte/_processors/sql/duckdb.py new file mode 100644 index 00000000..d58b5297 --- /dev/null +++ b/airbyte/_processors/sql/duckdb.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A DuckDB implementation of the cache.""" + +from __future__ import annotations + +import warnings +from textwrap import dedent, indent +from typing import TYPE_CHECKING + +from overrides import overrides + +from airbyte._processors.file import JsonlWriter +from airbyte._processors.sql.base import SqlProcessorBase +from airbyte.telemetry import CacheTelemetryInfo + + +if TYPE_CHECKING: + from pathlib import Path + + +# Suppress warnings from DuckDB about reflection on indices. +# https://github.com/Mause/duckdb_engine/issues/905 +warnings.filterwarnings( + "ignore", + message="duckdb-engine doesn't yet support reflection on indices", +) + + +class DuckDBSqlProcessor(SqlProcessorBase): + """A DuckDB implementation of the cache. + + Jsonl is used for local file storage before bulk loading. + Unlike the Snowflake implementation, we can't use the COPY command to load data + so we insert as values instead. + """ + + supports_merge_insert = False + file_writer_class = JsonlWriter + + # @overrides + # def _setup(self) -> None: + # """Create the database parent folder if it doesn't yet exist.""" + # config = cast(DuckDBCache, self.config) + + # if config.db_path == ":memory:": + # return + + # Path(config.db_path).parent.mkdir(parents=True, exist_ok=True) + + @overrides + def _ensure_compatible_table_schema( + self, + stream_name: str, + *, + raise_on_error: bool = True, + ) -> bool: + """Return true if the given table is compatible with the stream's schema. + + In addition to the base implementation, this also checks primary keys. + """ + # call super + if not super()._ensure_compatible_table_schema( + stream_name=stream_name, + raise_on_error=raise_on_error, + ): + return False + + # TODO: Add validation for primary keys after DuckDB adds support for primary key + # inspection: https://github.com/Mause/duckdb_engine/issues/594 + + return True + + def _write_files_to_new_table( + self, + files: list[Path], + stream_name: str, + batch_id: str, + ) -> str: + """Write a file(s) to a new table. + + We use DuckDB's `read_parquet` function to efficiently read the files and insert + them into the table in a single operation. + + Note: This implementation is fragile in regards to column ordering. However, since + we are inserting into a temp table we have just created, there should be no + drift between the table schema and the file schema. + """ + temp_table_name = self._create_table_for_loading( + stream_name=stream_name, + batch_id=batch_id, + ) + columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys()) + columns_list_str = indent( + "\n, ".join([self._quote_identifier(c) for c in columns_list]), + " ", + ) + files_list = ", ".join([f"'{f!s}'" for f in files]) + columns_type_map = indent( + "\n, ".join( + [ + f"{self._quote_identifier(c)}: " + f"{self._get_sql_column_definitions(stream_name)[c]!s}" + for c in columns_list + ] + ), + " ", + ) + insert_statement = dedent( + f""" + INSERT INTO {self.cache.schema_name}.{temp_table_name} + ( + {columns_list_str} + ) + SELECT + {columns_list_str} + FROM read_json_auto( + [{files_list}], + format = 'newline_delimited', + union_by_name = true, + columns = {{ { columns_type_map } }} + ) + """ + ) + self._execute_sql(insert_statement) + return temp_table_name + + @overrides + def _get_telemetry_info(self) -> CacheTelemetryInfo: + return CacheTelemetryInfo("duckdb") diff --git a/airbyte/_processors/sql/postgres.py b/airbyte/_processors/sql/postgres.py new file mode 100644 index 00000000..3573e051 --- /dev/null +++ b/airbyte/_processors/sql/postgres.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A Postgres implementation of the cache.""" + +from __future__ import annotations + +from overrides import overrides + +from airbyte._processors.file import JsonlWriter +from airbyte._processors.sql.base import SqlProcessorBase +from airbyte.telemetry import CacheTelemetryInfo + + +class PostgresSqlProcessor(SqlProcessorBase): + """A Postgres implementation of the cache. + + Jsonl is used for local file storage before bulk loading. + Unlike the Snowflake implementation, we can't use the COPY command to load data + so we insert as values instead. + + TODO: Add optimized bulk load path for Postgres. Could use an alternate file writer + or another import method. (Relatively low priority, since for now it works fine as-is.) + """ + + file_writer_class = JsonlWriter + supports_merge_insert = False # TODO: Add native implementation for merge insert + + @overrides + def _get_telemetry_info(self) -> CacheTelemetryInfo: + return CacheTelemetryInfo("postgres") diff --git a/airbyte/_processors/sql/snowflake.py b/airbyte/_processors/sql/snowflake.py new file mode 100644 index 00000000..72025fd4 --- /dev/null +++ b/airbyte/_processors/sql/snowflake.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A Snowflake implementation of the SQL processor.""" + +from __future__ import annotations + +from textwrap import dedent, indent +from typing import TYPE_CHECKING + +import sqlalchemy +from overrides import overrides +from snowflake.sqlalchemy import VARIANT + +from airbyte._processors.file.jsonl import JsonlWriter +from airbyte._processors.sql.base import SqlProcessorBase +from airbyte.telemetry import CacheTelemetryInfo +from airbyte.types import SQLTypeConverter + + +if TYPE_CHECKING: + from pathlib import Path + + from sqlalchemy.engine import Connection + + +class SnowflakeTypeConverter(SQLTypeConverter): + """A class to convert types for Snowflake.""" + + @overrides + def to_sql_type( + self, + json_schema_property_def: dict[str, str | dict | list], + ) -> sqlalchemy.types.TypeEngine: + """Convert a value to a SQL type. + + We first call the parent class method to get the type. Then if the type JSON, we + replace it with VARIANT. + """ + sql_type = super().to_sql_type(json_schema_property_def) + if isinstance(sql_type, sqlalchemy.types.JSON): + return VARIANT() + + return sql_type + + +class SnowflakeSQLSqlProcessor(SqlProcessorBase): + """A Snowflake implementation of the cache. + + Parquet is used for local file storage before bulk loading. + """ + + file_writer_class = JsonlWriter + type_converter_class = SnowflakeTypeConverter + + @overrides + def _write_files_to_new_table( + self, + files: list[Path], + stream_name: str, + batch_id: str, + ) -> str: + """Write files to a new table.""" + temp_table_name = self._create_table_for_loading( + stream_name=stream_name, + batch_id=batch_id, + ) + internal_sf_stage_name = f"@%{temp_table_name}" + put_files_statements = "\n".join( + [ + f"PUT 'file://{file_path.absolute()!s}' {internal_sf_stage_name};" + for file_path in files + ] + ) + self._execute_sql(put_files_statements) + + columns_list = [ + self._quote_identifier(c) + for c in list(self._get_sql_column_definitions(stream_name).keys()) + ] + files_list = ", ".join([f"'{f.name}'" for f in files]) + columns_list_str: str = indent("\n, ".join(columns_list), " " * 12) + variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list]) + copy_statement = dedent( + f""" + COPY INTO {temp_table_name} + ( + {columns_list_str} + ) + FROM ( + SELECT {variant_cols_str} + FROM {internal_sf_stage_name} + ) + FILES = ( {files_list} ) + FILE_FORMAT = ( TYPE = JSON ) + ; + """ + ) + self._execute_sql(copy_statement) + return temp_table_name + + @overrides + def _init_connection_settings(self, connection: Connection) -> None: + """We set Snowflake-specific settings for the session. + + This sets QUOTED_IDENTIFIERS_IGNORE_CASE setting to True, which is necessary because + Snowflake otherwise will treat quoted table and column references as case-sensitive. + More info: https://docs.snowflake.com/en/sql-reference/identifiers-syntax + + This also sets MULTI_STATEMENT_COUNT to 0, which allows multi-statement commands. + """ + connection.execute( + """ + ALTER SESSION SET + QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE + MULTI_STATEMENT_COUNT = 0 + """ + ) + + @overrides + def _get_telemetry_info(self) -> CacheTelemetryInfo: + return CacheTelemetryInfo("snowflake") diff --git a/airbyte/caches/__init__.py b/airbyte/caches/__init__.py index 2e19dc06..a1db85b4 100644 --- a/airbyte/caches/__init__.py +++ b/airbyte/caches/__init__.py @@ -1,7 +1,8 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. """Base module for all caches.""" from __future__ import annotations -from airbyte.caches.base import SQLCacheBase +from airbyte.caches.base import CacheBase from airbyte.caches.duckdb import DuckDBCache from airbyte.caches.postgres import PostgresCache from airbyte.caches.snowflake import SnowflakeCache @@ -11,6 +12,6 @@ __all__ = [ "DuckDBCache", "PostgresCache", - "SQLCacheBase", + "CacheBase", "SnowflakeCache", ] diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index f2e982a8..33c6b728 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -1,926 +1,39 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. """A SQL Cache implementation.""" + from __future__ import annotations import abc -import enum -from contextlib import contextmanager -from functools import cached_property -from typing import TYPE_CHECKING, Any, cast, final +from pathlib import Path +from typing import TYPE_CHECKING, Any, final -import pandas as pd -import sqlalchemy import ulid -from overrides import overrides -from pydantic import PrivateAttr -from sqlalchemy import ( - Column, - Table, - and_, - create_engine, - insert, - null, - select, - text, - update, -) -from sqlalchemy.pool import StaticPool -from sqlalchemy.sql.elements import TextClause +from pydantic import BaseModel, Field, PrivateAttr -from airbyte import exceptions as exc -from airbyte._file_writers.base import FileWriterBase, FileWriterBatchHandle -from airbyte._processors import BatchHandle, RecordProcessor -from airbyte._util.text_util import lower_case_set -from airbyte.caches._catalog_manager import CatalogManager -from airbyte.config import CacheConfigBase from airbyte.datasets._sql import CachedDataset -from airbyte.strategies import WriteStrategy -from airbyte.types import SQLTypeConverter if TYPE_CHECKING: from collections.abc import Generator - from pathlib import Path - - import pyarrow as pa - from sqlalchemy.engine import Connection, Engine - from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.engine.reflection import Inspector - from sqlalchemy.sql.base import Executable - from airbyte_protocol.models import ( - AirbyteStateMessage, - ConfiguredAirbyteCatalog, - ) + from sqlalchemy.engine import Engine + from airbyte._processors.sql.base import SqlProcessorBase from airbyte.datasets._base import DatasetBase - from airbyte.telemetry import CacheTelemetryInfo - - -DEBUG_MODE = False # Set to True to enable additional debug logging. - - -class RecordDedupeMode(enum.Enum): - APPEND = "append" - REPLACE = "replace" - - -class SQLRuntimeError(Exception): - """Raised when an SQL operation fails.""" - - -class SQLCacheInstanceBase(RecordProcessor): - """A base class to be used for SQL Caches.""" - - type_converter_class: type[SQLTypeConverter] = SQLTypeConverter - file_writer_class: type[FileWriterBase] - - supports_merge_insert = False - use_singleton_connection = False # If true, the same connection is used for all operations. - - # Constructor: - - @final # We don't want subclasses to have to override the constructor. - def __init__( - self, - config: SQLCacheBase, - file_writer: FileWriterBase | None = None, - ) -> None: - self._engine: Engine | None = None - self._connection_to_reuse: Connection | None = None - super().__init__(config, catalog_manager=None) - self._ensure_schema_exists() - self._catalog_manager = CatalogManager( - engine=self.get_sql_engine(), - table_name_resolver=lambda stream_name: self.get_sql_table_name(stream_name), - ) - self.file_writer = file_writer or self.file_writer_class( - config, catalog_manager=self._catalog_manager - ) - self.type_converter = self.type_converter_class() - self._cached_table_definitions: dict[str, sqlalchemy.Table] = {} - - # Public interface: - - def get_sql_alchemy_url(self) -> str: - """Return the SQLAlchemy URL to use.""" - return self.config.get_sql_alchemy_url() - - @final - @cached_property - def database_name(self) -> str: - """Return the name of the database.""" - return self.config.get_database_name() - - @final - def get_sql_engine(self) -> Engine: - """Return a new SQL engine to use.""" - if self._engine: - return self._engine - - sql_alchemy_url = self.get_sql_alchemy_url() - - execution_options = {"schema_translate_map": {None: self.config.schema_name}} - if self.use_singleton_connection: - if self._connection_to_reuse is None: - # This temporary bootstrap engine will be created once and is needed to - # create the long-lived connection object. - bootstrap_engine = create_engine( - sql_alchemy_url, - ) - self._connection_to_reuse = bootstrap_engine.connect() - - self._engine = create_engine( - sql_alchemy_url, - creator=lambda: self._connection_to_reuse, - poolclass=StaticPool, - echo=DEBUG_MODE, - execution_options=execution_options, - # isolation_level="AUTOCOMMIT", - ) - else: - # Regular engine creation for new connections - self._engine = create_engine( - sql_alchemy_url, - echo=DEBUG_MODE, - execution_options=execution_options, - # isolation_level="AUTOCOMMIT", - ) - - return self._engine - - @overrides - def register_source( - self, - source_name: str, - incoming_source_catalog: ConfiguredAirbyteCatalog, - stream_names: set[str], - ) -> None: - """Register the source with the cache. - - We use stream_names to determine which streams will receive data, and - we only register the stream if is expected to receive data. - - This method is called by the source when it is initialized. - """ - self._source_name = source_name - self.file_writer.register_source( - source_name, - incoming_source_catalog, - stream_names=stream_names, - ) - self._ensure_schema_exists() - super().register_source( - source_name, - incoming_source_catalog, - stream_names=stream_names, - ) - - @contextmanager - def get_sql_connection(self) -> Generator[sqlalchemy.engine.Connection, None, None]: - """A context manager which returns a new SQL connection for running queries. - - If the connection needs to close, it will be closed automatically. - """ - if self.use_singleton_connection and self._connection_to_reuse is not None: - connection = self._connection_to_reuse - self._init_connection_settings(connection) - yield connection - - else: - with self.get_sql_engine().begin() as connection: - self._init_connection_settings(connection) - yield connection - - if not self.use_singleton_connection: - connection.close() - del connection - - def get_sql_table_name( - self, - stream_name: str, - ) -> str: - """Return the name of the SQL table for the given stream.""" - table_prefix = self.config.table_prefix or "" - - # TODO: Add default prefix based on the source name. - - return self._normalize_table_name( - f"{table_prefix}{stream_name}{self.config.table_suffix}", - ) - - @final - def get_sql_table( - self, - stream_name: str, - ) -> sqlalchemy.Table: - """Return the main table object for the stream.""" - return self._get_table_by_name(self.get_sql_table_name(stream_name)) - # Read methods: - def get_records( - self, - stream_name: str, - ) -> CachedDataset: - """Uses SQLAlchemy to select all rows from the table.""" - return CachedDataset(self.config, stream_name) - - def get_pandas_dataframe( - self, - stream_name: str, - ) -> pd.DataFrame: - """Return a Pandas data frame with the stream's data.""" - table_name = self.get_sql_table_name(stream_name) - engine = self.get_sql_engine() - return pd.read_sql_table(table_name, engine) - - # Protected members (non-public interface): - - def _init_connection_settings(self, connection: Connection) -> None: - """This is called automatically whenever a new connection is created. - - By default this is a no-op. Subclasses can use this to set connection settings, such as - timezone, case-sensitivity settings, and other session-level variables. - """ - pass - - def _get_table_by_name( - self, - table_name: str, - *, - force_refresh: bool = False, - ) -> sqlalchemy.Table: - """Return a table object from a table name. - - To prevent unnecessary round-trips to the database, the table is cached after the first - query. To ignore the cache and force a refresh, set 'force_refresh' to True. - """ - if force_refresh or table_name not in self._cached_table_definitions: - self._cached_table_definitions[table_name] = sqlalchemy.Table( - table_name, - sqlalchemy.MetaData(schema=self.config.schema_name), - autoload_with=self.get_sql_engine(), - ) - - return self._cached_table_definitions[table_name] - - def _ensure_schema_exists( - self, - ) -> None: - """Return a new (unique) temporary table name.""" - schema_name = self.config.schema_name - if schema_name in self._get_schemas_list(): - return - - sql = f"CREATE SCHEMA IF NOT EXISTS {schema_name}" - - try: - self._execute_sql(sql) - except Exception as ex: - # Ignore schema exists errors. - if "already exists" not in str(ex): - raise - - if DEBUG_MODE: - found_schemas = self._get_schemas_list() - assert ( - schema_name in found_schemas - ), f"Schema {schema_name} was not created. Found: {found_schemas}" - - def _quote_identifier(self, identifier: str) -> str: - """Return the given identifier, quoted.""" - return f'"{identifier}"' - - @final - def _get_temp_table_name( - self, - stream_name: str, - batch_id: str | None = None, # ULID of the batch - ) -> str: - """Return a new (unique) temporary table name.""" - batch_id = batch_id or str(ulid.ULID()) - return self._normalize_table_name(f"{stream_name}_{batch_id}") - - def _fully_qualified( - self, - table_name: str, - ) -> str: - """Return the fully qualified name of the given table.""" - return f"{self.config.schema_name}.{self._quote_identifier(table_name)}" - - @final - def _create_table_for_loading( - self, - /, - stream_name: str, - batch_id: str, - ) -> str: - """Create a new table for loading data.""" - temp_table_name = self._get_temp_table_name(stream_name, batch_id) - column_definition_str = ",\n ".join( - f"{self._quote_identifier(column_name)} {sql_type}" - for column_name, sql_type in self._get_sql_column_definitions(stream_name).items() - ) - self._create_table(temp_table_name, column_definition_str) - - return temp_table_name - - def _get_tables_list( - self, - ) -> list[str]: - """Return a list of all tables in the database.""" - with self.get_sql_connection() as conn: - inspector: Inspector = sqlalchemy.inspect(conn) - return inspector.get_table_names(schema=self.config.schema_name) - - def _get_schemas_list( - self, - database_name: str | None = None, - ) -> list[str]: - """Return a list of all tables in the database.""" - inspector: Inspector = sqlalchemy.inspect(self.get_sql_engine()) - database_name = database_name or self.database_name - found_schemas = inspector.get_schema_names() - return [ - found_schema.split(".")[-1].strip('"') - for found_schema in found_schemas - if "." not in found_schema - or (found_schema.split(".")[0].lower().strip('"') == database_name.lower()) - ] - - def _ensure_final_table_exists( - self, - stream_name: str, - *, - create_if_missing: bool = True, - ) -> str: - """Create the final table if it doesn't already exist. - - Return the table name. - """ - table_name = self.get_sql_table_name(stream_name) - did_exist = self._table_exists(table_name) - if not did_exist and create_if_missing: - column_definition_str = ",\n ".join( - f"{self._quote_identifier(column_name)} {sql_type}" - for column_name, sql_type in self._get_sql_column_definitions( - stream_name, - ).items() - ) - self._create_table(table_name, column_definition_str) - - return table_name +# TODO: meta=EnforceOverrides (Pydantic doesn't like it currently.) +class CacheBase(BaseModel): + """Base configuration for a cache.""" - def _ensure_compatible_table_schema( - self, - stream_name: str, - *, - raise_on_error: bool = False, - ) -> bool: - """Return true if the given table is compatible with the stream's schema. - - If raise_on_error is true, raise an exception if the table is not compatible. - - TODO: Expand this to check for column types and sizes, and to add missing columns. + cache_dir: Path = Path(".cache") + """The directory to store the cache in.""" - Returns true if the table is compatible, false if it is not. - """ - json_schema = self._get_stream_json_schema(stream_name) - stream_column_names: list[str] = json_schema["properties"].keys() - table_column_names: list[str] = self.get_sql_table(stream_name).columns.keys() - - lower_case_table_column_names = lower_case_set(table_column_names) - missing_columns = [ - stream_col - for stream_col in stream_column_names - if stream_col.lower() not in lower_case_table_column_names - ] - if missing_columns: - if raise_on_error: - raise exc.AirbyteLibCacheTableValidationError( - violation="Cache table is missing expected columns.", - context={ - "stream_column_names": stream_column_names, - "table_column_names": table_column_names, - "missing_columns": missing_columns, - }, - ) - return False # Some columns are missing. - - return True # All columns exist. - - @final - def _create_table( - self, - table_name: str, - column_definition_str: str, - primary_keys: list[str] | None = None, - ) -> None: - if DEBUG_MODE: - assert table_name not in self._get_tables_list(), f"Table {table_name} already exists." - - if primary_keys: - pk_str = ", ".join(primary_keys) - column_definition_str += f",\n PRIMARY KEY ({pk_str})" - - cmd = f""" - CREATE TABLE {self._fully_qualified(table_name)} ( - {column_definition_str} - ) - """ - _ = self._execute_sql(cmd) - if DEBUG_MODE: - tables_list = self._get_tables_list() - assert ( - table_name in tables_list - ), f"Table {table_name} was not created. Found: {tables_list}" - - def _normalize_column_name( - self, - raw_name: str, - ) -> str: - return raw_name.lower().replace(" ", "_").replace("-", "_") - - def _normalize_table_name( - self, - raw_name: str, - ) -> str: - return raw_name.lower().replace(" ", "_").replace("-", "_") - - @final - def _get_sql_column_definitions( - self, - stream_name: str, - ) -> dict[str, sqlalchemy.types.TypeEngine]: - """Return the column definitions for the given stream.""" - columns: dict[str, sqlalchemy.types.TypeEngine] = {} - properties = self._get_stream_json_schema(stream_name)["properties"] - for property_name, json_schema_property_def in properties.items(): - clean_prop_name = self._normalize_column_name(property_name) - columns[clean_prop_name] = self.type_converter.to_sql_type( - json_schema_property_def, - ) - - # TODO: Add the metadata columns (this breaks tests) - # columns["_airbyte_extracted_at"] = sqlalchemy.TIMESTAMP() - # columns["_airbyte_loaded_at"] = sqlalchemy.TIMESTAMP() - return columns - - @overrides - def _write_batch( - self, - stream_name: str, - batch_id: str, - record_batch: pa.Table, - ) -> FileWriterBatchHandle: - """Process a record batch. - - Return the path to the cache file. - """ - return self.file_writer.write_batch(stream_name, batch_id, record_batch) - - def _cleanup_batch( - self, - stream_name: str, - batch_id: str, - batch_handle: BatchHandle, - ) -> None: - """Clean up the cache. - - For SQL caches, we only need to call the cleanup operation on the file writer. - - Subclasses should call super() if they override this method. - """ - self.file_writer.cleanup_batch(stream_name, batch_id, batch_handle) - - @final - @overrides - def _finalize_batches( - self, - stream_name: str, - write_strategy: WriteStrategy, - ) -> dict[str, BatchHandle]: - """Finalize all uncommitted batches. - - This is a generic 'final' implementation, which should not be overridden. - - Returns a mapping of batch IDs to batch handles, for those processed batches. - - TODO: Add a dedupe step here to remove duplicates from the temp table. - Some sources will send us duplicate records within the same stream, - although this is a fairly rare edge case we can ignore in V1. - """ - with self._finalizing_batches(stream_name) as batches_to_finalize: - # Make sure the target schema and target table exist. - self._ensure_schema_exists() - final_table_name = self._ensure_final_table_exists( - stream_name, - create_if_missing=True, - ) - self._ensure_compatible_table_schema( - stream_name=stream_name, - raise_on_error=True, - ) - - if not batches_to_finalize: - # If there are no batches to finalize, return after ensuring the table exists. - return {} - - files: list[Path] = [] - # Get a list of all files to finalize from all pending batches. - for batch_handle in batches_to_finalize.values(): - batch_handle = cast(FileWriterBatchHandle, batch_handle) - files += batch_handle.files - # Use the max batch ID as the batch ID for table names. - max_batch_id = max(batches_to_finalize.keys()) - - temp_table_name = self._write_files_to_new_table( - files=files, - stream_name=stream_name, - batch_id=max_batch_id, - ) - try: - self._write_temp_table_to_final_table( - stream_name=stream_name, - temp_table_name=temp_table_name, - final_table_name=final_table_name, - write_strategy=write_strategy, - ) - finally: - self._drop_temp_table(temp_table_name, if_exists=True) - - # Return the batch handles as measure of work completed. - return batches_to_finalize - - @overrides - def _finalize_state_messages( - self, - stream_name: str, - state_messages: list[AirbyteStateMessage], - ) -> None: - """Handle state messages by passing them to the catalog manager.""" - if not self._catalog_manager: - raise exc.AirbyteLibInternalError( - message="Catalog manager should exist but does not.", - ) - if state_messages and self._source_name: - self._catalog_manager.save_state( - source_name=self._source_name, - stream_name=stream_name, - state=state_messages[-1], - ) - - def _execute_sql(self, sql: str | TextClause | Executable) -> CursorResult: - """Execute the given SQL statement.""" - if isinstance(sql, str): - sql = text(sql) - if isinstance(sql, TextClause): - sql = sql.execution_options( - autocommit=True, - ) - - with self.get_sql_connection() as conn: - try: - result = conn.execute(sql) - except ( - sqlalchemy.exc.ProgrammingError, - sqlalchemy.exc.SQLAlchemyError, - ) as ex: - msg = f"Error when executing SQL:\n{sql}\n{type(ex).__name__}{ex!s}" - raise SQLRuntimeError(msg) from None # from ex - - return result - - def _drop_temp_table( - self, - table_name: str, - *, - if_exists: bool = True, - ) -> None: - """Drop the given table.""" - exists_str = "IF EXISTS" if if_exists else "" - self._execute_sql(f"DROP TABLE {exists_str} {self._fully_qualified(table_name)}") - - def _write_files_to_new_table( - self, - files: list[Path], - stream_name: str, - batch_id: str, - ) -> str: - """Write a file(s) to a new table. - - This is a generic implementation, which can be overridden by subclasses - to improve performance. - """ - temp_table_name = self._create_table_for_loading(stream_name, batch_id) - for file_path in files: - dataframe = pd.read_json(file_path, lines=True) - - # Pandas will auto-create the table if it doesn't exist, which we don't want. - if not self._table_exists(temp_table_name): - raise exc.AirbyteLibInternalError( - message="Table does not exist after creation.", - context={ - "temp_table_name": temp_table_name, - }, - ) - - dataframe.to_sql( - temp_table_name, - self.get_sql_alchemy_url(), - schema=self.config.schema_name, - if_exists="append", - index=False, - dtype=self._get_sql_column_definitions(stream_name), - ) - return temp_table_name - - @final - def _write_temp_table_to_final_table( - self, - stream_name: str, - temp_table_name: str, - final_table_name: str, - write_strategy: WriteStrategy, - ) -> None: - """Write the temp table into the final table using the provided write strategy.""" - has_pks: bool = bool(self._get_primary_keys(stream_name)) - has_incremental_key: bool = bool(self._get_incremental_key(stream_name)) - if write_strategy == WriteStrategy.MERGE and not has_pks: - raise exc.AirbyteLibInputError( - message="Cannot use merge strategy on a stream with no primary keys.", - context={ - "stream_name": stream_name, - }, - ) - - if write_strategy == WriteStrategy.AUTO: - if has_pks: - write_strategy = WriteStrategy.MERGE - elif has_incremental_key: - write_strategy = WriteStrategy.APPEND - else: - write_strategy = WriteStrategy.REPLACE - - if write_strategy == WriteStrategy.REPLACE: - self._swap_temp_table_with_final_table( - stream_name=stream_name, - temp_table_name=temp_table_name, - final_table_name=final_table_name, - ) - return - - if write_strategy == WriteStrategy.APPEND: - self._append_temp_table_to_final_table( - stream_name=stream_name, - temp_table_name=temp_table_name, - final_table_name=final_table_name, - ) - return - - if write_strategy == WriteStrategy.MERGE: - if not self.supports_merge_insert: - # Fallback to emulated merge if the database does not support merge natively. - self._emulated_merge_temp_table_to_final_table( - stream_name=stream_name, - temp_table_name=temp_table_name, - final_table_name=final_table_name, - ) - return - - self._merge_temp_table_to_final_table( - stream_name=stream_name, - temp_table_name=temp_table_name, - final_table_name=final_table_name, - ) - return - - raise exc.AirbyteLibInternalError( - message="Write strategy is not supported.", - context={ - "write_strategy": write_strategy, - }, - ) - - def _append_temp_table_to_final_table( - self, - temp_table_name: str, - final_table_name: str, - stream_name: str, - ) -> None: - nl = "\n" - columns = [self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)] - self._execute_sql( - f""" - INSERT INTO {self._fully_qualified(final_table_name)} ( - {f',{nl} '.join(columns)} - ) - SELECT - {f',{nl} '.join(columns)} - FROM {self._fully_qualified(temp_table_name)} - """, - ) - - def _get_primary_keys( - self, - stream_name: str, - ) -> list[str]: - pks = self._get_stream_config(stream_name).primary_key - if not pks: - return [] - - joined_pks = [".".join(pk) for pk in pks] - for pk in joined_pks: - if "." in pk: - msg = "Nested primary keys are not yet supported. Found: {pk}" - raise NotImplementedError(msg) - - return joined_pks - - def _get_incremental_key( - self, - stream_name: str, - ) -> str | None: - return self._get_stream_config(stream_name).cursor_field - - def _swap_temp_table_with_final_table( - self, - stream_name: str, - temp_table_name: str, - final_table_name: str, - ) -> None: - """Merge the temp table into the main one. - - This implementation requires MERGE support in the SQL DB. - Databases that do not support this syntax can override this method. - """ - if final_table_name is None: - raise exc.AirbyteLibInternalError(message="Arg 'final_table_name' cannot be None.") - if temp_table_name is None: - raise exc.AirbyteLibInternalError(message="Arg 'temp_table_name' cannot be None.") - - _ = stream_name - deletion_name = f"{final_table_name}_deleteme" - commands = "\n".join( - [ - f"ALTER TABLE {final_table_name} RENAME TO {deletion_name};", - f"ALTER TABLE {temp_table_name} RENAME TO {final_table_name};", - f"DROP TABLE {deletion_name};", - ] - ) - self._execute_sql(commands) - - def _merge_temp_table_to_final_table( - self, - stream_name: str, - temp_table_name: str, - final_table_name: str, - ) -> None: - """Merge the temp table into the main one. - - This implementation requires MERGE support in the SQL DB. - Databases that do not support this syntax can override this method. - """ - nl = "\n" - columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)} - pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)} - non_pk_columns = columns - pk_columns - join_clause = "{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns) - set_clause = "{nl} ".join(f"{col} = tmp.{col}" for col in non_pk_columns) - self._execute_sql( - f""" - MERGE INTO {self._fully_qualified(final_table_name)} final - USING ( - SELECT * - FROM {self._fully_qualified(temp_table_name)} - ) AS tmp - ON {join_clause} - WHEN MATCHED THEN UPDATE - SET - {set_clause} - WHEN NOT MATCHED THEN INSERT - ( - {f',{nl} '.join(columns)} - ) - VALUES ( - tmp.{f',{nl} tmp.'.join(columns)} - ); - """, - ) - - def _get_column_by_name(self, table: str | Table, column_name: str) -> Column: - """Return the column object for the given column name. - - This method is case-insensitive. - """ - if isinstance(table, str): - table = self._get_table_by_name(table) - try: - # Try to get the column in a case-insensitive manner - return next(col for col in table.c if col.name.lower() == column_name.lower()) - except StopIteration: - raise exc.AirbyteLibInternalError( - message="Could not find matching column.", - context={ - "table": table, - "column_name": column_name, - }, - ) from None - - def _emulated_merge_temp_table_to_final_table( - self, - stream_name: str, - temp_table_name: str, - final_table_name: str, - ) -> None: - """Emulate the merge operation using a series of SQL commands. - - This is a fallback implementation for databases that do not support MERGE. - """ - final_table = self._get_table_by_name(final_table_name) - temp_table = self._get_table_by_name(temp_table_name) - pk_columns = self._get_primary_keys(stream_name) - - columns_to_update: set[str] = self._get_sql_column_definitions( - stream_name=stream_name - ).keys() - set(pk_columns) - - # Create a dictionary mapping columns in users_final to users_stage for updating - update_values = { - self._get_column_by_name(final_table, column): ( - self._get_column_by_name(temp_table, column) - ) - for column in columns_to_update - } - - # Craft the WHERE clause for composite primary keys - join_conditions = [ - self._get_column_by_name(final_table, pk_column) - == self._get_column_by_name(temp_table, pk_column) - for pk_column in pk_columns - ] - join_clause = and_(*join_conditions) - - # Craft the UPDATE statement - update_stmt = update(final_table).values(update_values).where(join_clause) - - # Define a join between temp_table and final_table - joined_table = temp_table.outerjoin(final_table, join_clause) - - # Define a condition that checks for records in temp_table that do not have a corresponding - # record in final_table - where_not_exists_clause = self._get_column_by_name(final_table, pk_columns[0]) == null() - - # Select records from temp_table that are not in final_table - select_new_records_stmt = ( - select([temp_table]).select_from(joined_table).where(where_not_exists_clause) - ) - - # Craft the INSERT statement using the select statement - insert_new_records_stmt = insert(final_table).from_select( - names=[column.name for column in temp_table.columns], select=select_new_records_stmt - ) - - if DEBUG_MODE: - print(str(update_stmt)) - print(str(insert_new_records_stmt)) - - with self.get_sql_connection() as conn: - conn.execute(update_stmt) - conn.execute(insert_new_records_stmt) - - @final - def _table_exists( - self, - table_name: str, - ) -> bool: - """Return true if the given table exists.""" - return table_name in self._get_tables_list() - - @property - @overrides - def _streams_with_data(self) -> set[str]: - """Return a list of known streams.""" - if not self._catalog_manager: - raise exc.AirbyteLibInternalError( - message="Cannot get streams with data without a catalog.", - ) - return { - stream.stream.name - for stream in self._catalog_manager.source_catalog.streams - if self._table_exists(self.get_sql_table_name(stream.stream.name)) - } - - @abc.abstractmethod - def _get_telemetry_info(self) -> CacheTelemetryInfo: - pass - - -class SQLCacheBase(CacheConfigBase): - """Same as a regular config except it exposes the 'get_sql_alchemy_url()' method.""" + cleanup: bool = True + """Whether to clean up the cache after use.""" schema_name: str = "airbyte_raw" + """The name of the schema to write to.""" table_prefix: str | None = None """ A prefix to add to all table names. @@ -930,15 +43,15 @@ class SQLCacheBase(CacheConfigBase): table_suffix: str = "" """A suffix to add to all table names.""" - _sql_processor_class: type[SQLCacheInstanceBase] = PrivateAttr() - _sql_processor: SQLCacheInstanceBase | None = PrivateAttr(default=None) + _sql_processor_class: type[SqlProcessorBase] = PrivateAttr() + _sql_processor: SqlProcessorBase | None = PrivateAttr(default=None) @final @property - def processor(self) -> SQLCacheInstanceBase: + def processor(self) -> SqlProcessorBase: """Return the SQL processor instance.""" if self._sql_processor is None: - self._sql_processor = self._sql_processor_class(config=self) + self._sql_processor = self._sql_processor_class(cache=self) return self._sql_processor @final @@ -963,7 +76,7 @@ def streams( ) -> dict[str, CachedDataset]: """Return a temporary table name.""" result = {} - for stream_name in self._streams_with_data: + for stream_name in self.processor.expected_streams: result[stream_name] = CachedDataset(self, stream_name) return result @@ -972,12 +85,7 @@ def __getitem__(self, stream: str) -> DatasetBase: return self.streams[stream] def __contains__(self, stream: str) -> bool: - return stream in self._streams_with_data + return stream in (self.processor.expected_streams) def __iter__(self) -> Generator[tuple[str, Any], None, None]: return ((name, dataset) for name, dataset in self.streams.items()) - - @property - def _streams_with_data(self) -> set[str]: - """Return a list of known streams.""" - return self.processor._streams_with_data diff --git a/airbyte/caches/duckdb.py b/airbyte/caches/duckdb.py index 4c88faf8..db7cb904 100644 --- a/airbyte/caches/duckdb.py +++ b/airbyte/caches/duckdb.py @@ -1,20 +1,20 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. - """A DuckDB implementation of the cache.""" from __future__ import annotations import warnings -from pathlib import Path -from textwrap import dedent, indent -from typing import cast +from typing import TYPE_CHECKING from overrides import overrides from pydantic import PrivateAttr -from airbyte._file_writers import JsonlWriter, JsonlWriterConfig -from airbyte.caches.base import SQLCacheBase, SQLCacheInstanceBase -from airbyte.telemetry import CacheTelemetryInfo +from airbyte._processors.sql.duckdb import DuckDBSqlProcessor +from airbyte.caches.base import CacheBase + + +if TYPE_CHECKING: + from pathlib import Path # Suppress warnings from DuckDB about reflection on indices. @@ -25,110 +25,7 @@ ) -class DuckDBCacheInstance(SQLCacheInstanceBase): - """A DuckDB implementation of the cache. - - Jsonl is used for local file storage before bulk loading. - Unlike the Snowflake implementation, we can't use the COPY command to load data - so we insert as values instead. - """ - - supports_merge_insert = False - file_writer_class = JsonlWriter - - @overrides - def _setup(self) -> None: - """Create the database parent folder if it doesn't yet exist.""" - config = cast(DuckDBCache, self.config) - - if config.db_path == ":memory:": - return - - Path(config.db_path).parent.mkdir(parents=True, exist_ok=True) - - @overrides - def _ensure_compatible_table_schema( - self, - stream_name: str, - *, - raise_on_error: bool = True, - ) -> bool: - """Return true if the given table is compatible with the stream's schema. - - In addition to the base implementation, this also checks primary keys. - """ - # call super - if not super()._ensure_compatible_table_schema( - stream_name=stream_name, - raise_on_error=raise_on_error, - ): - return False - - # TODO: Add validation for primary keys after DuckDB adds support for primary key - # inspection: https://github.com/Mause/duckdb_engine/issues/594 - - return True - - def _write_files_to_new_table( - self, - files: list[Path], - stream_name: str, - batch_id: str, - ) -> str: - """Write a file(s) to a new table. - - We use DuckDB's `read_parquet` function to efficiently read the files and insert - them into the table in a single operation. - - Note: This implementation is fragile in regards to column ordering. However, since - we are inserting into a temp table we have just created, there should be no - drift between the table schema and the file schema. - """ - temp_table_name = self._create_table_for_loading( - stream_name=stream_name, - batch_id=batch_id, - ) - columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys()) - columns_list_str = indent( - "\n, ".join([self._quote_identifier(c) for c in columns_list]), - " ", - ) - files_list = ", ".join([f"'{f!s}'" for f in files]) - columns_type_map = indent( - "\n, ".join( - [ - f"{self._quote_identifier(c)}: " - f"{self._get_sql_column_definitions(stream_name)[c]!s}" - for c in columns_list - ] - ), - " ", - ) - insert_statement = dedent( - f""" - INSERT INTO {self.config.schema_name}.{temp_table_name} - ( - {columns_list_str} - ) - SELECT - {columns_list_str} - FROM read_json_auto( - [{files_list}], - format = 'newline_delimited', - union_by_name = true, - columns = {{ { columns_type_map } }} - ) - """ - ) - self._execute_sql(insert_statement) - return temp_table_name - - @overrides - def _get_telemetry_info(self) -> CacheTelemetryInfo: - return CacheTelemetryInfo("duckdb") - - -class DuckDBCache(SQLCacheBase, JsonlWriterConfig): +class DuckDBCache(CacheBase): """A DuckDB cache. Also inherits config from the JsonlWriterConfig, which is responsible for writing files to disk. @@ -140,10 +37,11 @@ class DuckDBCache(SQLCacheBase, JsonlWriterConfig): There are some cases, such as when connecting to MotherDuck, where it could be a string that is not also a path, such as "md:" to connect the user's default MotherDuck DB. """ + schema_name: str = "main" """The name of the schema to write to. Defaults to "main".""" - _sql_processor_class = PrivateAttr(DuckDBCacheInstance) + _sql_processor_class = DuckDBSqlProcessor @overrides def get_sql_alchemy_url(self) -> str: diff --git a/airbyte/caches/generic.py b/airbyte/caches/generic.py index 1a2aad1c..5eed71fa 100644 --- a/airbyte/caches/generic.py +++ b/airbyte/caches/generic.py @@ -1,13 +1,14 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. """A Generic SQL Cache implementation.""" + from __future__ import annotations from overrides import overrides -from airbyte.caches.base import SQLCacheBase +from airbyte.caches.base import CacheBase -class GenericSQLCacheConfig(SQLCacheBase): +class GenericSQLCacheConfig(CacheBase): """Allows configuring 'sql_alchemy_url' directly.""" sql_alchemy_url: str diff --git a/airbyte/caches/postgres.py b/airbyte/caches/postgres.py index 6fafcade..a9cd4727 100644 --- a/airbyte/caches/postgres.py +++ b/airbyte/caches/postgres.py @@ -1,36 +1,15 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. - """A Postgres implementation of the cache.""" from __future__ import annotations from overrides import overrides -from airbyte._file_writers import JsonlWriter, JsonlWriterConfig -from airbyte.caches.base import SQLCacheBase, SQLCacheInstanceBase -from airbyte.telemetry import CacheTelemetryInfo - - -class PostgresCacheInstance(SQLCacheInstanceBase): - """A Postgres implementation of the cache. - - Jsonl is used for local file storage before bulk loading. - Unlike the Snowflake implementation, we can't use the COPY command to load data - so we insert as values instead. - - TOOD: Add optimized bulk load path for Postgres. Could use an alternate file writer - or another import method. (Relatively low priority, since for now it works fine as-is.) - """ - - file_writer_class = JsonlWriter - supports_merge_insert = False # TODO: Add native implementation for merge insert - - @overrides - def _get_telemetry_info(self) -> CacheTelemetryInfo: - return CacheTelemetryInfo("postgres") +from airbyte._processors.sql.postgres import PostgresSqlProcessor +from airbyte.caches.base import CacheBase -class PostgresCache(SQLCacheBase, JsonlWriterConfig): +class PostgresCache(CacheBase): """Configuration for the Postgres cache. Also inherits config from the JsonlWriter, which is responsible for writing files to disk. @@ -42,7 +21,7 @@ class PostgresCache(SQLCacheBase, JsonlWriterConfig): password: str database: str - _sql_processor_class = PostgresCacheInstance + _sql_processor_class = PostgresSqlProcessor @overrides def get_sql_alchemy_url(self) -> str: diff --git a/airbyte/caches/snowflake.py b/airbyte/caches/snowflake.py index 3e1c69be..c272f4a9 100644 --- a/airbyte/caches/snowflake.py +++ b/airbyte/caches/snowflake.py @@ -1,132 +1,17 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. - """A Snowflake implementation of the cache.""" from __future__ import annotations -from textwrap import dedent, indent -from typing import TYPE_CHECKING - -import sqlalchemy from overrides import overrides -from snowflake.sqlalchemy import URL, VARIANT - -from airbyte._file_writers import JsonlWriter, JsonlWriterConfig -from airbyte.caches.base import ( - RecordDedupeMode, - SQLCacheBase, - SQLCacheInstanceBase, -) -from airbyte.telemetry import CacheTelemetryInfo -from airbyte.types import SQLTypeConverter - - -if TYPE_CHECKING: - from pathlib import Path - - from sqlalchemy.engine import Connection - - - -class SnowflakeTypeConverter(SQLTypeConverter): - """A class to convert types for Snowflake.""" - - @overrides - def to_sql_type( - self, - json_schema_property_def: dict[str, str | dict | list], - ) -> sqlalchemy.types.TypeEngine: - """Convert a value to a SQL type. - - We first call the parent class method to get the type. Then if the type JSON, we - replace it with VARIANT. - """ - sql_type = super().to_sql_type(json_schema_property_def) - if isinstance(sql_type, sqlalchemy.types.JSON): - return VARIANT() - - return sql_type - - -class SnowflakeSQLCacheInstance(SQLCacheInstanceBase): - """A Snowflake implementation of the cache. - - Parquet is used for local file storage before bulk loading. - """ - - file_writer_class = JsonlWriter - type_converter_class = SnowflakeTypeConverter - - @overrides - def _write_files_to_new_table( - self, - files: list[Path], - stream_name: str, - batch_id: str, - ) -> str: - """Write files to a new table.""" - temp_table_name = self._create_table_for_loading( - stream_name=stream_name, - batch_id=batch_id, - ) - internal_sf_stage_name = f"@%{temp_table_name}" - put_files_statements = "\n".join( - [ - f"PUT 'file://{file_path.absolute()!s}' {internal_sf_stage_name};" - for file_path in files - ] - ) - self._execute_sql(put_files_statements) - - columns_list = [ - self._quote_identifier(c) - for c in list(self._get_sql_column_definitions(stream_name).keys()) - ] - files_list = ", ".join([f"'{f.name}'" for f in files]) - columns_list_str: str = indent("\n, ".join(columns_list), " " * 12) - variant_cols_str: str = ("\n" + " " * 21 + ", ").join([f"$1:{col}" for col in columns_list]) - copy_statement = dedent( - f""" - COPY INTO {temp_table_name} - ( - {columns_list_str} - ) - FROM ( - SELECT {variant_cols_str} - FROM {internal_sf_stage_name} - ) - FILES = ( {files_list} ) - FILE_FORMAT = ( TYPE = JSON ) - ; - """ - ) - self._execute_sql(copy_statement) - return temp_table_name - - @overrides - def _init_connection_settings(self, connection: Connection) -> None: - """We set Snowflake-specific settings for the session. +from snowflake.sqlalchemy import URL - This sets QUOTED_IDENTIFIERS_IGNORE_CASE setting to True, which is necessary because - Snowflake otherwise will treat quoted table and column references as case-sensitive. - More info: https://docs.snowflake.com/en/sql-reference/identifiers-syntax - - This also sets MULTI_STATEMENT_COUNT to 0, which allows multi-statement commands. - """ - connection.execute( - """ - ALTER SESSION SET - QUOTED_IDENTIFIERS_IGNORE_CASE = TRUE - MULTI_STATEMENT_COUNT = 0 - """ - ) - - @overrides - def _get_telemetry_info(self) -> CacheTelemetryInfo: - return CacheTelemetryInfo("snowflake") +from airbyte._processors.sql.base import RecordDedupeMode +from airbyte._processors.sql.snowflake import SnowflakeSQLSqlProcessor +from airbyte.caches.base import CacheBase -class SnowflakeCache(SQLCacheBase, JsonlWriterConfig): +class SnowflakeCache(CacheBase): """Configuration for the Snowflake cache. Also inherits config from the JsonlWriterConfig, which is responsible for writing files to disk. @@ -141,7 +26,7 @@ class SnowflakeCache(SQLCacheBase, JsonlWriterConfig): dedupe_mode = RecordDedupeMode.APPEND - _sql_processor_class = SnowflakeSQLCacheInstance + _sql_processor_class = SnowflakeSQLSqlProcessor # Already defined in base class: # schema_name: str diff --git a/airbyte/config.py b/airbyte/config.py deleted file mode 100644 index 4fd4e603..00000000 --- a/airbyte/config.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. - -"""Define base Config interface, used by Caches and also File Writers (Processors).""" - -from __future__ import annotations - -from pydantic import BaseModel - - -class CacheConfigBase( - BaseModel -): # TODO: meta=EnforceOverrides (Pydantic doesn't like it currently.) - pass diff --git a/airbyte/datasets/_sql.py b/airbyte/datasets/_sql.py index eae1b407..c9dd19f0 100644 --- a/airbyte/datasets/_sql.py +++ b/airbyte/datasets/_sql.py @@ -1,4 +1,6 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""SQL datasets class.""" + from __future__ import annotations from collections.abc import Mapping @@ -17,7 +19,7 @@ from sqlalchemy import Selectable, Table from sqlalchemy.sql import ClauseElement - from airbyte.caches.base import SQLCacheBase + from airbyte.caches.base import CacheBase class SQLDataset(DatasetBase): @@ -29,12 +31,12 @@ class SQLDataset(DatasetBase): def __init__( self, - cache: SQLCacheBase, + cache: CacheBase, stream_name: str, query_statement: Selectable, ) -> None: self._length: int | None = None - self._cache: SQLCacheBase = cache + self._cache: CacheBase = cache self._stream_name: str = stream_name self._query_statement: Selectable = query_statement super().__init__() @@ -98,7 +100,7 @@ class CachedDataset(SQLDataset): underlying table as a SQLAlchemy Table object. """ - def __init__(self, cache: SQLCacheBase, stream_name: str) -> None: + def __init__(self, cache: CacheBase, stream_name: str) -> None: self._sql_table: Table = cache.processor.get_sql_table(stream_name) super().__init__( cache=cache, diff --git a/airbyte/results.py b/airbyte/results.py index ea06d514..9bb6cb78 100644 --- a/airbyte/results.py +++ b/airbyte/results.py @@ -12,12 +12,15 @@ from sqlalchemy.engine import Engine - from airbyte.caches import SQLCacheBase + from airbyte.caches import CacheBase class ReadResult(Mapping[str, CachedDataset]): def __init__( - self, processed_records: int, cache: SQLCacheBase, processed_streams: list[str] + self, + processed_records: int, + cache: CacheBase, + processed_streams: list[str], ) -> None: self.processed_records = processed_records self._cache = cache @@ -52,5 +55,5 @@ def streams(self) -> Mapping[str, CachedDataset]: } @property - def cache(self) -> SQLCacheBase: + def cache(self) -> CacheBase: return self._cache diff --git a/airbyte/source.py b/airbyte/source.py index f4e051e0..85ec00e6 100644 --- a/airbyte/source.py +++ b/airbyte/source.py @@ -46,7 +46,7 @@ from collections.abc import Generator, Iterable, Iterator from airbyte._executor import Executor - from airbyte.caches import SQLCacheBase + from airbyte.caches import CacheBase @contextmanager @@ -435,7 +435,7 @@ def _read_with_catalog( * Send out telemetry on the performed sync (with information about which source was used and the type of the cache) """ - source_tracking_information = self.executor._get_telemetry_info() + source_tracking_information = self.executor._get_telemetry_info() # noqa: SLF001 send_telemetry(source_tracking_information, cache_info, SyncState.STARTED) sync_failed = False self._processed_records = 0 # Reset the counter before we start @@ -522,7 +522,7 @@ def _tally_records( def read( self, - cache: SQLCacheBase | None = None, + cache: CacheBase | None = None, *, streams: str | list[str] | None = None, write_strategy: str | WriteStrategy = WriteStrategy.AUTO, @@ -579,8 +579,11 @@ def read( incoming_source_catalog=self.configured_catalog, stream_names=set(self._selected_stream_names), ) + if not cache.processor._catalog_manager: # noqa: SLF001 + raise exc.AirbyteLibInternalError(message="Catalog manager should exist but does not.") + state = ( - cache.processor._catalog_manager.get_state( + cache.processor._catalog_manager.get_state( # noqa: SLF001 source_name=self.name, streams=self._selected_stream_names, ) @@ -591,7 +594,7 @@ def read( cache.processor.process_airbyte_messages( self._tally_records( self._read( - cache.processor._get_telemetry_info(), + cache.processor._get_telemetry_info(), # noqa: SLF001 state=state, ), ), diff --git a/tests/integration_tests/test_snowflake_cache.py b/tests/integration_tests/test_snowflake_cache.py index c6f196cb..f86756e2 100644 --- a/tests/integration_tests/test_snowflake_cache.py +++ b/tests/integration_tests/test_snowflake_cache.py @@ -123,7 +123,7 @@ def test_replace_strategy( def test_merge_strategy( source_faker_seed_a: ab.Source, source_faker_seed_b: ab.Source, - snowflake_cache: ab.DuckDBCacheInstance, + snowflake_cache: ab.DuckDBSqlProcessor, ) -> None: """Test that the merge strategy works as expected. diff --git a/tests/integration_tests/test_source_faker_integration.py b/tests/integration_tests/test_source_faker_integration.py index 8ea257a5..087f1a30 100644 --- a/tests/integration_tests/test_source_faker_integration.py +++ b/tests/integration_tests/test_source_faker_integration.py @@ -141,7 +141,7 @@ def test_faker_pks( @pytest.mark.slow def test_replace_strategy( source_faker_seed_a: ab.Source, - all_cache_types: ab.DuckDBCacheInstance, + all_cache_types: ab.DuckDBSqlProcessor, ) -> None: """Test that the append strategy works as expected.""" for cache in all_cache_types: # Function-scoped fixtures can't be used in parametrized(). @@ -156,7 +156,7 @@ def test_replace_strategy( @pytest.mark.slow def test_append_strategy( source_faker_seed_a: ab.Source, - all_cache_types: ab.DuckDBCacheInstance, + all_cache_types: ab.DuckDBSqlProcessor, ) -> None: """Test that the append strategy works as expected.""" for cache in all_cache_types: # Function-scoped fixtures can't be used in parametrized(). @@ -172,7 +172,7 @@ def test_merge_strategy( strategy: str, source_faker_seed_a: ab.Source, source_faker_seed_b: ab.Source, - all_cache_types: ab.DuckDBCacheInstance, + all_cache_types: ab.DuckDBSqlProcessor, ) -> None: """Test that the merge strategy works as expected. @@ -206,7 +206,7 @@ def test_merge_strategy( def test_incremental_sync( source_faker_seed_a: ab.Source, source_faker_seed_b: ab.Source, - duckdb_cache: ab.DuckDBCacheInstance, + duckdb_cache: ab.DuckDBSqlProcessor, ) -> None: config_a = source_faker_seed_a.get_config() config_b = source_faker_seed_b.get_config() diff --git a/tests/integration_tests/test_source_test_fixture.py b/tests/integration_tests/test_source_test_fixture.py index ae4424a8..953fc1ce 100644 --- a/tests/integration_tests/test_source_test_fixture.py +++ b/tests/integration_tests/test_source_test_fixture.py @@ -9,7 +9,7 @@ from unittest.mock import Mock, call, patch import tempfile from pathlib import Path -from airbyte.caches.base import SQLCacheInstanceBase +from airbyte._processors.sql.base import SqlProcessorBase from sqlalchemy import column, text @@ -202,7 +202,7 @@ def test_file_write_and_cleanup() -> None: assert len(list(Path(temp_dir_2).glob("*.jsonl.gz"))) == 2, "Expected files to exist" -def assert_cache_data(expected_test_stream_data: dict[str, list[dict[str, str | int]]], cache: SQLCacheInstanceBase, streams: list[str] = None): +def assert_cache_data(expected_test_stream_data: dict[str, list[dict[str, str | int]]], cache: SqlProcessorBase, streams: list[str] = None): for stream_name in streams or expected_test_stream_data.keys(): if len(cache[stream_name]) > 0: pd.testing.assert_frame_equal( @@ -292,13 +292,13 @@ def test_read_isolated_by_prefix(expected_test_stream_data: dict[str, list[dict[ db_path = Path(f"./.cache/{cache_name}.duckdb") source = ab.get_source("source-test", config={"apiKey": "test"}) source.select_all_streams() - cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix="prefix_")) + cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix="prefix_")) source.read(cache) - same_prefix_cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix="prefix_")) - different_prefix_cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix="different_prefix_")) - no_prefix_cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix=None)) + same_prefix_cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix="prefix_")) + different_prefix_cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix="different_prefix_")) + no_prefix_cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix=None)) # validate that the cache with the same prefix has the data as expected, while the other two are empty assert_cache_data(expected_test_stream_data, same_prefix_cache) @@ -310,9 +310,9 @@ def test_read_isolated_by_prefix(expected_test_stream_data: dict[str, list[dict[ source.read(different_prefix_cache) source.read(no_prefix_cache) - second_same_prefix_cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix="prefix_")) - second_different_prefix_cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix="different_prefix_")) - second_no_prefix_cache = ab.DuckDBCacheInstance(config=ab.DuckDBCache(db_path=db_path, table_prefix=None)) + second_same_prefix_cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix="prefix_")) + second_different_prefix_cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix="different_prefix_")) + second_no_prefix_cache = ab.DuckDBSqlProcessor(config=ab.DuckDBCache(db_path=db_path, table_prefix=None)) # validate that the first cache still has full data, while the other two have partial data assert_cache_data(expected_test_stream_data, second_same_prefix_cache) @@ -423,7 +423,7 @@ def test_cached_dataset( not_a_stream_name = "not_a_stream" # Check that the stream appears in mapping-like attributes - assert stream_name in result.cache._streams_with_data + assert stream_name in result.cache._expected_streams assert stream_name in result assert stream_name in result.cache assert stream_name in result.cache.streams diff --git a/tests/unit_tests/test_caches.py b/tests/unit_tests/test_caches.py index 3b9f0fa0..2c75eed0 100644 --- a/tests/unit_tests/test_caches.py +++ b/tests/unit_tests/test_caches.py @@ -4,10 +4,10 @@ import pytest -from airbyte.caches.base import SQLCacheInstanceBase, SQLCacheBase +from airbyte.caches.base import CacheBase from airbyte.caches.duckdb import DuckDBCache -from airbyte._file_writers import JsonlWriterConfig -from airbyte.caches.base import SQLCacheBase +from airbyte._processors.file.jsonl import JsonlWriterConfig +from airbyte.caches.base import CacheBase def test_duck_db_cache_config_initialization(): @@ -28,7 +28,7 @@ def test_get_sql_alchemy_url_with_default_schema_name(): assert config.get_sql_alchemy_url() == 'duckdb:///test_path' def test_duck_db_cache_config_inheritance(): - assert issubclass(DuckDBCache, SQLCacheBase) + assert issubclass(DuckDBCache, CacheBase) assert issubclass(DuckDBCache, JsonlWriterConfig) def test_duck_db_cache_config_get_sql_alchemy_url(): @@ -40,7 +40,7 @@ def test_duck_db_cache_config_get_database_name(): assert config.get_database_name() == 'test_db' def test_duck_db_cache_base_inheritance(): - assert issubclass(DuckDBCache, SQLCacheBase) + assert issubclass(DuckDBCache, CacheBase) def test_duck_db_cache_config_default_schema_name(): config = DuckDBCache(db_path='test_path') @@ -55,7 +55,7 @@ def test_duck_db_cache_config_get_database_name_with_default_schema_name(): assert config.get_database_name() == 'test_db' def test_duck_db_cache_config_inheritance_from_sql_cache_config_base(): - assert issubclass(DuckDBCache, SQLCacheBase) + assert issubclass(DuckDBCache, CacheBase) def test_duck_db_cache_config_inheritance_from_parquet_writer_config(): assert issubclass(DuckDBCache, JsonlWriterConfig) diff --git a/tests/unit_tests/test_writers.py b/tests/unit_tests/test_writers.py index 460b1c38..be707e10 100644 --- a/tests/unit_tests/test_writers.py +++ b/tests/unit_tests/test_writers.py @@ -2,8 +2,8 @@ from pathlib import Path import pytest -from airbyte._file_writers.base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase -from airbyte._file_writers.parquet import ParquetWriter, ParquetWriterConfig +from airbyte.file.base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase +from airbyte.file.parquet import ParquetWriter, ParquetWriterConfig from numpy import source