Skip to content

Commit

Permalink
AirbyteLib: Add len() support on SQL datasets and Mapping behaviors f…
Browse files Browse the repository at this point in the history
…or ReadResult (airbytehq#34763)
  • Loading branch information
aaronsteers authored and jatinyadav-cc committed Feb 26, 2024
1 parent 5e5efc1 commit 2e50dfc
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 10 deletions.
1 change: 1 addition & 0 deletions airbyte-lib/airbyte_lib/datasets/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
iterator: Iterator[Mapping[str, Any]],
) -> None:
self._iterator: Iterator[Mapping[str, Any]] = iterator
super().__init__()

@overrides
def __iter__(self) -> Iterator[Mapping[str, Any]]:
Expand Down
27 changes: 22 additions & 5 deletions airbyte-lib/airbyte_lib/datasets/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, cast

from overrides import overrides
from sqlalchemy import and_, text
from sqlalchemy import and_, func, select, text

from airbyte_lib.datasets._base import DatasetBase

Expand Down Expand Up @@ -33,9 +33,11 @@ def __init__(
stream_name: str,
query_statement: Selectable,
) -> None:
self._length: int | None = None
self._cache: SQLCacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = query_statement
super().__init__()

@property
def stream_name(self) -> str:
Expand All @@ -48,6 +50,18 @@ def __iter__(self) -> Iterator[Mapping[str, Any]]:
# https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html
yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001

def __len__(self) -> int:
"""Return the number of records in the dataset.
This method caches the length of the dataset after the first call.
"""
if self._length is None:
count_query = select([func.count()]).select_from(self._query_statement.alias())
with self._cache.get_sql_connection() as conn:
self._length = conn.execute(count_query).scalar()

return self._length

def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

Expand Down Expand Up @@ -85,16 +99,19 @@ class CachedDataset(SQLDataset):
"""

def __init__(self, cache: SQLCacheBase, stream_name: str) -> None:
self._cache: SQLCacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = self.to_sql_table().select()
self._sql_table: Table = cache.get_sql_table(stream_name)
super().__init__(
cache=cache,
stream_name=stream_name,
query_statement=self._sql_table.select(),
)

@overrides
def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

def to_sql_table(self) -> Table:
return self._cache.get_sql_table(self._stream_name)
return self._sql_table

def __eq__(self, value: object) -> bool:
"""Return True if the value is a CachedDataset with the same cache and stream name.
Expand Down
13 changes: 10 additions & 3 deletions airbyte-lib/airbyte_lib/results.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING

from airbyte_lib.datasets import CachedDataset


if TYPE_CHECKING:
from collections.abc import Iterator, Mapping
from collections.abc import Iterator

from sqlalchemy.engine import Engine

from airbyte_lib.caches import SQLCacheBase


class ReadResult:
class ReadResult(Mapping[str, CachedDataset]):
def __init__(
self, processed_records: int, cache: SQLCacheBase, processed_streams: list[str]
) -> None:
Expand All @@ -28,12 +29,18 @@ def __getitem__(self, stream: str) -> CachedDataset:

return CachedDataset(self._cache, stream)

def __contains__(self, stream: str) -> bool:
def __contains__(self, stream: object) -> bool:
if not isinstance(stream, str):
return False

return stream in self._processed_streams

def __iter__(self) -> Iterator[str]:
return self._processed_streams.__iter__()

def __len__(self) -> int:
return len(self._processed_streams)

def get_sql_engine(self) -> Engine:
return self._cache.get_sql_engine()

Expand Down
22 changes: 20 additions & 2 deletions airbyte-lib/docs/generated/airbyte_lib.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions airbyte-lib/tests/integration_tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,35 @@ def test_sync_to_duckdb(expected_test_stream_data: dict[str, list[dict[str, str
assert_cache_data(expected_test_stream_data, cache)


def test_read_result_mapping():
source = ab.get_connector("source-test", config={"apiKey": "test"})
result: ReadResult = source.read()
assert len(result) == 2
assert isinstance(result, Mapping)
assert "stream1" in result
assert "stream2" in result
assert "stream3" not in result
assert result.keys() == {"stream1", "stream2"}


def test_dataset_list_and_len(expected_test_stream_data):
source = ab.get_connector("source-test", config={"apiKey": "test"})
result: ReadResult = source.read()
stream_1 = result["stream1"]
assert len(stream_1) == 2
assert len(list(stream_1)) == 2
# Make sure we can iterate over the stream after calling len
assert list(stream_1) == [{"column1": "value1", "column2": 1}, {"column1": "value2", "column2": 2}]
# Make sure we can iterate over the stream a second time
assert list(stream_1) == [{"column1": "value1", "column2": 1}, {"column1": "value2", "column2": 2}]

assert isinstance(result, Mapping)
assert "stream1" in result
assert "stream2" in result
assert "stream3" not in result
assert result.keys() == {"stream1", "stream2"}


def test_read_from_cache(expected_test_stream_data: dict[str, list[dict[str, str | int]]]):
"""
Test that we can read from a cache that already has data (identigier by name)
Expand Down

0 comments on commit 2e50dfc

Please sign in to comment.