diff --git a/airbyte-lib/airbyte_lib/datasets/_lazy.py b/airbyte-lib/airbyte_lib/datasets/_lazy.py index a8bb8173711d..83d67cec0043 100644 --- a/airbyte-lib/airbyte_lib/datasets/_lazy.py +++ b/airbyte-lib/airbyte_lib/datasets/_lazy.py @@ -20,6 +20,7 @@ def __init__( iterator: Iterator[Mapping[str, Any]], ) -> None: self._iterator: Iterator[Mapping[str, Any]] = iterator + super().__init__() @overrides def __iter__(self) -> Iterator[Mapping[str, Any]]: diff --git a/airbyte-lib/airbyte_lib/datasets/_sql.py b/airbyte-lib/airbyte_lib/datasets/_sql.py index c6195dc6e280..7dfb22482146 100644 --- a/airbyte-lib/airbyte_lib/datasets/_sql.py +++ b/airbyte-lib/airbyte_lib/datasets/_sql.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, cast from overrides import overrides -from sqlalchemy import and_, text +from sqlalchemy import and_, func, select, text from airbyte_lib.datasets._base import DatasetBase @@ -33,9 +33,11 @@ def __init__( stream_name: str, query_statement: Selectable, ) -> None: + self._length: int | None = None self._cache: SQLCacheBase = cache self._stream_name: str = stream_name self._query_statement: Selectable = query_statement + super().__init__() @property def stream_name(self) -> str: @@ -48,6 +50,18 @@ def __iter__(self) -> Iterator[Mapping[str, Any]]: # https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001 + def __len__(self) -> int: + """Return the number of records in the dataset. + + This method caches the length of the dataset after the first call. + """ + if self._length is None: + count_query = select([func.count()]).select_from(self._query_statement.alias()) + with self._cache.get_sql_connection() as conn: + self._length = conn.execute(count_query).scalar() + + return self._length + def to_pandas(self) -> DataFrame: return self._cache.get_pandas_dataframe(self._stream_name) @@ -85,16 +99,19 @@ class CachedDataset(SQLDataset): """ def __init__(self, cache: SQLCacheBase, stream_name: str) -> None: - self._cache: SQLCacheBase = cache - self._stream_name: str = stream_name - self._query_statement: Selectable = self.to_sql_table().select() + self._sql_table: Table = cache.get_sql_table(stream_name) + super().__init__( + cache=cache, + stream_name=stream_name, + query_statement=self._sql_table.select(), + ) @overrides def to_pandas(self) -> DataFrame: return self._cache.get_pandas_dataframe(self._stream_name) def to_sql_table(self) -> Table: - return self._cache.get_sql_table(self._stream_name) + return self._sql_table def __eq__(self, value: object) -> bool: """Return True if the value is a CachedDataset with the same cache and stream name. diff --git a/airbyte-lib/airbyte_lib/results.py b/airbyte-lib/airbyte_lib/results.py index 18861629ba24..5c5021fc8afc 100644 --- a/airbyte-lib/airbyte_lib/results.py +++ b/airbyte-lib/airbyte_lib/results.py @@ -1,20 +1,21 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. from __future__ import annotations +from collections.abc import Mapping from typing import TYPE_CHECKING from airbyte_lib.datasets import CachedDataset if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Iterator from sqlalchemy.engine import Engine from airbyte_lib.caches import SQLCacheBase -class ReadResult: +class ReadResult(Mapping[str, CachedDataset]): def __init__( self, processed_records: int, cache: SQLCacheBase, processed_streams: list[str] ) -> None: @@ -28,12 +29,18 @@ def __getitem__(self, stream: str) -> CachedDataset: return CachedDataset(self._cache, stream) - def __contains__(self, stream: str) -> bool: + def __contains__(self, stream: object) -> bool: + if not isinstance(stream, str): + return False + return stream in self._processed_streams def __iter__(self) -> Iterator[str]: return self._processed_streams.__iter__() + def __len__(self) -> int: + return len(self._processed_streams) + def get_sql_engine(self) -> Engine: return self._cache.get_sql_engine() diff --git a/airbyte-lib/docs/generated/airbyte_lib.html b/airbyte-lib/docs/generated/airbyte_lib.html index ba7c11e54e3d..59867359c01e 100644 --- a/airbyte-lib/docs/generated/airbyte_lib.html +++ b/airbyte-lib/docs/generated/airbyte_lib.html @@ -325,13 +325,19 @@
Inherited Members
class - ReadResult: + ReadResult(collections.abc.Mapping[str, airbyte_lib.datasets._sql.CachedDataset]):
- +

A Mapping is a generic container for associating key/value +pairs.

+ +

This class provides concrete generic implementations of all +methods except for __getitem__, __iter__, and __len__.

+
+
@@ -390,6 +396,18 @@
Inherited Members
+
+
+
Inherited Members
+
+
collections.abc.Mapping
+
get
+
keys
+
items
+
values
+ +
+
diff --git a/airbyte-lib/tests/integration_tests/test_integration.py b/airbyte-lib/tests/integration_tests/test_integration.py index fbe10e57bd70..a122df84899e 100644 --- a/airbyte-lib/tests/integration_tests/test_integration.py +++ b/airbyte-lib/tests/integration_tests/test_integration.py @@ -217,6 +217,35 @@ def test_sync_to_duckdb(expected_test_stream_data: dict[str, list[dict[str, str assert_cache_data(expected_test_stream_data, cache) +def test_read_result_mapping(): + source = ab.get_connector("source-test", config={"apiKey": "test"}) + result: ReadResult = source.read() + assert len(result) == 2 + assert isinstance(result, Mapping) + assert "stream1" in result + assert "stream2" in result + assert "stream3" not in result + assert result.keys() == {"stream1", "stream2"} + + +def test_dataset_list_and_len(expected_test_stream_data): + source = ab.get_connector("source-test", config={"apiKey": "test"}) + result: ReadResult = source.read() + stream_1 = result["stream1"] + assert len(stream_1) == 2 + assert len(list(stream_1)) == 2 + # Make sure we can iterate over the stream after calling len + assert list(stream_1) == [{"column1": "value1", "column2": 1}, {"column1": "value2", "column2": 2}] + # Make sure we can iterate over the stream a second time + assert list(stream_1) == [{"column1": "value1", "column2": 1}, {"column1": "value2", "column2": 2}] + + assert isinstance(result, Mapping) + assert "stream1" in result + assert "stream2" in result + assert "stream3" not in result + assert result.keys() == {"stream1", "stream2"} + + def test_read_from_cache(expected_test_stream_data: dict[str, list[dict[str, str | int]]]): """ Test that we can read from a cache that already has data (identigier by name)