Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AirbyteLib: Add len() support on SQL datasets and Mapping behaviors for ReadResult (#34763) #34763

Merged
merged 8 commits into from
Feb 2, 2024
13 changes: 13 additions & 0 deletions airbyte-lib/airbyte_lib/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,24 @@
class DatasetBase(ABC):
"""Base implementation for all datasets."""

def __init__(self) -> None:
self._length: int | None = None

@abstractmethod
def __iter__(self) -> Iterator[Mapping[str, Any]]:
"""Return the iterator of records."""
raise NotImplementedError

def __len__(self) -> int:
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved
"""Return the number of records in the dataset.

This method caches the length of the dataset after the first call.
"""
if self._length is None:
self._length = sum(1 for _ in self)

return self._length

def to_pandas(self) -> DataFrame:
"""Return a pandas DataFrame representation of the dataset.

Expand Down
1 change: 1 addition & 0 deletions airbyte-lib/airbyte_lib/datasets/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
iterator: Iterator[Mapping[str, Any]],
) -> None:
self._iterator: Iterator[Mapping[str, Any]] = iterator
super().__init__()

@overrides
def __iter__(self) -> Iterator[Mapping[str, Any]]:
Expand Down
18 changes: 14 additions & 4 deletions airbyte-lib/airbyte_lib/datasets/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, cast

from overrides import overrides
from sqlalchemy import and_, text
from sqlalchemy import and_, func, select, text

from airbyte_lib.datasets._base import DatasetBase

Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
self._cache: SQLCacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = query_statement
super().__init__()

@property
def stream_name(self) -> str:
Expand All @@ -48,6 +49,13 @@ def __iter__(self) -> Iterator[Mapping[str, Any]]:
# https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html
yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001

def __len__(self) -> int:
if self._length is None:
count_query = select([func.count()]).select_from(self._query_statement.alias())
with self._cache.get_sql_connection() as conn:
self._length = conn.execute(count_query).scalar()
return self._length

def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

Expand Down Expand Up @@ -85,9 +93,11 @@ class CachedDataset(SQLDataset):
"""

def __init__(self, cache: SQLCacheBase, stream_name: str) -> None:
self._cache: SQLCacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = self.to_sql_table().select()
super().__init__(
cache=cache,
stream_name=stream_name,
query_statement=self.to_sql_table().select(),
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved
)

@overrides
def to_pandas(self) -> DataFrame:
Expand Down
8 changes: 6 additions & 2 deletions airbyte-lib/airbyte_lib/results.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING

from airbyte_lib.datasets import CachedDataset


if TYPE_CHECKING:
from collections.abc import Iterator, Mapping
from collections.abc import Iterator

from sqlalchemy.engine import Engine

from airbyte_lib.caches import SQLCacheBase


class ReadResult:
class ReadResult(Mapping[str, CachedDataset]):
def __init__(
self, processed_records: int, cache: SQLCacheBase, processed_streams: list[str]
) -> None:
Expand All @@ -34,6 +35,9 @@ def __contains__(self, stream: str) -> bool:
def __iter__(self) -> Iterator[str]:
return self._processed_streams.__iter__()

def __len__(self) -> int:
return len(self._processed_streams)

def get_sql_engine(self) -> Engine:
return self._cache.get_sql_engine()

Expand Down
Loading