diff --git a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py index c0c2cf3dbc85..842757997f24 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py @@ -4,15 +4,12 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime -from functools import lru_cache from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union from airbyte_cdk.models import ( AirbyteCatalog, AirbyteConnectionStatus, AirbyteMessage, - AirbyteRecordMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, @@ -24,8 +21,8 @@ from airbyte_cdk.sources.source import Source from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.http.http import HttpStream +from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config -from airbyte_cdk.sources.utils.transform import TypeTransformer from airbyte_cdk.utils.event_timing import create_timer from airbyte_cdk.utils.traced_exception import AirbyteTracedException @@ -241,20 +238,27 @@ def _read_incremental( stream_state=stream_state, cursor_field=configured_stream.cursor_field or None, ) - for record_counter, record_data in enumerate(records, start=1): - yield self._as_airbyte_record(stream_name, record_data) - stream_state = stream_instance.get_updated_state(stream_state, record_data) - checkpoint_interval = stream_instance.state_checkpoint_interval - if checkpoint_interval and record_counter % checkpoint_interval == 0: - yield self._checkpoint_state(stream_instance, stream_state, state_manager) - - total_records_counter += 1 - # This functionality should ideally live outside of this method - # but since state is managed inside this method, we keep track - # of it here. - if self._limit_reached(internal_config, total_records_counter): - # Break from slice loop to save state and exit from _read_incremental function. - break + record_counter = 0 + for message_counter, record_data_or_message in enumerate(records, start=1): + message = stream_data_to_airbyte_message( + stream_name, record_data_or_message, stream_instance.transformer, stream_instance.get_json_schema() + ) + yield message + if message.type == MessageType.RECORD: + record = message.record + stream_state = stream_instance.get_updated_state(stream_state, record.data) + checkpoint_interval = stream_instance.state_checkpoint_interval + record_counter += 1 + if checkpoint_interval and record_counter % checkpoint_interval == 0: + yield self._checkpoint_state(stream_instance, stream_state, state_manager) + + total_records_counter += 1 + # This functionality should ideally live outside of this method + # but since state is managed inside this method, we keep track + # of it here. + if self._limit_reached(internal_config, total_records_counter): + # Break from slice loop to save state and exit from _read_incremental function. + break yield self._checkpoint_state(stream_instance, stream_state, state_manager) if self._limit_reached(internal_config, total_records_counter): @@ -277,16 +281,20 @@ def _read_full_refresh( total_records_counter = 0 for _slice in slices: logger.debug("Processing stream slice", extra={"slice": _slice}) - records = stream_instance.read_records( + record_data_or_messages = stream_instance.read_records( stream_slice=_slice, sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field, ) - for record in records: - yield self._as_airbyte_record(configured_stream.stream.name, record) - total_records_counter += 1 - if self._limit_reached(internal_config, total_records_counter): - return + for record_data_or_message in record_data_or_messages: + message = stream_data_to_airbyte_message( + stream_instance.name, record_data_or_message, stream_instance.transformer, stream_instance.get_json_schema() + ) + yield message + if message.type == MessageType.RECORD: + total_records_counter += 1 + if self._limit_reached(internal_config, total_records_counter): + return def _checkpoint_state(self, stream: Stream, stream_state, state_manager: ConnectorStateManager): # First attempt to retrieve the current state using the stream's state property. We receive an AttributeError if the state @@ -294,33 +302,11 @@ def _checkpoint_state(self, stream: Stream, stream_state, state_manager: Connect # instance's deprecated get_updated_state() method. try: state_manager.update_state_for_stream(stream.name, stream.namespace, stream.state) + except AttributeError: state_manager.update_state_for_stream(stream.name, stream.namespace, stream_state) return state_manager.create_state_message(stream.name, stream.namespace, send_per_stream_state=self.per_stream_state_enabled) - @lru_cache(maxsize=None) - def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[TypeTransformer, Mapping[str, Any]]: - """ - Lookup stream's transform object and jsonschema based on stream name. - This function would be called a lot so using caching to save on costly - get_json_schema operation. - :param stream_name name of stream from catalog. - :return tuple with stream transformer object and discover json schema. - """ - stream_instance = self._stream_to_instance_map[stream_name] - return stream_instance.transformer, stream_instance.get_json_schema() - - def _as_airbyte_record(self, stream_name: str, data: Mapping[str, Any]): - now_millis = int(datetime.now().timestamp() * 1000) - transformer, schema = self._get_stream_transformer_and_schema(stream_name) - # Transform object fields according to config. Most likely you will - # need it to normalize values against json schema. By default no action - # taken unless configured. See - # docs/connector-development/cdk-python/schemas.md for details. - transformer.transform(data, schema) # type: ignore - message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis) - return AirbyteMessage(type=MessageType.RECORD, record=message) - @staticmethod def _apply_log_level_to_stream_logger(logger: logging.Logger, stream_instance: Stream): """ diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py index 02199df40c31..2c34c336ebbb 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/core.py @@ -6,14 +6,24 @@ import inspect import logging from abc import ABC, abstractmethod +from functools import lru_cache from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union import airbyte_cdk.sources.utils.casing as casing -from airbyte_cdk.models import AirbyteStream, SyncMode +from airbyte_cdk.models import AirbyteLogMessage, AirbyteRecordMessage, AirbyteStream, AirbyteTraceMessage, SyncMode + +# list of all possible HTTP methods which can be used for sending of request bodies from airbyte_cdk.sources.utils.schema_helpers import ResourceSchemaLoader from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer from deprecated.classic import deprecated +# A stream's read method can return one of the following types: +# Mapping[str, Any]: The content of an AirbyteRecordMessage +# AirbyteRecordMessage: An AirbyteRecordMessage +# AirbyteLogMessage: A log message +# AirbyteTraceMessage: A trace message +StreamData = Union[Mapping[str, Any], AirbyteRecordMessage, AirbyteLogMessage, AirbyteTraceMessage] + def package_name_from_class(cls: object) -> str: """Find the package name given a class name""" @@ -94,11 +104,12 @@ def read_records( cursor_field: List[str] = None, stream_slice: Mapping[str, Any] = None, stream_state: Mapping[str, Any] = None, - ) -> Iterable[Mapping[str, Any]]: + ) -> Iterable[StreamData]: """ This method should be overridden by subclasses to read records based on the inputs """ + @lru_cache(maxsize=None) def get_json_schema(self) -> Mapping[str, Any]: """ :return: A dict of the JSON schema representing this stream. diff --git a/airbyte-cdk/python/airbyte_cdk/sources/utils/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/utils/__init__.py index 5adf292dff0c..8edf89696e66 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/utils/__init__.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/utils/__init__.py @@ -1,5 +1,7 @@ # -# Copyright (c) 2021 Airbyte, Inc., all rights reserved. +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. # # Initialize Utils Package + +__all__ = ["record_helper"] diff --git a/airbyte-cdk/python/airbyte_cdk/sources/utils/record_helper.py b/airbyte-cdk/python/airbyte_cdk/sources/utils/record_helper.py new file mode 100644 index 000000000000..e3c3223bd66a --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/utils/record_helper.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +import datetime +from typing import Any, Mapping + +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteRecordMessage, AirbyteTraceMessage +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer + + +def stream_data_to_airbyte_message( + stream_name: str, + data_or_message: StreamData, + transformer: TypeTransformer = TypeTransformer(TransformConfig.NoTransform), + schema: Mapping[str, Any] = None, +) -> AirbyteMessage: + if schema is None: + schema = {} + + if isinstance(data_or_message, dict): + data = data_or_message + now_millis = int(datetime.datetime.now().timestamp() * 1000) + # Transform object fields according to config. Most likely you will + # need it to normalize values against json schema. By default no action + # taken unless configured. See + # docs/connector-development/cdk-python/schemas.md for details. + transformer.transform(data, schema) # type: ignore + message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis) + return AirbyteMessage(type=MessageType.RECORD, record=message) + elif isinstance(data_or_message, AirbyteRecordMessage): + return AirbyteMessage(type=MessageType.RECORD, record=data_or_message) + elif isinstance(data_or_message, AirbyteTraceMessage): + return AirbyteMessage(type=MessageType.TRACE, trace=data_or_message) + elif isinstance(data_or_message, AirbyteLogMessage): + return AirbyteMessage(type=MessageType.LOG, log=data_or_message) + else: + raise ValueError(f"Unexpected type for data_or_message: {type(data_or_message)}") diff --git a/airbyte-cdk/python/unit_tests/sources/streams/test_streams_core.py b/airbyte-cdk/python/unit_tests/sources/streams/test_streams_core.py index 82fe96d412c2..1ae9214079c9 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/test_streams_core.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/test_streams_core.py @@ -4,6 +4,7 @@ from typing import Any, Iterable, List, Mapping +from unittest import mock import pytest from airbyte_cdk.models import AirbyteStream, SyncMode @@ -173,3 +174,11 @@ def test_wrapped_primary_key_various_argument(test_input, expected): wrapped = Stream._wrapped_primary_key(test_input) assert wrapped == expected + + +@mock.patch("airbyte_cdk.sources.utils.schema_helpers.ResourceSchemaLoader.get_schema") +def test_get_json_schema_is_cached(mocked_method): + stream = StreamStubFullRefresh() + for i in range(5): + stream.get_json_schema() + assert mocked_method.call_count == 1 diff --git a/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py b/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py index 7696a058383f..bbb808a0ec74 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py +++ b/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py @@ -2,6 +2,7 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # +import copy import logging from collections import defaultdict from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union @@ -11,6 +12,7 @@ from airbyte_cdk.models import ( AirbyteCatalog, AirbyteConnectionStatus, + AirbyteLogMessage, AirbyteMessage, AirbyteRecordMessage, AirbyteStateBlob, @@ -21,6 +23,7 @@ ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, DestinationSyncMode, + Level, Status, StreamDescriptor, SyncMode, @@ -29,6 +32,7 @@ from airbyte_cdk.sources import AbstractSource from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.streams import IncrementalMixin, Stream +from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message from airbyte_cdk.utils.traced_exception import AirbyteTracedException logger = logging.getLogger("airbyte") @@ -149,6 +153,31 @@ def state(self, value): pass +class MockStreamEmittingAirbyteMessages(MockStreamWithState): + def __init__( + self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, name: str = None, state=None + ): + super().__init__(inputs_and_mocked_outputs, name, state) + self._inputs_and_mocked_outputs = inputs_and_mocked_outputs + self._name = name + + @property + def name(self): + return self._name + + @property + def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: + return "pk" + + @property + def state(self) -> MutableMapping[str, Any]: + return {self.cursor_field: self._cursor_value} if self._cursor_value else {} + + @state.setter + def state(self, value: MutableMapping[str, Any]): + self._cursor_value = value.get(self.cursor_field, self.start_date) + + def test_discover(mocker): """Tests that the appropriate AirbyteCatalog is returned from the discover method""" airbyte_stream1 = AirbyteStream( @@ -783,6 +812,115 @@ def test_with_slices_and_interval(self, mocker, use_legacy, per_stream_enabled): assert expected == messages + @pytest.mark.parametrize( + "per_stream_enabled", + [ + pytest.param(False, id="test_source_emits_state_as_per_stream_format"), + ], + ) + def test_emit_non_records(self, mocker, per_stream_enabled): + """ + Tests that an incremental read which uses slices and a checkpoint interval: + 1. outputs all records + 2. outputs a state message every N records (N=checkpoint_interval) + 3. outputs a state message after reading the entire slice + """ + + input_state = [] + slices = [{"1": "1"}, {"2": "2"}] + stream_output = [ + {"k1": "v1"}, + AirbyteLogMessage(level=Level.INFO, message="HELLO"), + {"k2": "v2"}, + {"k3": "v3"}, + ] + stream_1 = MockStreamEmittingAirbyteMessages( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s1", + state=copy.deepcopy(input_state), + ) + stream_2 = MockStreamEmittingAirbyteMessages( + [ + ( + { + "sync_mode": SyncMode.incremental, + "stream_slice": s, + "stream_state": mocker.ANY, + }, + stream_output, + ) + for s in slices + ], + name="s2", + state=copy.deepcopy(input_state), + ) + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", return_value=slices) + mocker.patch.object( + MockStream, + "state_checkpoint_interval", + new_callable=mocker.PropertyMock, + return_value=2, + ) + + src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled) + catalog = ConfiguredAirbyteCatalog( + streams=[ + _configured_stream(stream_1, SyncMode.incremental), + _configured_stream(stream_2, SyncMode.incremental), + ] + ) + + expected = _fix_emitted_at( + [ + # stream 1 slice 1 + stream_data_to_airbyte_message("s1", stream_output[0]), + stream_data_to_airbyte_message("s1", stream_output[1]), + stream_data_to_airbyte_message("s1", stream_output[2]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + stream_data_to_airbyte_message("s1", stream_output[3]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + # stream 1 slice 2 + stream_data_to_airbyte_message("s1", stream_output[0]), + stream_data_to_airbyte_message("s1", stream_output[1]), + stream_data_to_airbyte_message("s1", stream_output[2]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + stream_data_to_airbyte_message("s1", stream_output[3]), + _as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}), + # stream 2 slice 1 + stream_data_to_airbyte_message("s2", stream_output[0]), + stream_data_to_airbyte_message("s2", stream_output[1]), + stream_data_to_airbyte_message("s2", stream_output[2]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + stream_data_to_airbyte_message("s2", stream_output[3]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + # stream 2 slice 2 + stream_data_to_airbyte_message("s2", stream_output[0]), + stream_data_to_airbyte_message("s2", stream_output[1]), + stream_data_to_airbyte_message("s2", stream_output[2]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + stream_data_to_airbyte_message("s2", stream_output[3]), + _as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}), + ] + ) + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state))) + + assert expected == messages + def test_checkpoint_state_from_stream_instance(): teams_stream = MockStreamOverridesStateMethod() diff --git a/airbyte-cdk/python/unit_tests/sources/test_source.py b/airbyte-cdk/python/unit_tests/sources/test_source.py index c81b794e5af6..1034975c1892 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_source.py +++ b/airbyte-cdk/python/unit_tests/sources/test_source.py @@ -425,8 +425,8 @@ def test_source_config_no_transform(abstract_source, catalog): records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] assert len(records) == 2 * 5 assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5 - assert http_stream.get_json_schema.call_count == 1 - assert non_http_stream.get_json_schema.call_count == 1 + assert http_stream.get_json_schema.call_count == 5 + assert non_http_stream.get_json_schema.call_count == 5 def test_source_config_transform(abstract_source, catalog): diff --git a/airbyte-cdk/python/unit_tests/sources/utils/test_record_helper.py b/airbyte-cdk/python/unit_tests/sources/utils/test_record_helper.py new file mode 100644 index 000000000000..ffb3248333da --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/utils/test_record_helper.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +from unittest.mock import MagicMock + +import pytest +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, + AirbyteTraceMessage, + Level, + TraceType, +) +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message + +NOW = 1234567 +STREAM_NAME = "my_stream" + + +@pytest.mark.parametrize( + "test_name, data, expected_message", + [ + ( + "test_data_to_airbyte_record", + {"id": 0, "field_A": 1.0, "field_B": "airbyte"}, + AirbyteMessage( + type=MessageType.RECORD, + record=AirbyteRecordMessage(stream="my_stream", data={"id": 0, "field_A": 1.0, "field_B": "airbyte"}, emitted_at=NOW), + ), + ), + ( + "test_record_to_airbyte_record", + AirbyteRecordMessage(stream="my_stream", data={"id": 0, "field_A": 1.0, "field_B": "airbyte"}, emitted_at=NOW), + AirbyteMessage( + type=MessageType.RECORD, + record=AirbyteRecordMessage(stream="my_stream", data={"id": 0, "field_A": 1.0, "field_B": "airbyte"}, emitted_at=NOW), + ), + ), + ], +) +def test_data_or_record_to_airbyte_record(test_name, data, expected_message): + transformer = MagicMock() + schema = {} + message = stream_data_to_airbyte_message(STREAM_NAME, data, transformer, schema) + message.record.emitted_at = NOW + + if isinstance(data, dict): + transformer.transform.assert_called_with(data, schema) + else: + assert not transformer.transform.called + assert expected_message == message + + +@pytest.mark.parametrize( + "test_name, data, expected_message", + [ + ( + "test_log_message_to_airbyte_record", + AirbyteLogMessage(level=Level.INFO, message="Hello, this is a log message"), + AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message="Hello, this is a log message")), + ), + ( + "test_trace_message_to_airbyte_record", + AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=101), + AirbyteMessage(type=MessageType.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=101)), + ), + ], +) +def test_log_or_trace_to_message(test_name, data, expected_message): + transformer = MagicMock() + schema = {} + message = stream_data_to_airbyte_message(STREAM_NAME, data, transformer, schema) + + assert not transformer.transform.called + assert expected_message == message + + +@pytest.mark.parametrize( + "test_name, data", + [ + ("test_log_message_to_airbyte_record", AirbyteStateMessage(type=AirbyteStateType.STREAM)), + ], +) +def test_state_message_to_message(test_name, data): + transformer = MagicMock() + schema = {} + with pytest.raises(ValueError): + stream_data_to_airbyte_message(STREAM_NAME, data, transformer, schema)