diff --git a/airbyte-cdk/python/CHANGELOG.md b/airbyte-cdk/python/CHANGELOG.md index d84eb7015181..9e4f9312c012 100644 --- a/airbyte-cdk/python/CHANGELOG.md +++ b/airbyte-cdk/python/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 0.1.18 +Fix incremental stream not saved state when internal limit config set. + ## 0.1.17 Fix mismatching between number of records actually read and number of records in logs by 1: https://github.com/airbytehq/airbyte/pull/5767 diff --git a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py index c9fc7f361059..c8d2c47883fa 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py @@ -137,29 +137,40 @@ def _read_stream( use_incremental = configured_stream.sync_mode == SyncMode.incremental and stream_instance.supports_incremental if use_incremental: - record_iterator = self._read_incremental(logger, stream_instance, configured_stream, connector_state) + record_iterator = self._read_incremental(logger, stream_instance, configured_stream, connector_state, internal_config) else: - record_iterator = self._read_full_refresh(stream_instance, configured_stream) + record_iterator = self._read_full_refresh(stream_instance, configured_stream, internal_config) record_counter = 0 stream_name = configured_stream.stream.name logger.info(f"Syncing stream: {stream_name} ") for record in record_iterator: if record.type == MessageType.RECORD: - if internal_config.limit and record_counter >= internal_config.limit: - logger.info(f"Reached limit defined by internal config ({internal_config.limit}), stop reading") - break record_counter += 1 yield record logger.info(f"Read {record_counter} records from {stream_name} stream") + @staticmethod + def _limit_reached(internal_config: InternalConfig, records_counter: int) -> bool: + """ + Check if record count reached liimt set by internal config. + :param internal_config - internal CDK configuration separated from user defined config + :records_counter - number of records already red + :return True if limit reached, False otherwise + """ + if internal_config.limit: + if records_counter >= internal_config.limit: + return True + return False + def _read_incremental( self, logger: AirbyteLogger, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream, connector_state: MutableMapping[str, Any], + internal_config: InternalConfig, ) -> Iterator[AirbyteMessage]: stream_name = configured_stream.stream.name stream_state = connector_state.get(stream_name, {}) @@ -170,31 +181,46 @@ def _read_incremental( slices = stream_instance.stream_slices( cursor_field=configured_stream.cursor_field, sync_mode=SyncMode.incremental, stream_state=stream_state ) + total_records_counter = 0 for slice in slices: - record_counter = 0 records = stream_instance.read_records( sync_mode=SyncMode.incremental, stream_slice=slice, stream_state=stream_state, cursor_field=configured_stream.cursor_field or None, ) - for record_data in records: - record_counter += 1 + for record_counter, record_data in enumerate(records, start=1): yield self._as_airbyte_record(stream_name, record_data) stream_state = stream_instance.get_updated_state(stream_state, record_data) if checkpoint_interval and record_counter % checkpoint_interval == 0: yield self._checkpoint_state(stream_name, stream_state, connector_state, logger) + total_records_counter += 1 + # This functionality should ideally live outside of this method + # but since state is managed inside this method, we keep track + # of it here. + if self._limit_reached(internal_config, total_records_counter): + # Break from slice loop to save state and exit from _read_incremental function. + break + yield self._checkpoint_state(stream_name, stream_state, connector_state, logger) + if self._limit_reached(internal_config, total_records_counter): + return - def _read_full_refresh(self, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream) -> Iterator[AirbyteMessage]: + def _read_full_refresh( + self, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream, internal_config: InternalConfig + ) -> Iterator[AirbyteMessage]: slices = stream_instance.stream_slices(sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field) + total_records_counter = 0 for slice in slices: records = stream_instance.read_records( stream_slice=slice, sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field ) for record in records: yield self._as_airbyte_record(configured_stream.stream.name, record) + total_records_counter += 1 + if self._limit_reached(internal_config, total_records_counter): + return def _checkpoint_state(self, stream_name, stream_state, connector_state, logger): logger.info(f"Setting state of {stream_name} stream to {stream_state}") diff --git a/airbyte-cdk/python/setup.py b/airbyte-cdk/python/setup.py index 735d5dc6253c..7e5223f33a26 100644 --- a/airbyte-cdk/python/setup.py +++ b/airbyte-cdk/python/setup.py @@ -35,7 +35,7 @@ setup( name="airbyte-cdk", - version="0.1.17", + version="0.1.18", description="A framework for writing Airbyte Connectors.", long_description=README, long_description_content_type="text/markdown", diff --git a/airbyte-cdk/python/unit_tests/sources/test_source.py b/airbyte-cdk/python/unit_tests/sources/test_source.py index b7ec2e0e7a2d..9a297c201549 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_source.py +++ b/airbyte-cdk/python/unit_tests/sources/test_source.py @@ -30,7 +30,7 @@ import pytest from airbyte_cdk.logger import AirbyteLogger -from airbyte_cdk.models import ConfiguredAirbyteCatalog +from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode, Type from airbyte_cdk.sources import AbstractSource, Source from airbyte_cdk.sources.streams.core import Stream from airbyte_cdk.sources.streams.http.http import HttpStream @@ -54,6 +54,25 @@ def source(): return MockSource() +@pytest.fixture +def catalog(): + configured_catalog = { + "streams": [ + { + "stream": {"name": "mock_http_stream", "json_schema": {}}, + "destination_sync_mode": "overwrite", + "sync_mode": "full_refresh", + }, + { + "stream": {"name": "mock_stream", "json_schema": {}}, + "destination_sync_mode": "overwrite", + "sync_mode": "full_refresh", + }, + ] + } + return ConfiguredAirbyteCatalog.parse_obj(configured_catalog) + + @pytest.fixture def abstract_source(mocker): mocker.patch.multiple(HttpStream, __abstractmethods__=set()) @@ -63,6 +82,9 @@ class MockHttpStream(MagicMock, HttpStream): url_base = "http://example.com" path = "/dummy/path" + def supports_incremental(self): + return True + def __init__(self, *args, **kvargs): MagicMock.__init__(self) HttpStream.__init__(self, *args, kvargs) @@ -120,22 +142,7 @@ def test_read_catalog(source): assert actual == expected -def test_internal_config(abstract_source): - configured_catalog = { - "streams": [ - { - "stream": {"name": "mock_http_stream", "json_schema": {}}, - "destination_sync_mode": "overwrite", - "sync_mode": "full_refresh", - }, - { - "stream": {"name": "mock_stream", "json_schema": {}}, - "destination_sync_mode": "overwrite", - "sync_mode": "full_refresh", - }, - ] - } - catalog = ConfiguredAirbyteCatalog.parse_obj(configured_catalog) +def test_internal_config(abstract_source, catalog): streams = abstract_source.streams(None) assert len(streams) == 2 http_stream = streams[0] @@ -175,3 +182,37 @@ def test_internal_config(abstract_source): assert http_stream.page_size == 2 # Make sure page_size havent been set for non http streams assert not non_http_stream.page_size + + +def test_internal_config_limit(abstract_source, catalog): + logger_mock = MagicMock() + del catalog.streams[1] + STREAM_LIMIT = 2 + FULL_RECORDS_NUMBER = 3 + streams = abstract_source.streams(None) + http_stream = streams[0] + http_stream.read_records.return_value = [{}] * FULL_RECORDS_NUMBER + internal_config = {"some_config": 100, "_limit": STREAM_LIMIT} + + catalog.streams[0].sync_mode = SyncMode.full_refresh + records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})] + assert len(records) == STREAM_LIMIT + logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list] + # Check if log line matches number of limit + read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")] + assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ") + + # No limit, check if state record produced for incremental stream + catalog.streams[0].sync_mode = SyncMode.incremental + records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] + assert len(records) == FULL_RECORDS_NUMBER + 1 + assert records[-1].type == Type.STATE + + # Set limit and check if state is produced when limit is set for incremental stream + logger_mock.reset_mock() + records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})] + assert len(records) == STREAM_LIMIT + 1 + assert records[-1].type == Type.STATE + logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list] + read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")] + assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")