-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CDK: Embedded reader utils #28873
CDK: Embedded reader utils #28873
Changes from all commits
35fe28a
b96bb29
36874e3
7bc3fb0
abbc620
fd28a42
e289972
3431345
34b5e9d
c60f3ee
8b675f5
c27fd9b
37f55bd
b23abc4
6aa5b13
eeced02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# | ||
# Copyright (c) 2021 Airbyte, Inc., all rights reserved. | ||
# |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Generic, Iterable, Optional, TypeVar | ||
|
||
from airbyte_cdk.connector import TConfig | ||
from airbyte_cdk.sources.embedded.catalog import create_configured_catalog, get_stream, get_stream_names | ||
from airbyte_cdk.sources.embedded.runner import SourceRunner | ||
from airbyte_cdk.sources.embedded.tools import get_defined_id | ||
from airbyte_protocol.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type | ||
|
||
TOutput = TypeVar("TOutput") | ||
|
||
|
||
class BaseEmbeddedIntegration(ABC, Generic[TConfig, TOutput]): | ||
def __init__(self, runner: SourceRunner[TConfig], config: TConfig): | ||
self.source = runner | ||
self.config = config | ||
|
||
self.last_state: Optional[AirbyteStateMessage] = None | ||
|
||
@abstractmethod | ||
def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Optional[TOutput]: | ||
""" | ||
Turn an Airbyte record into the appropriate output type for the integration. | ||
""" | ||
pass | ||
|
||
def _load_data(self, stream_name: str, state: Optional[AirbyteStateMessage] = None) -> Iterable[TOutput]: | ||
catalog = self.source.discover(self.config) | ||
stream = get_stream(catalog, stream_name) | ||
if not stream: | ||
raise ValueError(f"Stream {stream_name} not found, the following streams are available: {', '.join(get_stream_names(catalog))}") | ||
if SyncMode.incremental not in stream.supported_sync_modes: | ||
configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.full_refresh) | ||
else: | ||
configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.incremental) | ||
|
||
for message in self.source.read(self.config, configured_catalog, state): | ||
if message.type == Type.RECORD: | ||
output = self._handle_record(message.record, get_defined_id(stream, message.record.data)) | ||
if output: | ||
yield output | ||
elif message.type is Type.STATE and message.state: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if it matters for embedded, but we generally need to output state messages to destination so they can checkpoint There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As it's just in-memory I think it's OK - if they want the guarantees from checkpointing, they can use hosted Airbyte (not trying to re-build the whole platform here :) ) |
||
self.last_state = message.state |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from typing import List, Optional | ||
|
||
from airbyte_cdk.models import ( | ||
AirbyteCatalog, | ||
AirbyteStream, | ||
ConfiguredAirbyteCatalog, | ||
ConfiguredAirbyteStream, | ||
DestinationSyncMode, | ||
SyncMode, | ||
) | ||
from airbyte_cdk.sources.embedded.tools import get_first | ||
|
||
|
||
def get_stream(catalog: AirbyteCatalog, stream_name: str) -> Optional[AirbyteStream]: | ||
return get_first(catalog.streams, lambda s: s.name == stream_name) | ||
|
||
|
||
def get_stream_names(catalog: AirbyteCatalog) -> List[str]: | ||
return [stream.name for stream in catalog.streams] | ||
|
||
|
||
def to_configured_stream( | ||
stream: AirbyteStream, | ||
sync_mode: SyncMode = SyncMode.full_refresh, | ||
destination_sync_mode: DestinationSyncMode = DestinationSyncMode.append, | ||
cursor_field: Optional[List[str]] = None, | ||
primary_key: Optional[List[List[str]]] = None, | ||
) -> ConfiguredAirbyteStream: | ||
return ConfiguredAirbyteStream( | ||
stream=stream, sync_mode=sync_mode, destination_sync_mode=destination_sync_mode, cursor_field=cursor_field, primary_key=primary_key | ||
) | ||
|
||
|
||
def to_configured_catalog(configured_streams: List[ConfiguredAirbyteStream]) -> ConfiguredAirbyteCatalog: | ||
return ConfiguredAirbyteCatalog(streams=configured_streams) | ||
|
||
|
||
def create_configured_catalog(stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh) -> ConfiguredAirbyteCatalog: | ||
configured_streams = [to_configured_stream(stream, sync_mode=sync_mode, primary_key=stream.source_defined_primary_key)] | ||
|
||
return to_configured_catalog(configured_streams) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, Iterable, Optional | ||
|
||
from airbyte_cdk.connector import TConfig | ||
from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog | ||
from airbyte_cdk.sources.source import Source | ||
|
||
|
||
class SourceRunner(ABC, Generic[TConfig]): | ||
@abstractmethod | ||
def discover(self, config: TConfig) -> AirbyteCatalog: | ||
pass | ||
|
||
@abstractmethod | ||
def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]: | ||
pass | ||
|
||
|
||
class CDKRunner(SourceRunner[TConfig]): | ||
def __init__(self, source: Source, name: str): | ||
self._source = source | ||
self._logger = logging.getLogger(name) | ||
|
||
def discover(self, config: TConfig) -> AirbyteCatalog: | ||
return self._source.discover(self._logger, config) | ||
|
||
def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]: | ||
return self._source.read(self._logger, config, catalog, state=[state] if state else []) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from typing import Any, Callable, Dict, Iterable, Optional | ||
|
||
import dpath | ||
from airbyte_cdk.models import AirbyteStream | ||
|
||
|
||
def get_first(iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True) -> Optional[Any]: | ||
return next(filter(predicate, iterable), None) | ||
|
||
|
||
def get_defined_id(stream: AirbyteStream, data: Dict[str, Any]) -> Optional[str]: | ||
if not stream.source_defined_primary_key: | ||
return None | ||
primary_key = [] | ||
for key in stream.source_defined_primary_key: | ||
try: | ||
primary_key.append(str(dpath.util.get(data, key))) | ||
except KeyError: | ||
primary_key.append("__not_found__") | ||
return "_".join(primary_key) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
import unittest | ||
from typing import Any, Mapping, Optional | ||
from unittest.mock import MagicMock | ||
|
||
from airbyte_cdk.sources.embedded.base_integration import BaseEmbeddedIntegration | ||
from airbyte_protocol.models import ( | ||
AirbyteCatalog, | ||
AirbyteLogMessage, | ||
AirbyteMessage, | ||
AirbyteRecordMessage, | ||
AirbyteStateMessage, | ||
AirbyteStream, | ||
ConfiguredAirbyteCatalog, | ||
ConfiguredAirbyteStream, | ||
DestinationSyncMode, | ||
Level, | ||
SyncMode, | ||
Type, | ||
) | ||
|
||
|
||
class TestIntegration(BaseEmbeddedIntegration): | ||
def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Mapping[str, Any]: | ||
return {"data": record.data, "id": id} | ||
|
||
|
||
class EmbeddedIntegrationTestCase(unittest.TestCase): | ||
def setUp(self): | ||
self.source_class = MagicMock() | ||
self.source = MagicMock() | ||
self.source_class.return_value = self.source | ||
self.config = MagicMock() | ||
self.integration = TestIntegration(self.source, self.config) | ||
self.stream1 = AirbyteStream( | ||
name="test", | ||
source_defined_primary_key=[["test"]], | ||
json_schema={}, | ||
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], | ||
) | ||
self.stream2 = AirbyteStream(name="test2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh]) | ||
self.source.discover.return_value = AirbyteCatalog(streams=[self.stream2, self.stream1]) | ||
|
||
def test_integration(self): | ||
self.source.read.return_value = [ | ||
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")), | ||
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)), | ||
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 2}, emitted_at=2)), | ||
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 3}, emitted_at=3)), | ||
] | ||
result = list(self.integration._load_data("test", None)) | ||
self.assertEqual( | ||
result, | ||
[ | ||
{"data": {"test": 1}, "id": "1"}, | ||
{"data": {"test": 2}, "id": "2"}, | ||
{"data": {"test": 3}, "id": "3"}, | ||
], | ||
) | ||
self.source.discover.assert_called_once_with(self.config) | ||
self.source.read.assert_called_once_with( | ||
self.config, | ||
ConfiguredAirbyteCatalog( | ||
streams=[ | ||
ConfiguredAirbyteStream( | ||
stream=self.stream1, | ||
sync_mode=SyncMode.incremental, | ||
destination_sync_mode=DestinationSyncMode.append, | ||
primary_key=[["test"]], | ||
) | ||
] | ||
), | ||
None, | ||
) | ||
|
||
def test_state(self): | ||
state = AirbyteStateMessage(data={}) | ||
self.source.read.return_value = [ | ||
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")), | ||
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)), | ||
AirbyteMessage(type=Type.STATE, state=state), | ||
] | ||
result = list(self.integration._load_data("test", None)) | ||
self.assertEqual( | ||
result, | ||
[ | ||
{"data": {"test": 1}, "id": "1"}, | ||
], | ||
) | ||
self.integration.last_state = state | ||
|
||
def test_incremental(self): | ||
state = AirbyteStateMessage(data={}) | ||
list(self.integration._load_data("test", state)) | ||
self.source.read.assert_called_once_with( | ||
self.config, | ||
ConfiguredAirbyteCatalog( | ||
streams=[ | ||
ConfiguredAirbyteStream( | ||
stream=self.stream1, | ||
sync_mode=SyncMode.incremental, | ||
destination_sync_mode=DestinationSyncMode.append, | ||
primary_key=[["test"]], | ||
) | ||
] | ||
), | ||
state, | ||
) | ||
|
||
def test_incremental_without_state(self): | ||
list(self.integration._load_data("test")) | ||
self.source.read.assert_called_once_with( | ||
self.config, | ||
ConfiguredAirbyteCatalog( | ||
streams=[ | ||
ConfiguredAirbyteStream( | ||
stream=self.stream1, | ||
sync_mode=SyncMode.incremental, | ||
destination_sync_mode=DestinationSyncMode.append, | ||
primary_key=[["test"]], | ||
) | ||
] | ||
), | ||
None, | ||
) | ||
|
||
def test_incremental_unsupported(self): | ||
state = AirbyteStateMessage(data={}) | ||
list(self.integration._load_data("test2", state)) | ||
self.source.read.assert_called_once_with( | ||
self.config, | ||
ConfiguredAirbyteCatalog( | ||
streams=[ | ||
ConfiguredAirbyteStream( | ||
stream=self.stream2, | ||
sync_mode=SyncMode.full_refresh, | ||
destination_sync_mode=DestinationSyncMode.append, | ||
) | ||
] | ||
), | ||
state, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method probably shouldn't be private since it's used by child classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understood the underscore as meaning protected, not private (https://jellis18.github.io/post/2022-01-15-access-modifiers-python/#:~:text=Private%20vs%20Protected,and%20not%20a%20private%20value.) - ChatGPT agrees with me as well ;) I'm going to leave it like this for now as you shouldn't call it on the integration from the outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting. I guess we don't have many private fields in our codebase 🤷