diff --git a/airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py b/airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py index b30a3a3744f48..bd7bc0bc59b0f 100644 --- a/airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py +++ b/airbyte-cdk/python/airbyte_cdk/connector_builder/message_grouper.py @@ -24,7 +24,7 @@ from airbyte_cdk.sources.utils.types import JsonType from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer -from airbyte_cdk.utils.schema_inferrer import SchemaInferrer +from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException from airbyte_protocol.models.airbyte_protocol import ( AirbyteControlMessage, AirbyteLogMessage, @@ -45,6 +45,32 @@ def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: self._max_slices = max_slices self._max_record_limit = max_record_limit + def _pk_to_nested_and_composite_field(self, field: Optional[Union[str, List[str], List[List[str]]]]) -> List[List[str]]: + if not field: + return [[]] + + if isinstance(field, str): + return [[field]] + + is_composite_key = isinstance(field[0], str) + if is_composite_key: + return [[i] for i in field] # type: ignore # the type of field is expected to be List[str] here + + return field # type: ignore # the type of field is expected to be List[List[str]] here + + def _cursor_field_to_nested_and_composite_field(self, field: Union[str, List[str]]) -> List[List[str]]: + if not field: + return [[]] + + if isinstance(field, str): + return [[field]] + + is_nested_key = isinstance(field[0], str) + if is_nested_key: + return [field] # type: ignore # the type of field is expected to be List[str] here + + raise ValueError(f"Unknown type for cursor field `{field}") + def get_message_groups( self, source: DeclarativeSource, @@ -54,7 +80,11 @@ def get_message_groups( ) -> StreamRead: if record_limit is not None and not (1 <= record_limit <= self._max_record_limit): raise ValueError(f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}") - schema_inferrer = SchemaInferrer() + stream = source.streams(config)[0] # The connector builder currently only supports reading from a single stream at a time + schema_inferrer = SchemaInferrer( + self._pk_to_nested_and_composite_field(stream.primary_key), + self._cursor_field_to_nested_and_composite_field(stream.cursor_field), + ) datetime_format_inferrer = DatetimeFormatInferrer() if record_limit is None: @@ -88,14 +118,20 @@ def get_message_groups( else: raise ValueError(f"Unknown message group type: {type(message_group)}") + try: + configured_stream = configured_catalog.streams[0] # The connector builder currently only supports reading from a single stream at a time + schema = schema_inferrer.get_stream_schema(configured_stream.stream.name) + except SchemaValidationException as exception: + for validation_error in exception.validation_errors: + log_messages.append(LogMessage(validation_error, "ERROR")) + schema = exception.schema + return StreamRead( logs=log_messages, slices=slices, test_read_limit_reached=self._has_reached_limit(slices), auxiliary_requests=auxiliary_requests, - inferred_schema=schema_inferrer.get_stream_schema( - configured_catalog.streams[0].stream.name - ), # The connector builder currently only supports reading from a single stream at a time + inferred_schema=schema, latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) if latest_config_update else None, inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(), ) diff --git a/airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py b/airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py index 522e3dd68ab28..6b88bb898c7dd 100644 --- a/airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py +++ b/airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py @@ -1,29 +1,62 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. -from typing import Any, Dict, List +from typing import List, Union, overload -from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode +from airbyte_protocol.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, SyncMode + + +class ConfiguredAirbyteStreamBuilder: + def __init__(self) -> None: + self._stream = { + "stream": { + "name": "any name", + "json_schema": {}, + "supported_sync_modes": ["full_refresh", "incremental"], + "source_defined_primary_key": [["id"]], + }, + "primary_key": [["id"]], + "sync_mode": "full_refresh", + "destination_sync_mode": "overwrite", + } + + def with_name(self, name: str) -> "ConfiguredAirbyteStreamBuilder": + self._stream["stream"]["name"] = name # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any] + return self + + def with_sync_mode(self, sync_mode: SyncMode) -> "ConfiguredAirbyteStreamBuilder": + self._stream["sync_mode"] = sync_mode.name + return self + + def with_primary_key(self, pk: List[List[str]]) -> "ConfiguredAirbyteStreamBuilder": + self._stream["primary_key"] = pk + self._stream["stream"]["source_defined_primary_key"] = pk # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any] + return self + + def build(self) -> ConfiguredAirbyteStream: + return ConfiguredAirbyteStream.parse_obj(self._stream) class CatalogBuilder: def __init__(self) -> None: - self._streams: List[Dict[str, Any]] = [] + self._streams: List[ConfiguredAirbyteStreamBuilder] = [] + @overload + def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder": + ... + + @overload def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder": - self._streams.append( - { - "stream": { - "name": name, - "json_schema": {}, - "supported_sync_modes": ["full_refresh", "incremental"], - "source_defined_primary_key": [["id"]], - }, - "primary_key": [["id"]], - "sync_mode": sync_mode.name, - "destination_sync_mode": "overwrite", - } - ) + ... + + def with_stream(self, name: Union[str, ConfiguredAirbyteStreamBuilder], sync_mode: Union[SyncMode, None] = None) -> "CatalogBuilder": + # As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface + # with_stream(str, SyncMode) + + # to avoid a breaking change, `name` needs to stay in the API but this can be either a name or a builder + name_or_builder = name + builder = name_or_builder if isinstance(name_or_builder, ConfiguredAirbyteStreamBuilder) else ConfiguredAirbyteStreamBuilder().with_name(name_or_builder).with_sync_mode(sync_mode) + self._streams.append(builder) return self def build(self) -> ConfiguredAirbyteCatalog: - return ConfiguredAirbyteCatalog.parse_obj({"streams": self._streams}) + return ConfiguredAirbyteCatalog(streams=list(map(lambda builder: builder.build(), self._streams))) diff --git a/airbyte-cdk/python/airbyte_cdk/utils/schema_inferrer.py b/airbyte-cdk/python/airbyte_cdk/utils/schema_inferrer.py index 41f8e179469e0..134068e212bfa 100644 --- a/airbyte-cdk/python/airbyte_cdk/utils/schema_inferrer.py +++ b/airbyte-cdk/python/airbyte_cdk/utils/schema_inferrer.py @@ -3,18 +3,23 @@ # from collections import defaultdict -from typing import Any, Dict, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional from airbyte_cdk.models import AirbyteRecordMessage from genson import SchemaBuilder, SchemaNode from genson.schema.strategies.object import Object from genson.schema.strategies.scalar import Number +_NULL_TYPE = "null" + class NoRequiredObj(Object): """ This class has Object behaviour, but it does not generate "required[]" fields - every time it parses object. So we dont add unnecessary extra field. + every time it parses object. So we don't add unnecessary extra field. + + The logic is that even reading all the data from a source, it does not mean that there can be another record added with those fields as + optional. Hence, we make everything nullable. """ def to_schema(self) -> Mapping[str, Any]: @@ -41,6 +46,25 @@ class NoRequiredSchemaBuilder(SchemaBuilder): InferredSchema = Dict[str, Any] +class SchemaValidationException(Exception): + @classmethod + def merge_exceptions(cls, exceptions: List["SchemaValidationException"]) -> "SchemaValidationException": + # We assume the schema is the same for all SchemaValidationException + return SchemaValidationException(exceptions[0].schema, [x for exception in exceptions for x in exception._validation_errors]) + + def __init__(self, schema: InferredSchema, validation_errors: List[Exception]): + self._schema = schema + self._validation_errors = validation_errors + + @property + def schema(self) -> InferredSchema: + return self._schema + + @property + def validation_errors(self) -> List[str]: + return list(map(lambda error: str(error), self._validation_errors)) + + class SchemaInferrer: """ This class is used to infer a JSON schema which fits all the records passed into it @@ -53,23 +77,15 @@ class SchemaInferrer: stream_to_builder: Dict[str, SchemaBuilder] - def __init__(self) -> None: + def __init__(self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None) -> None: self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder) + self._pk = [] if pk is None else pk + self._cursor_field = [] if cursor_field is None else cursor_field def accumulate(self, record: AirbyteRecordMessage) -> None: """Uses the input record to add to the inferred schemas maintained by this object""" self.stream_to_builder[record.stream].add_object(record.data) - def get_inferred_schemas(self) -> Dict[str, InferredSchema]: - """ - Returns the JSON schemas for all encountered streams inferred by inspecting all records - passed via the accumulate method - """ - schemas = {} - for stream_name, builder in self.stream_to_builder.items(): - schemas[stream_name] = self._clean(builder.to_schema()) - return schemas - def _clean(self, node: InferredSchema) -> InferredSchema: """ Recursively cleans up a produced schema: @@ -78,23 +94,119 @@ def _clean(self, node: InferredSchema) -> InferredSchema: """ if isinstance(node, dict): if "anyOf" in node: - if len(node["anyOf"]) == 2 and {"type": "null"} in node["anyOf"]: - real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == "null" else node["anyOf"][0] + if len(node["anyOf"]) == 2 and {"type": _NULL_TYPE} in node["anyOf"]: + real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == _NULL_TYPE else node["anyOf"][0] node.update(real_type) - node["type"] = [node["type"], "null"] + node["type"] = [node["type"], _NULL_TYPE] node.pop("anyOf") if "properties" in node and isinstance(node["properties"], dict): for key, value in list(node["properties"].items()): - if isinstance(value, dict) and value.get("type", None) == "null": + if isinstance(value, dict) and value.get("type", None) == _NULL_TYPE: node["properties"].pop(key) else: self._clean(value) if "items" in node: self._clean(node["items"]) + + # this check needs to follow the "anyOf" cleaning as it might populate `type` + if isinstance(node["type"], list): + if _NULL_TYPE in node["type"]: + # we want to make sure null is always at the end as it makes schemas more readable + node["type"].remove(_NULL_TYPE) + node["type"].append(_NULL_TYPE) + else: + node["type"] = [node["type"], _NULL_TYPE] + return node + + def _add_required_properties(self, node: InferredSchema) -> InferredSchema: + """ + This method takes properties that should be marked as required (self._pk and self._cursor_field) and travel the schema to mark every + node as required. + """ + # Removing nullable for the root as when we call `_clean`, we make everything nullable + node["type"] = "object" + + exceptions = [] + for field in [x for x in [self._pk, self._cursor_field] if x]: + try: + self._add_fields_as_required(node, field) + except SchemaValidationException as exception: + exceptions.append(exception) + + if exceptions: + raise SchemaValidationException.merge_exceptions(exceptions) + return node + def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List[str]]) -> None: + """ + Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required. + """ + errors: List[Exception] = [] + + for path in composite_key: + try: + self._add_field_as_required(node, path) + except ValueError as exception: + errors.append(exception) + + if errors: + raise SchemaValidationException(node, errors) + + def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None) -> None: + """ + Take a nested key and travel the schema to mark every node as required. + """ + self._remove_null_from_type(node) + if self._is_leaf(path): + return + + if not traveled_path: + traveled_path = [] + + if "properties" not in node: + # This validation is only relevant when `traveled_path` is empty + raise ValueError( + f"Path {traveled_path} does not refer to an object but is `{node}` and hence {path} can't be marked as required." + ) + + next_node = path[0] + if next_node not in node["properties"]: + raise ValueError(f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required.") + + if "type" not in node: + # We do not expect this case to happen but we added a specific error message just in case + raise ValueError( + f"Unknown schema error: {traveled_path} is expected to have a type but did not. Schema inferrence is probably broken" + ) + + if node["type"] not in ["object", ["null", "object"], ["object", "null"]]: + raise ValueError(f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`") + + if "required" not in node or not node["required"]: + node["required"] = [next_node] + elif next_node not in node["required"]: + node["required"].append(next_node) + + traveled_path.append(next_node) + self._add_field_as_required(node["properties"][next_node], path[1:], traveled_path) + + def _is_leaf(self, path: List[str]) -> bool: + return len(path) == 0 + + def _remove_null_from_type(self, node: InferredSchema) -> None: + if isinstance(node["type"], list): + if "null" in node["type"]: + node["type"].remove("null") + if len(node["type"]) == 1: + node["type"] = node["type"][0] + def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]: """ Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name. """ - return self._clean(self.stream_to_builder[stream_name].to_schema()) if stream_name in self.stream_to_builder else None + return ( + self._add_required_properties(self._clean(self.stream_to_builder[stream_name].to_schema())) + if stream_name in self.stream_to_builder + else None + ) diff --git a/airbyte-cdk/python/unit_tests/connector_builder/test_connector_builder_handler.py b/airbyte-cdk/python/unit_tests/connector_builder/test_connector_builder_handler.py index 190f8d4bcb56b..fa1cf13a09214 100644 --- a/airbyte-cdk/python/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/airbyte-cdk/python/unit_tests/connector_builder/test_connector_builder_handler.py @@ -499,7 +499,19 @@ def test_config_update(): @patch("traceback.TracebackException.from_exception") def test_read_returns_error_response(mock_from_exception): + class MockDeclarativeStream: + @property + def primary_key(self): + return [[]] + + @property + def cursor_field(self): + return [] + class MockManifestDeclarativeSource: + def streams(self, config): + return [MockDeclarativeStream()] + def read(self, logger, config, catalog, state): raise ValueError("error_message") diff --git a/airbyte-cdk/python/unit_tests/connector_builder/test_message_grouper.py b/airbyte-cdk/python/unit_tests/connector_builder/test_message_grouper.py index 437a775dd8dee..f2602b960e73e 100644 --- a/airbyte-cdk/python/unit_tests/connector_builder/test_message_grouper.py +++ b/airbyte-cdk/python/unit_tests/connector_builder/test_message_grouper.py @@ -21,6 +21,9 @@ from airbyte_cdk.models import Type as MessageType from unit_tests.connector_builder.utils import create_configured_catalog +_NO_PK = [[]] +_NO_CURSOR_FIELD = [] + MAX_PAGES_PER_SLICE = 4 MAX_SLICES = 3 @@ -96,7 +99,7 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None: response = {"status_code": 200, "headers": {"field": "value"}, "body": {"content": '{"name": "field"}'}} expected_schema = { "$schema": "http://json-schema.org/schema#", - "properties": {"name": {"type": "string"}, "date": {"type": "string"}}, + "properties": {"name": {"type": ["string", "null"]}, "date": {"type": ["string", "null"]}}, "type": "object", } expected_datetime_fields = {"date": "%Y-%m-%d"} @@ -537,6 +540,7 @@ def test_get_grouped_messages_given_maximum_number_of_pages_then_test_read_limit def test_read_stream_returns_error_if_stream_does_not_exist() -> None: mock_source = MagicMock() mock_source.read.side_effect = ValueError("error") + mock_source.streams.return_value = [make_mock_stream()] full_config: Mapping[str, Any] = {**CONFIG, **{"__injected_declarative_manifest": MANIFEST}} @@ -636,12 +640,58 @@ def test_given_no_slices_then_return_empty_slices(mock_entrypoint_read: Mock) -> assert len(stream_read.slices) == 0 +@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None: + mock_source = make_mock_source(mock_entrypoint_read, iter([ + request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"), + record_message("hashiras", {"id": "Shinobu Kocho", "date": "2023-03-03"}), + record_message("hashiras", {"id": "Muichiro Tokito", "date": "2023-03-04"}), + ])) + mock_source.streams.return_value = [Mock()] + mock_source.streams.return_value[0].primary_key = [["id"]] + mock_source.streams.return_value[0].cursor_field = _NO_CURSOR_FIELD + connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + + stream_read: StreamRead = connector_builder_handler.get_message_groups( + source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras") + ) + + assert stream_read.inferred_schema["required"] == ["id"] + + +@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None: + mock_source = make_mock_source(mock_entrypoint_read, iter([ + request_response_log_message({"request": 1}, {"response": 2}, "http://any_url.com"), + record_message("hashiras", {"id": "Shinobu Kocho", "date": "2023-03-03"}), + record_message("hashiras", {"id": "Muichiro Tokito", "date": "2023-03-04"}), + ])) + mock_source.streams.return_value = [Mock()] + mock_source.streams.return_value[0].primary_key = _NO_PK + mock_source.streams.return_value[0].cursor_field = ["date"] + connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + + stream_read: StreamRead = connector_builder_handler.get_message_groups( + source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras") + ) + + assert stream_read.inferred_schema["required"] == ["date"] + + def make_mock_source(mock_entrypoint_read: Mock, return_value: Iterator[AirbyteMessage]) -> MagicMock: mock_source = MagicMock() mock_entrypoint_read.return_value = return_value + mock_source.streams.return_value = [make_mock_stream()] return mock_source +def make_mock_stream(): + mock_stream = MagicMock() + mock_stream.primary_key = [] + mock_stream.cursor_field = [] + return mock_stream + + def request_log_message(request: Mapping[str, Any]) -> AirbyteMessage: return AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"request:{json.dumps(request)}")) diff --git a/airbyte-cdk/python/unit_tests/utils/test_schema_inferrer.py b/airbyte-cdk/python/unit_tests/utils/test_schema_inferrer.py index 3666d89c9f66a..6a1943f0279b9 100644 --- a/airbyte-cdk/python/unit_tests/utils/test_schema_inferrer.py +++ b/airbyte-cdk/python/unit_tests/utils/test_schema_inferrer.py @@ -6,7 +6,7 @@ import pytest from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage -from airbyte_cdk.utils.schema_inferrer import SchemaInferrer +from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException NOW = 1234567 @@ -19,7 +19,7 @@ {"stream": "my_stream", "data": {"field_A": "abc"}}, {"stream": "my_stream", "data": {"field_A": "def"}}, ], - {"my_stream": {"field_A": {"type": "string"}}}, + {"my_stream": {"field_A": {"type": ["string", "null"]}}}, id="test_basic", ), pytest.param( @@ -27,7 +27,7 @@ {"stream": "my_stream", "data": {"field_A": 1.0}}, {"stream": "my_stream", "data": {"field_A": "abc"}}, ], - {"my_stream": {"field_A": {"type": ["number", "string"]}}}, + {"my_stream": {"field_A": {"type": ["number", "string", "null"]}}}, id="test_deriving_schema_refine", ), pytest.param( @@ -38,10 +38,10 @@ { "my_stream": { "obj": { - "type": "object", + "type": ["object", "null"], "properties": { - "data": {"type": "array", "items": {"type": "number"}}, - "other_key": {"type": "string"}, + "data": {"type": ["array", "null"], "items": {"type": ["number", "null"]}}, + "other_key": {"type": ["string", "null"]}, }, } } @@ -53,7 +53,7 @@ {"stream": "my_stream", "data": {"field_A": 1}}, {"stream": "my_stream", "data": {"field_A": 2}}, ], - {"my_stream": {"field_A": {"type": "number"}}}, + {"my_stream": {"field_A": {"type": ["number", "null"]}}}, id="test_integer_number", ), pytest.param( @@ -68,7 +68,7 @@ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": "abc"}}, ], - {"my_stream": {"field_A": {"type": ["null", "string"]}}}, + {"my_stream": {"field_A": {"type": ["string", "null"]}}}, id="test_null_optional", ), pytest.param( @@ -76,7 +76,7 @@ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": {"nested": "abc"}}}, ], - {"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": "string"}}}}}, + {"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": ["string", "null"]}}}}}, id="test_any_of", ), pytest.param( @@ -84,7 +84,7 @@ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": {"nested": "abc", "nully": None}}}, ], - {"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": "string"}}}}}, + {"my_stream": {"field_A": {"type": ["object", "null"], "properties": {"nested": {"type": ["string", "null"]}}}}}, id="test_any_of_with_null", ), pytest.param( @@ -97,7 +97,7 @@ "my_stream": { "field_A": { "type": ["object", "null"], - "properties": {"nested": {"type": "string"}, "nully": {"type": ["null", "string"]}}, + "properties": {"nested": {"type": ["string", "null"]}, "nully": {"type": ["string", "null"]}}, } } }, @@ -113,7 +113,7 @@ "my_stream": { "field_A": { "type": ["object", "null"], - "properties": {"nested": {"type": "string"}, "nully": {"type": ["null", "string"]}}, + "properties": {"nested": {"type": ["string", "null"]}, "nully": {"type": ["string", "null"]}}, } } }, @@ -123,7 +123,7 @@ [ {"stream": "my_stream", "data": {"field_A": "abc", "nested": {"field_B": None}}}, ], - {"my_stream": {"field_A": {"type": "string"}, "nested": {"type": "object", "properties": {}}}}, + {"my_stream": {"field_A": {"type": ["string", "null"]}, "nested": {"type": ["object", "null"], "properties": {}}}}, id="test_nested_null", ), pytest.param( @@ -132,8 +132,8 @@ ], { "my_stream": { - "field_A": {"type": "string"}, - "nested": {"type": "array", "items": {"type": "object", "properties": {"field_C": {"type": "string"}}}}, + "field_A": {"type": ["string", "null"]}, + "nested": {"type": ["array", "null"], "items": {"type": ["object", "null"], "properties": {"field_C": {"type": ["string", "null"]}}}}, } }, id="test_array_nested_null", @@ -145,8 +145,8 @@ ], { "my_stream": { - "field_A": {"type": "string"}, - "nested": {"type": ["array", "null"], "items": {"type": "object", "properties": {"field_C": {"type": "string"}}}}, + "field_A": {"type": ["string", "null"]}, + "nested": {"type": ["array", "null"], "items": {"type": ["object", "null"], "properties": {"field_C": {"type": ["string", "null"]}}}}, } }, id="test_array_top_level_null", @@ -156,7 +156,7 @@ {"stream": "my_stream", "data": {"field_A": None}}, {"stream": "my_stream", "data": {"field_A": "abc"}}, ], - {"my_stream": {"field_A": {"type": ["null", "string"]}}}, + {"my_stream": {"field_A": {"type": ["string", "null"]}}}, id="test_null_string", ), ], @@ -167,36 +167,127 @@ def test_schema_derivation(input_records: List, expected_schemas: Mapping): inferrer.accumulate(AirbyteRecordMessage(stream=record["stream"], data=record["data"], emitted_at=NOW)) for stream_name, expected_schema in expected_schemas.items(): - assert inferrer.get_inferred_schemas()[stream_name] == { + assert inferrer.get_stream_schema(stream_name) == { "$schema": "http://json-schema.org/schema#", "type": "object", "properties": expected_schema, } -def test_deriving_schema_multiple_streams(): - inferrer = SchemaInferrer() - inferrer.accumulate(AirbyteRecordMessage(stream="my_stream", data={"field_A": 1.0}, emitted_at=NOW)) - inferrer.accumulate(AirbyteRecordMessage(stream="my_stream2", data={"field_A": "abc"}, emitted_at=NOW)) - inferred_schemas = inferrer.get_inferred_schemas() - assert inferred_schemas["my_stream"] == { - "$schema": "http://json-schema.org/schema#", - "type": "object", - "properties": {"field_A": {"type": "number"}}, - } - assert inferred_schemas["my_stream2"] == { - "$schema": "http://json-schema.org/schema#", - "type": "object", - "properties": {"field_A": {"type": "string"}}, - } - - -def test_get_individual_schema(): - inferrer = SchemaInferrer() - inferrer.accumulate(AirbyteRecordMessage(stream="my_stream", data={"field_A": 1.0}, emitted_at=NOW)) - assert inferrer.get_stream_schema("my_stream") == { - "$schema": "http://json-schema.org/schema#", - "type": "object", - "properties": {"field_A": {"type": "number"}}, - } - assert inferrer.get_stream_schema("another_stream") is None +_STREAM_NAME = "a stream name" +_ANY_VALUE = "any value" +_IS_PK = True +_IS_CURSOR_FIELD = True + + +def _create_inferrer_with_required_field(is_pk: bool, field: List[List[str]]) -> SchemaInferrer: + if is_pk: + return SchemaInferrer(field) + return SchemaInferrer([[]], field) + + +@pytest.mark.parametrize( + "is_pk", + [ + pytest.param(_IS_PK, id="required_field_is_pk"), + pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"), + ] +) +def test_field_is_on_root(is_pk: bool): + inferrer = _create_inferrer_with_required_field(is_pk, [["property"]]) + + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": _ANY_VALUE}, emitted_at=NOW)) + + assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property"] + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["type"] == "string" + + +@pytest.mark.parametrize( + "is_pk", + [ + pytest.param(_IS_PK, id="required_field_is_pk"), + pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"), + ] +) +def test_field_is_nested(is_pk: bool): + inferrer = _create_inferrer_with_required_field(is_pk, [["property", "nested_property"]]) + + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property": {"nested_property": _ANY_VALUE}}, emitted_at=NOW)) + + assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property"] + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["type"] == "object" + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property"]["required"] == ["nested_property"] + + +@pytest.mark.parametrize( + "is_pk", + [ + pytest.param(_IS_PK, id="required_field_is_pk"), + pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"), + ] +) +def test_field_is_composite(is_pk: bool): + inferrer = _create_inferrer_with_required_field(is_pk, [["property 1"], ["property 2"]]) + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property 1": _ANY_VALUE, "property 2": _ANY_VALUE}, emitted_at=NOW)) + assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property 1", "property 2"] + + +@pytest.mark.parametrize( + "is_pk", + [ + pytest.param(_IS_PK, id="required_field_is_pk"), + pytest.param(_IS_CURSOR_FIELD, id="required_field_is_cursor_field"), + ] +) +def test_field_is_composite_and_nested(is_pk: bool): + inferrer = _create_inferrer_with_required_field(is_pk, [["property 1", "nested"], ["property 2"]]) + + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"property 1": {"nested": _ANY_VALUE}, "property 2": _ANY_VALUE}, emitted_at=NOW)) + + assert inferrer.get_stream_schema(_STREAM_NAME)["required"] == ["property 1", "property 2"] + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["type"] == "object" + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 2"]["type"] == "string" + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["required"] == ["nested"] + assert inferrer.get_stream_schema(_STREAM_NAME)["properties"]["property 1"]["properties"]["nested"]["type"] == "string" + + +def test_given_pk_does_not_exist_when_get_inferred_schemas_then_raise_error(): + inferrer = SchemaInferrer([["pk does not exist"]]) + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW)) + + with pytest.raises(SchemaValidationException) as exception: + inferrer.get_stream_schema(_STREAM_NAME) + + assert len(exception.value.validation_errors) == 1 + + +def test_given_pk_path_is_partially_valid_when_get_inferred_schemas_then_validation_error_mentions_where_the_issue_is(): + inferrer = SchemaInferrer([["id", "nested pk that does not exist"]]) + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id": _ANY_VALUE}, emitted_at=NOW)) + + with pytest.raises(SchemaValidationException) as exception: + inferrer.get_stream_schema(_STREAM_NAME) + + assert len(exception.value.validation_errors) == 1 + assert "Path ['id']" in exception.value.validation_errors[0] + + +def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_then_valid_path_is_required(): + inferrer = SchemaInferrer([["id 1"], ["id 2"]]) + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW)) + + with pytest.raises(SchemaValidationException) as exception: + inferrer.get_stream_schema(_STREAM_NAME) + + assert exception.value.schema["required"] == ["id 1"] + + +def test_given_composite_pk_but_only_one_path_valid_when_get_inferred_schemas_then_validation_error_mentions_where_the_issue_is(): + inferrer = SchemaInferrer([["id 1"], ["id 2"]]) + inferrer.accumulate(AirbyteRecordMessage(stream=_STREAM_NAME, data={"id 1": _ANY_VALUE}, emitted_at=NOW)) + + with pytest.raises(SchemaValidationException) as exception: + inferrer.get_stream_schema(_STREAM_NAME) + + assert len(exception.value.validation_errors) == 1 + assert "id 2" in exception.value.validation_errors[0] diff --git a/airbyte-integrations/bases/connector-acceptance-test/unit_tests/test_backward_compatibility.py b/airbyte-integrations/bases/connector-acceptance-test/unit_tests/test_backward_compatibility.py index 306622325f2e3..3119a8d43511e 100644 --- a/airbyte-integrations/bases/connector-acceptance-test/unit_tests/test_backward_compatibility.py +++ b/airbyte-integrations/bases/connector-acceptance-test/unit_tests/test_backward_compatibility.py @@ -1597,6 +1597,28 @@ def test_validate_previous_configs(previous_connector_spec, actual_connector_spe ) }, ), + Transition( + name="Given the same types, the order does not matter", + should_fail=False, + previous={ + "test_stream": AirbyteStream.parse_obj( + { + "name": "test_stream", + "json_schema": {"properties": {"user": {"type": "object", "properties": {"username": {"type": ["null", "string"]}}}}}, + "supported_sync_modes": ["full_refresh"], + } + ) + }, + current={ + "test_stream": AirbyteStream.parse_obj( + { + "name": "test_stream", + "json_schema": {"properties": {"user": {"type": "object", "properties": {"username": {"type": ["string", "null"]}}}}}, + "supported_sync_modes": ["full_refresh"], + } + ) + }, + ), Transition( name="Changing 'type' field to list should not fail.", should_fail=False,