Skip to content

Commit

Permalink
low-code: Fix type check in DeclarativeStream (#25533)
Browse files Browse the repository at this point in the history
* Set right type

* Update the comment

* Update

* format

* Update comment
  • Loading branch information
girarda authored Apr 26, 2023
1 parent 2500f4c commit e41060c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from dataclasses import InitVar, dataclass, field
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteTraceMessage, SyncMode
from airbyte_cdk.models import AirbyteMessage, SyncMode
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.declarative.types import Config, StreamSlice
from airbyte_cdk.sources.streams.core import Stream
from airbyte_cdk.sources.streams.core import Stream, StreamData


@dataclass
Expand Down Expand Up @@ -102,17 +102,27 @@ def read_records(

def _apply_transformations(
self,
message_or_record_data: Union[AirbyteMessage, AirbyteLogMessage, AirbyteTraceMessage, Mapping[str, Any]],
message_or_record_data: StreamData,
config: Config,
stream_slice: StreamSlice,
):
# If the input is an AirbyteRecord, transform the record's data
# If the input is another type of Airbyte Message, return it as is
# If the input is an AirbyteMessage with a record, transform the record's data
# If the input is another type of AirbyteMessage, return it as is
# If the input is a dict, transform it
if isinstance(message_or_record_data, AirbyteLogMessage) or isinstance(message_or_record_data, AirbyteTraceMessage):
return message_or_record_data
if isinstance(message_or_record_data, AirbyteMessage):
if message_or_record_data.record:
record = message_or_record_data.record.data
else:
return message_or_record_data
elif isinstance(message_or_record_data, dict):
record = message_or_record_data
else:
# Raise an error because this is unexpected and indicative of a typing problem in the CDK
raise ValueError(
f"Unexpected record type. Expected {StreamData}. Got {type(message_or_record_data)}. This is probably due to a bug in the CDK."
)
for transformation in self.transformations:
transformation.transform(message_or_record_data, config=config, stream_state=self.state, stream_slice=stream_slice)
transformation.transform(record, config=config, stream_state=self.state, stream_slice=stream_slice)

return message_or_record_data

Expand Down
8 changes: 3 additions & 5 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union

import airbyte_cdk.sources.utils.casing as casing
from airbyte_cdk.models import AirbyteLogMessage, AirbyteStream, AirbyteTraceMessage, SyncMode
from airbyte_cdk.models import AirbyteMessage, AirbyteStream, SyncMode

# list of all possible HTTP methods which can be used for sending of request bodies
from airbyte_cdk.sources.utils.schema_helpers import ResourceSchemaLoader
Expand All @@ -24,10 +24,8 @@

# A stream's read method can return one of the following types:
# Mapping[str, Any]: The content of an AirbyteRecordMessage
# AirbyteRecordMessage: An AirbyteRecordMessage
# AirbyteLogMessage: A log message
# AirbyteTraceMessage: A trace message
StreamData = Union[Mapping[str, Any], AirbyteLogMessage, AirbyteTraceMessage]
# AirbyteMessage: An AirbyteMessage. Could be of any type
StreamData = Union[Mapping[str, Any], AirbyteMessage]


def package_name_from_class(cls: object) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from unittest import mock
from unittest.mock import MagicMock, call

from airbyte_cdk.models import AirbyteLogMessage, AirbyteTraceMessage, Level, SyncMode, TraceType
from airbyte_cdk.models import (
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteTraceMessage,
Level,
SyncMode,
TraceType,
Type,
)
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.transformations import AddFields, RecordTransformation
from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition
Expand All @@ -24,8 +33,8 @@ def test_declarative_stream():
records = [
{"pk": 1234, "field": "value"},
{"pk": 4567, "field": "different_value"},
AirbyteLogMessage(level=Level.INFO, message="This is a log message"),
AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345),
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")),
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)),
]
stream_slices = [
{"date": "2021-01-01"},
Expand Down Expand Up @@ -84,15 +93,17 @@ def test_declarative_stream_with_add_fields_transform():
retriever_records = [
{"pk": 1234, "field": "value"},
{"pk": 4567, "field": "different_value"},
AirbyteLogMessage(level=Level.INFO, message="This is a log message"),
AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(data={"pk": 1357, "field": "a_value"}, emitted_at=12344, stream="stream")),
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")),
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)),
]

expected_records = [
{"pk": 1234, "field": "value", "added_key": "added_value"},
{"pk": 4567, "field": "different_value", "added_key": "added_value"},
AirbyteLogMessage(level=Level.INFO, message="This is a log message"),
AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345),
AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(data={"pk": 1357, "field": "a_value", "added_key": "added_value"}, emitted_at=12344, stream="stream")),
AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is a log message")),
AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)),
]
stream_slices = [
{"date": "2021-01-01"},
Expand Down

0 comments on commit e41060c

Please sign in to comment.