Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 CDK: fix bug with limit parameter for incremental stream #5833

Merged
merged 10 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,40 @@ def _read_stream(

use_incremental = configured_stream.sync_mode == SyncMode.incremental and stream_instance.supports_incremental
if use_incremental:
record_iterator = self._read_incremental(logger, stream_instance, configured_stream, connector_state)
record_iterator = self._read_incremental(logger, stream_instance, configured_stream, connector_state, internal_config)
else:
record_iterator = self._read_full_refresh(stream_instance, configured_stream)
record_iterator = self._read_full_refresh(stream_instance, configured_stream, internal_config)

record_counter = 0
stream_name = configured_stream.stream.name
logger.info(f"Syncing stream: {stream_name} ")
for record in record_iterator:
if record.type == MessageType.RECORD:
if internal_config.limit and record_counter >= internal_config.limit:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need this because we might have limit > size(slice)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you idea, but don't think its good to put this condition back, check out my proposal (updated PR)

logger.info(f"Reached limit defined by internal config ({internal_config.limit}), stop reading")
break
record_counter += 1
yield record

logger.info(f"Read {record_counter} records from {stream_name} stream")

@staticmethod
def _limit_reached(internal_config: InternalConfig, records_counter: int) -> bool:
"""
Check if record count reached liimt set by internal config.
:param internal_config - internal CDK configuration separated from user defined config
:records_counter - number of records already red
:return True if limit reached, False otherwise
"""
if internal_config.limit:
if records_counter >= internal_config.limit:
return True
return False

def _read_incremental(
self,
logger: AirbyteLogger,
stream_instance: Stream,
configured_stream: ConfiguredAirbyteStream,
connector_state: MutableMapping[str, Any],
internal_config: InternalConfig,
) -> Iterator[AirbyteMessage]:
stream_name = configured_stream.stream.name
stream_state = connector_state.get(stream_name, {})
Expand All @@ -170,31 +181,43 @@ def _read_incremental(
slices = stream_instance.stream_slices(
cursor_field=configured_stream.cursor_field, sync_mode=SyncMode.incremental, stream_state=stream_state
)
total_records_counter = 0
for slice in slices:
record_counter = 0
records = stream_instance.read_records(
sync_mode=SyncMode.incremental,
stream_slice=slice,
stream_state=stream_state,
cursor_field=configured_stream.cursor_field or None,
)
for record_data in records:
record_counter += 1
for record_counter, record_data in enumerate(records, start=1):
yield self._as_airbyte_record(stream_name, record_data)
stream_state = stream_instance.get_updated_state(stream_state, record_data)
if checkpoint_interval and record_counter % checkpoint_interval == 0:
yield self._checkpoint_state(stream_name, stream_state, connector_state, logger)

total_records_counter += 1
if self._limit_reached(internal_config, total_records_counter):
# Break from slice loop to save state and exit from _read_incremental function.
break
Copy link
Contributor

@keu keu Sep 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we still going to read all slices, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


yield self._checkpoint_state(stream_name, stream_state, connector_state, logger)
if self._limit_reached(internal_config, total_records_counter):
return

def _read_full_refresh(self, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream) -> Iterator[AirbyteMessage]:
def _read_full_refresh(
self, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream, internal_config: InternalConfig
) -> Iterator[AirbyteMessage]:
slices = stream_instance.stream_slices(sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field)
total_records_counter = 0
for slice in slices:
records = stream_instance.read_records(
stream_slice=slice, sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field
)
for record in records:
yield self._as_airbyte_record(configured_stream.stream.name, record)
total_records_counter += 1
if self._limit_reached(internal_config, total_records_counter):
return

def _checkpoint_state(self, stream_name, stream_state, connector_state, logger):
logger.info(f"Setting state of {stream_name} stream to {stream_state}")
Expand Down
75 changes: 58 additions & 17 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import pytest
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode, Type
from airbyte_cdk.sources import AbstractSource, Source
from airbyte_cdk.sources.streams.core import Stream
from airbyte_cdk.sources.streams.http.http import HttpStream
Expand All @@ -54,6 +54,25 @@ def source():
return MockSource()


@pytest.fixture
def catalog():
configured_catalog = {
"streams": [
{
"stream": {"name": "mock_http_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
{
"stream": {"name": "mock_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
]
}
return ConfiguredAirbyteCatalog.parse_obj(configured_catalog)


@pytest.fixture
def abstract_source(mocker):
mocker.patch.multiple(HttpStream, __abstractmethods__=set())
Expand All @@ -63,6 +82,9 @@ class MockHttpStream(MagicMock, HttpStream):
url_base = "http://example.com"
path = "/dummy/path"

def supports_incremental(self):
return True

def __init__(self, *args, **kvargs):
MagicMock.__init__(self)
HttpStream.__init__(self, *args, kvargs)
Expand Down Expand Up @@ -120,22 +142,7 @@ def test_read_catalog(source):
assert actual == expected


def test_internal_config(abstract_source):
configured_catalog = {
"streams": [
{
"stream": {"name": "mock_http_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
{
"stream": {"name": "mock_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
]
}
catalog = ConfiguredAirbyteCatalog.parse_obj(configured_catalog)
def test_internal_config(abstract_source, catalog):
streams = abstract_source.streams(None)
assert len(streams) == 2
http_stream = streams[0]
Expand Down Expand Up @@ -175,3 +182,37 @@ def test_internal_config(abstract_source):
assert http_stream.page_size == 2
# Make sure page_size havent been set for non http streams
assert not non_http_stream.page_size


def test_internal_config_limit(abstract_source, catalog):
logger_mock = MagicMock()
del catalog.streams[1]
STREAM_LIMIT = 2
FULL_RECORDS_NUMBER = 3
streams = abstract_source.streams(None)
http_stream = streams[0]
http_stream.read_records.return_value = [{}] * FULL_RECORDS_NUMBER
internal_config = {"some_config": 100, "_limit": STREAM_LIMIT}

catalog.streams[0].sync_mode = SyncMode.full_refresh
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
# Check if log line matches number of limit
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")

# No limit, check if state record produced for incremental stream
catalog.streams[0].sync_mode = SyncMode.incremental
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == FULL_RECORDS_NUMBER + 1
assert records[-1].type == Type.STATE

# Set limit and check if state is produced when limit is set for incremental stream
logger_mock.reset_mock()
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + 1
assert records[-1].type == Type.STATE
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")