Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDK: Embedded reader utils #28873

Merged
merged 16 commits into from
Aug 3, 2023
3 changes: 3 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from typing import Generic, Iterable, Optional, TypeVar

from airbyte_cdk.connector import TConfig
from airbyte_cdk.sources.embedded.catalog import create_configured_catalog, get_stream, get_stream_names
from airbyte_cdk.sources.embedded.runner import SourceRunner
from airbyte_cdk.sources.embedded.tools import get_defined_id
from airbyte_protocol.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type

TOutput = TypeVar("TOutput")


class BaseEmbeddedIntegration(ABC, Generic[TConfig, TOutput]):
def __init__(self, runner: SourceRunner[TConfig], config: TConfig):
self.source = runner
self.config = config

self.last_state: Optional[AirbyteStateMessage] = None

@abstractmethod
def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Optional[TOutput]:
"""
Turn an Airbyte record into the appropriate output type for the integration.
"""
pass

def _load_data(self, stream_name: str, state: Optional[AirbyteStateMessage] = None) -> Iterable[TOutput]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method probably shouldn't be private since it's used by child classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood the underscore as meaning protected, not private (https://jellis18.github.io/post/2022-01-15-access-modifiers-python/#:~:text=Private%20vs%20Protected,and%20not%20a%20private%20value.) - ChatGPT agrees with me as well ;) I'm going to leave it like this for now as you shouldn't call it on the integration from the outside.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. I guess we don't have many private fields in our codebase 🤷

catalog = self.source.discover(self.config)
stream = get_stream(catalog, stream_name)
if not stream:
raise ValueError(f"Stream {stream_name} not found, the following streams are available: {', '.join(get_stream_names(catalog))}")
if SyncMode.incremental not in stream.supported_sync_modes:
configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.full_refresh)
else:
configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.incremental)

for message in self.source.read(self.config, configured_catalog, state):
if message.type == Type.RECORD:
output = self._handle_record(message.record, get_defined_id(stream, message.record.data))
if output:
yield output
elif message.type is Type.STATE and message.state:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it matters for embedded, but we generally need to output state messages to destination so they can checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it's just in-memory I think it's OK - if they want the guarantees from checkpointing, they can use hosted Airbyte (not trying to re-build the whole platform here :) )

self.last_state = message.state
45 changes: 45 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import List, Optional

from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
SyncMode,
)
from airbyte_cdk.sources.embedded.tools import get_first


def get_stream(catalog: AirbyteCatalog, stream_name: str) -> Optional[AirbyteStream]:
return get_first(catalog.streams, lambda s: s.name == stream_name)


def get_stream_names(catalog: AirbyteCatalog) -> List[str]:
return [stream.name for stream in catalog.streams]


def to_configured_stream(
stream: AirbyteStream,
sync_mode: SyncMode = SyncMode.full_refresh,
destination_sync_mode: DestinationSyncMode = DestinationSyncMode.append,
cursor_field: Optional[List[str]] = None,
primary_key: Optional[List[List[str]]] = None,
) -> ConfiguredAirbyteStream:
return ConfiguredAirbyteStream(
stream=stream, sync_mode=sync_mode, destination_sync_mode=destination_sync_mode, cursor_field=cursor_field, primary_key=primary_key
)


def to_configured_catalog(configured_streams: List[ConfiguredAirbyteStream]) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog(streams=configured_streams)


def create_configured_catalog(stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh) -> ConfiguredAirbyteCatalog:
configured_streams = [to_configured_stream(stream, sync_mode=sync_mode, primary_key=stream.source_defined_primary_key)]

return to_configured_catalog(configured_streams)
34 changes: 34 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#


import logging
from abc import ABC, abstractmethod
from typing import Generic, Iterable, Optional

from airbyte_cdk.connector import TConfig
from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.sources.source import Source


class SourceRunner(ABC, Generic[TConfig]):
@abstractmethod
def discover(self, config: TConfig) -> AirbyteCatalog:
pass

@abstractmethod
def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]:
pass


class CDKRunner(SourceRunner[TConfig]):
def __init__(self, source: Source, name: str):
self._source = source
self._logger = logging.getLogger(name)

def discover(self, config: TConfig) -> AirbyteCatalog:
return self._source.discover(self._logger, config)

def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]:
return self._source.read(self._logger, config, catalog, state=[state] if state else [])
24 changes: 24 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, Callable, Dict, Iterable, Optional

import dpath
from airbyte_cdk.models import AirbyteStream


def get_first(iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True) -> Optional[Any]:
return next(filter(predicate, iterable), None)


def get_defined_id(stream: AirbyteStream, data: Dict[str, Any]) -> Optional[str]:
if not stream.source_defined_primary_key:
return None
primary_key = []
for key in stream.source_defined_primary_key:
try:
primary_key.append(str(dpath.util.get(data, key)))
except KeyError:
primary_key.append("__not_found__")
return "_".join(primary_key)
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import unittest
from typing import Any, Mapping, Optional
from unittest.mock import MagicMock

from airbyte_cdk.sources.embedded.base_integration import BaseEmbeddedIntegration
from airbyte_protocol.models import (
AirbyteCatalog,
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
Level,
SyncMode,
Type,
)


class TestIntegration(BaseEmbeddedIntegration):
def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Mapping[str, Any]:
return {"data": record.data, "id": id}


class EmbeddedIntegrationTestCase(unittest.TestCase):
def setUp(self):
self.source_class = MagicMock()
self.source = MagicMock()
self.source_class.return_value = self.source
self.config = MagicMock()
self.integration = TestIntegration(self.source, self.config)
self.stream1 = AirbyteStream(
name="test",
source_defined_primary_key=[["test"]],
json_schema={},
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
)
self.stream2 = AirbyteStream(name="test2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh])
self.source.discover.return_value = AirbyteCatalog(streams=[self.stream2, self.stream1])

def test_integration(self):
self.source.read.return_value = [
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 2}, emitted_at=2)),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 3}, emitted_at=3)),
]
result = list(self.integration._load_data("test", None))
self.assertEqual(
result,
[
{"data": {"test": 1}, "id": "1"},
{"data": {"test": 2}, "id": "2"},
{"data": {"test": 3}, "id": "3"},
],
)
self.source.discover.assert_called_once_with(self.config)
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream1,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
primary_key=[["test"]],
)
]
),
None,
)

def test_state(self):
state = AirbyteStateMessage(data={})
self.source.read.return_value = [
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)),
AirbyteMessage(type=Type.STATE, state=state),
]
result = list(self.integration._load_data("test", None))
self.assertEqual(
result,
[
{"data": {"test": 1}, "id": "1"},
],
)
self.integration.last_state = state

def test_incremental(self):
state = AirbyteStateMessage(data={})
list(self.integration._load_data("test", state))
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream1,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
primary_key=[["test"]],
)
]
),
state,
)

def test_incremental_without_state(self):
list(self.integration._load_data("test"))
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream1,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
primary_key=[["test"]],
)
]
),
None,
)

def test_incremental_unsupported(self):
state = AirbyteStateMessage(data={})
list(self.integration._load_data("test2", state))
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream2,
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.append,
)
]
),
state,
)