-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stream returns AirbyteMessages #18572
Changes from 3 commits
5f88aa6
0de8395
e5acdcf
8a63af7
d0dcb52
adf07fa
c68e01d
258990d
b0f80a7
68528b2
b31f9bd
893beab
c3964c9
c69774a
152f5df
8ea5e85
1684823
b08a9c4
4fd9ed0
06b1c70
3e4b135
4702f4d
ae7b7fb
3d86bce
9dc4486
d3852f3
179d689
65c416c
b13fe7a
0fff056
01dcd2a
5d46bb0
600c195
0b48fea
613e876
ea2089d
e2d11b3
d901e1c
0600e59
82e46c5
50bb3ed
acfcec8
3fa6a00
048508d
3e85dca
6d9d0aa
464b247
4acd199
0c36dcb
46a807e
bad305b
d78628f
7fd5cc9
0cbd28e
a1a139e
b1d62cd
1da2ed1
5dac7c2
594ad06
2d91e9a
cff30a1
2f91b80
3fc697c
62d95eb
003b860
27150ba
26c3289
b36842a
fe3b01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,12 @@ | |
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from datetime import datetime | ||
from functools import lru_cache | ||
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union | ||
|
||
from airbyte_cdk.models import ( | ||
AirbyteCatalog, | ||
AirbyteConnectionStatus, | ||
AirbyteMessage, | ||
AirbyteRecordMessage, | ||
AirbyteStateMessage, | ||
ConfiguredAirbyteCatalog, | ||
ConfiguredAirbyteStream, | ||
|
@@ -25,7 +22,6 @@ | |
from airbyte_cdk.sources.streams import Stream | ||
from airbyte_cdk.sources.streams.http.http import HttpStream | ||
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config | ||
from airbyte_cdk.sources.utils.transform import TypeTransformer | ||
from airbyte_cdk.utils.event_timing import create_timer | ||
from airbyte_cdk.utils.traced_exception import AirbyteTracedException | ||
|
||
|
@@ -235,26 +231,28 @@ def _read_incremental( | |
for _slice in slices: | ||
has_slices = True | ||
logger.debug("Processing stream slice", extra={"slice": _slice}) | ||
records = stream_instance.read_records( | ||
records = stream_instance.read_records_as_messages( | ||
sync_mode=SyncMode.incremental, | ||
stream_slice=_slice, | ||
stream_state=stream_state, | ||
cursor_field=configured_stream.cursor_field or None, | ||
) | ||
for record_counter, record_data in enumerate(records, start=1): | ||
yield self._as_airbyte_record(stream_name, record_data) | ||
stream_state = stream_instance.get_updated_state(stream_state, record_data) | ||
checkpoint_interval = stream_instance.state_checkpoint_interval | ||
if checkpoint_interval and record_counter % checkpoint_interval == 0: | ||
yield self._checkpoint_state(stream_instance, stream_state, state_manager) | ||
|
||
total_records_counter += 1 | ||
# This functionality should ideally live outside of this method | ||
# but since state is managed inside this method, we keep track | ||
# of it here. | ||
if self._limit_reached(internal_config, total_records_counter): | ||
# Break from slice loop to save state and exit from _read_incremental function. | ||
break | ||
for record_counter, message in enumerate(records, start=1): | ||
yield message | ||
if message.type == MessageType.RECORD: | ||
record = message.record | ||
stream_state = stream_instance.get_updated_state(stream_state, record.data) | ||
checkpoint_interval = stream_instance.state_checkpoint_interval | ||
if checkpoint_interval and record_counter % checkpoint_interval == 0: | ||
yield self._checkpoint_state(stream_instance, stream_state, state_manager) | ||
|
||
total_records_counter += 1 | ||
# This functionality should ideally live outside of this method | ||
# but since state is managed inside this method, we keep track | ||
# of it here. | ||
if self._limit_reached(internal_config, total_records_counter): | ||
# Break from slice loop to save state and exit from _read_incremental function. | ||
break | ||
|
||
yield self._checkpoint_state(stream_instance, stream_state, state_manager) | ||
if self._limit_reached(internal_config, total_records_counter): | ||
|
@@ -277,16 +275,17 @@ def _read_full_refresh( | |
total_records_counter = 0 | ||
for _slice in slices: | ||
logger.debug("Processing stream slice", extra={"slice": _slice}) | ||
records = stream_instance.read_records( | ||
messages = stream_instance.read_records_as_messages( | ||
stream_slice=_slice, | ||
sync_mode=SyncMode.full_refresh, | ||
cursor_field=configured_stream.cursor_field, | ||
) | ||
for record in records: | ||
yield self._as_airbyte_record(configured_stream.stream.name, record) | ||
total_records_counter += 1 | ||
if self._limit_reached(internal_config, total_records_counter): | ||
return | ||
for message in messages: | ||
yield message | ||
if message.type == MessageType.RECORD: | ||
total_records_counter += 1 | ||
if self._limit_reached(internal_config, total_records_counter): | ||
return | ||
|
||
def _checkpoint_state(self, stream: Stream, stream_state, state_manager: ConnectorStateManager): | ||
# First attempt to retrieve the current state using the stream's state property. We receive an AttributeError if the state | ||
|
@@ -298,29 +297,6 @@ def _checkpoint_state(self, stream: Stream, stream_state, state_manager: Connect | |
state_manager.update_state_for_stream(stream.name, stream.namespace, stream_state) | ||
return state_manager.create_state_message(stream.name, stream.namespace, send_per_stream_state=self.per_stream_state_enabled) | ||
|
||
@lru_cache(maxsize=None) | ||
def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[TypeTransformer, Mapping[str, Any]]: | ||
""" | ||
Lookup stream's transform object and jsonschema based on stream name. | ||
This function would be called a lot so using caching to save on costly | ||
get_json_schema operation. | ||
:param stream_name name of stream from catalog. | ||
:return tuple with stream transformer object and discover json schema. | ||
""" | ||
stream_instance = self._stream_to_instance_map[stream_name] | ||
return stream_instance.transformer, stream_instance.get_json_schema() | ||
|
||
def _as_airbyte_record(self, stream_name: str, data: Mapping[str, Any]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extracted to a function so it can be reused by the SimpleRetriever |
||
now_millis = int(datetime.now().timestamp() * 1000) | ||
transformer, schema = self._get_stream_transformer_and_schema(stream_name) | ||
# Transform object fields according to config. Most likely you will | ||
# need it to normalize values against json schema. By default no action | ||
# taken unless configured. See | ||
# docs/connector-development/cdk-python/schemas.md for details. | ||
transformer.transform(data, schema) # type: ignore | ||
message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis) | ||
return AirbyteMessage(type=MessageType.RECORD, record=message) | ||
|
||
@staticmethod | ||
def _apply_log_level_to_stream_logger(logger: logging.Logger, stream_instance: Stream): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,14 @@ | |
import inspect | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from functools import lru_cache | ||
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union | ||
|
||
import airbyte_cdk.sources.utils.casing as casing | ||
from airbyte_cdk.models import AirbyteStream, SyncMode | ||
from airbyte_cdk.models import AirbyteMessage, AirbyteStream, SyncMode | ||
|
||
# list of all possible HTTP methods which can be used for sending of request bodies | ||
from airbyte_cdk.sources.utils.record_helper import data_to_airbyte_record | ||
from airbyte_cdk.sources.utils.schema_helpers import ResourceSchemaLoader | ||
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer | ||
from deprecated.classic import deprecated | ||
|
@@ -87,6 +91,19 @@ def get_error_display_message(self, exception: BaseException) -> Optional[str]: | |
""" | ||
return None | ||
|
||
def read_records_as_messages( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something that gives me slight pause is that this new function is defined to make it seem like If we were to pull it back into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit smelly that I can override Are we just saying that a developer can return either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @brianjlai the @sherifnada
|
||
self, | ||
sync_mode: SyncMode, | ||
cursor_field: List[str] = None, | ||
stream_slice: Mapping[str, Any] = None, | ||
stream_state: Mapping[str, Any] = None, | ||
) -> Iterable[AirbyteMessage]: | ||
""" """ | ||
for record_mapping in self.read_records( | ||
sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state | ||
): | ||
yield data_to_airbyte_record(self.name, record_mapping, self.transformer, self.get_json_schema()) | ||
|
||
@abstractmethod | ||
def read_records( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the method name is a little misleading now because it doesn't only return records anymore |
||
self, | ||
|
@@ -99,6 +116,7 @@ def read_records( | |
This method should be overridden by subclasses to read records based on the inputs | ||
""" | ||
|
||
@lru_cache(maxsize=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
def get_json_schema(self) -> Mapping[str, Any]: | ||
""" | ||
:return: A dict of the JSON schema representing this stream. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
# | ||
# Copyright (c) 2021 Airbyte, Inc., all rights reserved. | ||
# Copyright (c) 2022 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
# Initialize Utils Package | ||
|
||
__all__ = ["record_helper"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,20 +3,28 @@ | |
# | ||
|
||
|
||
import datetime | ||
import json | ||
import unittest | ||
from http import HTTPStatus | ||
from typing import Any, Iterable, Mapping, Optional | ||
from unittest.mock import ANY, MagicMock, patch | ||
|
||
import pytest | ||
import requests | ||
from airbyte_cdk.models import SyncMode | ||
from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, SyncMode, Type | ||
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream | ||
from airbyte_cdk.sources.streams.http.auth import NoAuth | ||
from airbyte_cdk.sources.streams.http.auth import TokenAuthenticator as HttpTokenAuthenticator | ||
from airbyte_cdk.sources.streams.http.exceptions import DefaultBackoffException, RequestBodyException, UserDefinedBackoffException | ||
from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator | ||
|
||
datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z" | ||
FAKE_NOW = datetime.datetime(2022, 1, 1, tzinfo=datetime.timezone.utc) | ||
|
||
config = {"start_date": "2021-01-01T00:00:00.000000+0000", "start_date_ymd": "2021-01-01"} | ||
timezone = datetime.timezone.utc | ||
|
||
|
||
class StubBasicReadHttpStream(HttpStream): | ||
url_base = "https://test_base_url.com" | ||
|
@@ -37,6 +45,16 @@ def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapp | |
self.resp_counter += 1 | ||
yield stubResp | ||
|
||
def get_json_schema(self) -> Mapping[str, Any]: | ||
return {} | ||
|
||
|
||
@pytest.fixture() | ||
def mock_datetime_now(monkeypatch): | ||
datetime_mock = unittest.mock.MagicMock(wraps=datetime.datetime) | ||
datetime_mock.now.return_value = FAKE_NOW | ||
monkeypatch.setattr(datetime, "datetime", datetime_mock) | ||
|
||
|
||
def test_default_authenticator(): | ||
stream = StubBasicReadHttpStream() | ||
|
@@ -79,6 +97,20 @@ def test_stub_basic_read_http_stream_read_records(mocker): | |
assert [{"data": 1}] == records | ||
|
||
|
||
def test_stub_basic_read_http_stream_read_records_as_messages(mocker, mock_datetime_now): | ||
stream = StubBasicReadHttpStream() | ||
blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway. | ||
mocker.patch.object(StubBasicReadHttpStream, "_send_request", return_value=blank_response) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like mocking private methods is a smell. I know it didn't originate in this PR, but it feels like it's telling us we need a dependency injection of the requester component. (mostly meaning to start a discussion in this comment, not asking that it's fixed in this PR) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +100 |
||
|
||
records = list(stream.read_records_as_messages(SyncMode.full_refresh)) | ||
|
||
assert [ | ||
AirbyteMessage( | ||
type=Type.RECORD, record=AirbyteRecordMessage(stream="stub_basic_read_http_stream", data={"data": 1}, emitted_at=1640995200000) | ||
) | ||
] == records | ||
|
||
|
||
class StubNextPageTokenHttpStream(StubBasicReadHttpStream): | ||
current_page = 0 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -425,8 +425,8 @@ def test_source_config_no_transform(abstract_source, catalog): | |
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})] | ||
assert len(records) == 2 * 5 | ||
assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5 | ||
assert http_stream.get_json_schema.call_count == 1 | ||
assert non_http_stream.get_json_schema.call_count == 1 | ||
assert http_stream.get_json_schema.call_count == 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this assert verified that we used the cached transformer and schema instead of calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense! but side question, could we add a test into
only gets invoked once. like given that it's cached, we'd only expect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done 0cbd28e |
||
assert non_http_stream.get_json_schema.call_count == 5 | ||
|
||
|
||
def test_source_config_transform(abstract_source, catalog): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same logic moved inside the
if
block because we only want to update the state and record counter on records