Skip to content

Commit

Permalink
CDK: Embedded reader utils (#28873)
Browse files Browse the repository at this point in the history
* relax pydantic dep

* Automated Commit - Format and Process Resources Changes

* wip

* wrap up base integration

* add init file

* introduce CDK runner and improve error message

* make state param optional

* update protocol models

* review comments

* always run incremental if possible

* fix

---------

Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
  • Loading branch information
2 people authored and jbfbell committed Aug 5, 2023
1 parent 5cd576a commit b91acef
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 0 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2021 Airbyte, Inc., all rights reserved.
#
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from typing import Generic, Iterable, Optional, TypeVar

from airbyte_cdk.connector import TConfig
from airbyte_cdk.sources.embedded.catalog import create_configured_catalog, get_stream, get_stream_names
from airbyte_cdk.sources.embedded.runner import SourceRunner
from airbyte_cdk.sources.embedded.tools import get_defined_id
from airbyte_protocol.models import AirbyteRecordMessage, AirbyteStateMessage, SyncMode, Type

TOutput = TypeVar("TOutput")


class BaseEmbeddedIntegration(ABC, Generic[TConfig, TOutput]):
def __init__(self, runner: SourceRunner[TConfig], config: TConfig):
self.source = runner
self.config = config

self.last_state: Optional[AirbyteStateMessage] = None

@abstractmethod
def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Optional[TOutput]:
"""
Turn an Airbyte record into the appropriate output type for the integration.
"""
pass

def _load_data(self, stream_name: str, state: Optional[AirbyteStateMessage] = None) -> Iterable[TOutput]:
catalog = self.source.discover(self.config)
stream = get_stream(catalog, stream_name)
if not stream:
raise ValueError(f"Stream {stream_name} not found, the following streams are available: {', '.join(get_stream_names(catalog))}")
if SyncMode.incremental not in stream.supported_sync_modes:
configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.full_refresh)
else:
configured_catalog = create_configured_catalog(stream, sync_mode=SyncMode.incremental)

for message in self.source.read(self.config, configured_catalog, state):
if message.type == Type.RECORD:
output = self._handle_record(message.record, get_defined_id(stream, message.record.data))
if output:
yield output
elif message.type is Type.STATE and message.state:
self.last_state = message.state
45 changes: 45 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import List, Optional

from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
SyncMode,
)
from airbyte_cdk.sources.embedded.tools import get_first


def get_stream(catalog: AirbyteCatalog, stream_name: str) -> Optional[AirbyteStream]:
return get_first(catalog.streams, lambda s: s.name == stream_name)


def get_stream_names(catalog: AirbyteCatalog) -> List[str]:
return [stream.name for stream in catalog.streams]


def to_configured_stream(
stream: AirbyteStream,
sync_mode: SyncMode = SyncMode.full_refresh,
destination_sync_mode: DestinationSyncMode = DestinationSyncMode.append,
cursor_field: Optional[List[str]] = None,
primary_key: Optional[List[List[str]]] = None,
) -> ConfiguredAirbyteStream:
return ConfiguredAirbyteStream(
stream=stream, sync_mode=sync_mode, destination_sync_mode=destination_sync_mode, cursor_field=cursor_field, primary_key=primary_key
)


def to_configured_catalog(configured_streams: List[ConfiguredAirbyteStream]) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog(streams=configured_streams)


def create_configured_catalog(stream: AirbyteStream, sync_mode: SyncMode = SyncMode.full_refresh) -> ConfiguredAirbyteCatalog:
configured_streams = [to_configured_stream(stream, sync_mode=sync_mode, primary_key=stream.source_defined_primary_key)]

return to_configured_catalog(configured_streams)
34 changes: 34 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#


import logging
from abc import ABC, abstractmethod
from typing import Generic, Iterable, Optional

from airbyte_cdk.connector import TConfig
from airbyte_cdk.models import AirbyteCatalog, AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog
from airbyte_cdk.sources.source import Source


class SourceRunner(ABC, Generic[TConfig]):
@abstractmethod
def discover(self, config: TConfig) -> AirbyteCatalog:
pass

@abstractmethod
def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]:
pass


class CDKRunner(SourceRunner[TConfig]):
def __init__(self, source: Source, name: str):
self._source = source
self._logger = logging.getLogger(name)

def discover(self, config: TConfig) -> AirbyteCatalog:
return self._source.discover(self._logger, config)

def read(self, config: TConfig, catalog: ConfiguredAirbyteCatalog, state: Optional[AirbyteStateMessage]) -> Iterable[AirbyteMessage]:
return self._source.read(self._logger, config, catalog, state=[state] if state else [])
24 changes: 24 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/embedded/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, Callable, Dict, Iterable, Optional

import dpath
from airbyte_cdk.models import AirbyteStream


def get_first(iterable: Iterable[Any], predicate: Callable[[Any], bool] = lambda m: True) -> Optional[Any]:
return next(filter(predicate, iterable), None)


def get_defined_id(stream: AirbyteStream, data: Dict[str, Any]) -> Optional[str]:
if not stream.source_defined_primary_key:
return None
primary_key = []
for key in stream.source_defined_primary_key:
try:
primary_key.append(str(dpath.util.get(data, key)))
except KeyError:
primary_key.append("__not_found__")
return "_".join(primary_key)
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import unittest
from typing import Any, Mapping, Optional
from unittest.mock import MagicMock

from airbyte_cdk.sources.embedded.base_integration import BaseEmbeddedIntegration
from airbyte_protocol.models import (
AirbyteCatalog,
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
Level,
SyncMode,
Type,
)


class TestIntegration(BaseEmbeddedIntegration):
def _handle_record(self, record: AirbyteRecordMessage, id: Optional[str]) -> Mapping[str, Any]:
return {"data": record.data, "id": id}


class EmbeddedIntegrationTestCase(unittest.TestCase):
def setUp(self):
self.source_class = MagicMock()
self.source = MagicMock()
self.source_class.return_value = self.source
self.config = MagicMock()
self.integration = TestIntegration(self.source, self.config)
self.stream1 = AirbyteStream(
name="test",
source_defined_primary_key=[["test"]],
json_schema={},
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
)
self.stream2 = AirbyteStream(name="test2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh])
self.source.discover.return_value = AirbyteCatalog(streams=[self.stream2, self.stream1])

def test_integration(self):
self.source.read.return_value = [
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 2}, emitted_at=2)),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 3}, emitted_at=3)),
]
result = list(self.integration._load_data("test", None))
self.assertEqual(
result,
[
{"data": {"test": 1}, "id": "1"},
{"data": {"test": 2}, "id": "2"},
{"data": {"test": 3}, "id": "3"},
],
)
self.source.discover.assert_called_once_with(self.config)
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream1,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
primary_key=[["test"]],
)
]
),
None,
)

def test_state(self):
state = AirbyteStateMessage(data={})
self.source.read.return_value = [
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="test")),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream="test", data={"test": 1}, emitted_at=1)),
AirbyteMessage(type=Type.STATE, state=state),
]
result = list(self.integration._load_data("test", None))
self.assertEqual(
result,
[
{"data": {"test": 1}, "id": "1"},
],
)
self.integration.last_state = state

def test_incremental(self):
state = AirbyteStateMessage(data={})
list(self.integration._load_data("test", state))
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream1,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
primary_key=[["test"]],
)
]
),
state,
)

def test_incremental_without_state(self):
list(self.integration._load_data("test"))
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream1,
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.append,
primary_key=[["test"]],
)
]
),
None,
)

def test_incremental_unsupported(self):
state = AirbyteStateMessage(data={})
list(self.integration._load_data("test2", state))
self.source.read.assert_called_once_with(
self.config,
ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=self.stream2,
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.append,
)
]
),
state,
)

0 comments on commit b91acef

Please sign in to comment.