Skip to content

Commit

Permalink
[ISSUE #6548] make all fields nullable except from pk and cursor field (
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Mar 20, 2024
1 parent a38fdac commit 2f34f08
Show file tree
Hide file tree
Showing 7 changed files with 442 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airbyte_cdk.sources.utils.types import JsonType
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer
from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException
from airbyte_protocol.models.airbyte_protocol import (
AirbyteControlMessage,
AirbyteLogMessage,
Expand All @@ -45,6 +45,32 @@ def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit:
self._max_slices = max_slices
self._max_record_limit = max_record_limit

def _pk_to_nested_and_composite_field(self, field: Optional[Union[str, List[str], List[List[str]]]]) -> List[List[str]]:
if not field:
return [[]]

if isinstance(field, str):
return [[field]]

is_composite_key = isinstance(field[0], str)
if is_composite_key:
return [[i] for i in field] # type: ignore # the type of field is expected to be List[str] here

return field # type: ignore # the type of field is expected to be List[List[str]] here

def _cursor_field_to_nested_and_composite_field(self, field: Union[str, List[str]]) -> List[List[str]]:
if not field:
return [[]]

if isinstance(field, str):
return [[field]]

is_nested_key = isinstance(field[0], str)
if is_nested_key:
return [field] # type: ignore # the type of field is expected to be List[str] here

raise ValueError(f"Unknown type for cursor field `{field}")

def get_message_groups(
self,
source: DeclarativeSource,
Expand All @@ -54,7 +80,11 @@ def get_message_groups(
) -> StreamRead:
if record_limit is not None and not (1 <= record_limit <= self._max_record_limit):
raise ValueError(f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}")
schema_inferrer = SchemaInferrer()
stream = source.streams(config)[0] # The connector builder currently only supports reading from a single stream at a time
schema_inferrer = SchemaInferrer(
self._pk_to_nested_and_composite_field(stream.primary_key),
self._cursor_field_to_nested_and_composite_field(stream.cursor_field),
)
datetime_format_inferrer = DatetimeFormatInferrer()

if record_limit is None:
Expand Down Expand Up @@ -88,14 +118,20 @@ def get_message_groups(
else:
raise ValueError(f"Unknown message group type: {type(message_group)}")

try:
configured_stream = configured_catalog.streams[0] # The connector builder currently only supports reading from a single stream at a time
schema = schema_inferrer.get_stream_schema(configured_stream.stream.name)
except SchemaValidationException as exception:
for validation_error in exception.validation_errors:
log_messages.append(LogMessage(validation_error, "ERROR"))
schema = exception.schema

return StreamRead(
logs=log_messages,
slices=slices,
test_read_limit_reached=self._has_reached_limit(slices),
auxiliary_requests=auxiliary_requests,
inferred_schema=schema_inferrer.get_stream_schema(
configured_catalog.streams[0].stream.name
), # The connector builder currently only supports reading from a single stream at a time
inferred_schema=schema,
latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) if latest_config_update else None,
inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(),
)
Expand Down
67 changes: 50 additions & 17 deletions airbyte-cdk/python/airbyte_cdk/test/catalog_builder.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,62 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

from typing import Any, Dict, List
from typing import List, Union, overload

from airbyte_protocol.models import ConfiguredAirbyteCatalog, SyncMode
from airbyte_protocol.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, SyncMode


class ConfiguredAirbyteStreamBuilder:
def __init__(self) -> None:
self._stream = {
"stream": {
"name": "any name",
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": "full_refresh",
"destination_sync_mode": "overwrite",
}

def with_name(self, name: str) -> "ConfiguredAirbyteStreamBuilder":
self._stream["stream"]["name"] = name # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any]
return self

def with_sync_mode(self, sync_mode: SyncMode) -> "ConfiguredAirbyteStreamBuilder":
self._stream["sync_mode"] = sync_mode.name
return self

def with_primary_key(self, pk: List[List[str]]) -> "ConfiguredAirbyteStreamBuilder":
self._stream["primary_key"] = pk
self._stream["stream"]["source_defined_primary_key"] = pk # type: ignore # we assume that self._stream["stream"] is a Dict[str, Any]
return self

def build(self) -> ConfiguredAirbyteStream:
return ConfiguredAirbyteStream.parse_obj(self._stream)


class CatalogBuilder:
def __init__(self) -> None:
self._streams: List[Dict[str, Any]] = []
self._streams: List[ConfiguredAirbyteStreamBuilder] = []

@overload
def with_stream(self, name: ConfiguredAirbyteStreamBuilder) -> "CatalogBuilder":
...

@overload
def with_stream(self, name: str, sync_mode: SyncMode) -> "CatalogBuilder":
self._streams.append(
{
"stream": {
"name": name,
"json_schema": {},
"supported_sync_modes": ["full_refresh", "incremental"],
"source_defined_primary_key": [["id"]],
},
"primary_key": [["id"]],
"sync_mode": sync_mode.name,
"destination_sync_mode": "overwrite",
}
)
...

def with_stream(self, name: Union[str, ConfiguredAirbyteStreamBuilder], sync_mode: Union[SyncMode, None] = None) -> "CatalogBuilder":
# As we are introducing a fully fledge ConfiguredAirbyteStreamBuilder, we would like to deprecate the previous interface
# with_stream(str, SyncMode)

# to avoid a breaking change, `name` needs to stay in the API but this can be either a name or a builder
name_or_builder = name
builder = name_or_builder if isinstance(name_or_builder, ConfiguredAirbyteStreamBuilder) else ConfiguredAirbyteStreamBuilder().with_name(name_or_builder).with_sync_mode(sync_mode)
self._streams.append(builder)
return self

def build(self) -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog.parse_obj({"streams": self._streams})
return ConfiguredAirbyteCatalog(streams=list(map(lambda builder: builder.build(), self._streams)))
148 changes: 130 additions & 18 deletions airbyte-cdk/python/airbyte_cdk/utils/schema_inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
#

from collections import defaultdict
from typing import Any, Dict, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional

from airbyte_cdk.models import AirbyteRecordMessage
from genson import SchemaBuilder, SchemaNode
from genson.schema.strategies.object import Object
from genson.schema.strategies.scalar import Number

_NULL_TYPE = "null"


class NoRequiredObj(Object):
"""
This class has Object behaviour, but it does not generate "required[]" fields
every time it parses object. So we dont add unnecessary extra field.
every time it parses object. So we don't add unnecessary extra field.
The logic is that even reading all the data from a source, it does not mean that there can be another record added with those fields as
optional. Hence, we make everything nullable.
"""

def to_schema(self) -> Mapping[str, Any]:
Expand All @@ -41,6 +46,25 @@ class NoRequiredSchemaBuilder(SchemaBuilder):
InferredSchema = Dict[str, Any]


class SchemaValidationException(Exception):
@classmethod
def merge_exceptions(cls, exceptions: List["SchemaValidationException"]) -> "SchemaValidationException":
# We assume the schema is the same for all SchemaValidationException
return SchemaValidationException(exceptions[0].schema, [x for exception in exceptions for x in exception._validation_errors])

def __init__(self, schema: InferredSchema, validation_errors: List[Exception]):
self._schema = schema
self._validation_errors = validation_errors

@property
def schema(self) -> InferredSchema:
return self._schema

@property
def validation_errors(self) -> List[str]:
return list(map(lambda error: str(error), self._validation_errors))


class SchemaInferrer:
"""
This class is used to infer a JSON schema which fits all the records passed into it
Expand All @@ -53,23 +77,15 @@ class SchemaInferrer:

stream_to_builder: Dict[str, SchemaBuilder]

def __init__(self) -> None:
def __init__(self, pk: Optional[List[List[str]]] = None, cursor_field: Optional[List[List[str]]] = None) -> None:
self.stream_to_builder = defaultdict(NoRequiredSchemaBuilder)
self._pk = [] if pk is None else pk
self._cursor_field = [] if cursor_field is None else cursor_field

def accumulate(self, record: AirbyteRecordMessage) -> None:
"""Uses the input record to add to the inferred schemas maintained by this object"""
self.stream_to_builder[record.stream].add_object(record.data)

def get_inferred_schemas(self) -> Dict[str, InferredSchema]:
"""
Returns the JSON schemas for all encountered streams inferred by inspecting all records
passed via the accumulate method
"""
schemas = {}
for stream_name, builder in self.stream_to_builder.items():
schemas[stream_name] = self._clean(builder.to_schema())
return schemas

def _clean(self, node: InferredSchema) -> InferredSchema:
"""
Recursively cleans up a produced schema:
Expand All @@ -78,23 +94,119 @@ def _clean(self, node: InferredSchema) -> InferredSchema:
"""
if isinstance(node, dict):
if "anyOf" in node:
if len(node["anyOf"]) == 2 and {"type": "null"} in node["anyOf"]:
real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == "null" else node["anyOf"][0]
if len(node["anyOf"]) == 2 and {"type": _NULL_TYPE} in node["anyOf"]:
real_type = node["anyOf"][1] if node["anyOf"][0]["type"] == _NULL_TYPE else node["anyOf"][0]
node.update(real_type)
node["type"] = [node["type"], "null"]
node["type"] = [node["type"], _NULL_TYPE]
node.pop("anyOf")
if "properties" in node and isinstance(node["properties"], dict):
for key, value in list(node["properties"].items()):
if isinstance(value, dict) and value.get("type", None) == "null":
if isinstance(value, dict) and value.get("type", None) == _NULL_TYPE:
node["properties"].pop(key)
else:
self._clean(value)
if "items" in node:
self._clean(node["items"])

# this check needs to follow the "anyOf" cleaning as it might populate `type`
if isinstance(node["type"], list):
if _NULL_TYPE in node["type"]:
# we want to make sure null is always at the end as it makes schemas more readable
node["type"].remove(_NULL_TYPE)
node["type"].append(_NULL_TYPE)
else:
node["type"] = [node["type"], _NULL_TYPE]
return node

def _add_required_properties(self, node: InferredSchema) -> InferredSchema:
"""
This method takes properties that should be marked as required (self._pk and self._cursor_field) and travel the schema to mark every
node as required.
"""
# Removing nullable for the root as when we call `_clean`, we make everything nullable
node["type"] = "object"

exceptions = []
for field in [x for x in [self._pk, self._cursor_field] if x]:
try:
self._add_fields_as_required(node, field)
except SchemaValidationException as exception:
exceptions.append(exception)

if exceptions:
raise SchemaValidationException.merge_exceptions(exceptions)

return node

def _add_fields_as_required(self, node: InferredSchema, composite_key: List[List[str]]) -> None:
"""
Take a list of nested keys (this list represents a composite key) and travel the schema to mark every node as required.
"""
errors: List[Exception] = []

for path in composite_key:
try:
self._add_field_as_required(node, path)
except ValueError as exception:
errors.append(exception)

if errors:
raise SchemaValidationException(node, errors)

def _add_field_as_required(self, node: InferredSchema, path: List[str], traveled_path: Optional[List[str]] = None) -> None:
"""
Take a nested key and travel the schema to mark every node as required.
"""
self._remove_null_from_type(node)
if self._is_leaf(path):
return

if not traveled_path:
traveled_path = []

if "properties" not in node:
# This validation is only relevant when `traveled_path` is empty
raise ValueError(
f"Path {traveled_path} does not refer to an object but is `{node}` and hence {path} can't be marked as required."
)

next_node = path[0]
if next_node not in node["properties"]:
raise ValueError(f"Path {traveled_path} does not have field `{next_node}` in the schema and hence can't be marked as required.")

if "type" not in node:
# We do not expect this case to happen but we added a specific error message just in case
raise ValueError(
f"Unknown schema error: {traveled_path} is expected to have a type but did not. Schema inferrence is probably broken"
)

if node["type"] not in ["object", ["null", "object"], ["object", "null"]]:
raise ValueError(f"Path {traveled_path} is expected to be an object but was of type `{node['properties'][next_node]['type']}`")

if "required" not in node or not node["required"]:
node["required"] = [next_node]
elif next_node not in node["required"]:
node["required"].append(next_node)

traveled_path.append(next_node)
self._add_field_as_required(node["properties"][next_node], path[1:], traveled_path)

def _is_leaf(self, path: List[str]) -> bool:
return len(path) == 0

def _remove_null_from_type(self, node: InferredSchema) -> None:
if isinstance(node["type"], list):
if "null" in node["type"]:
node["type"].remove("null")
if len(node["type"]) == 1:
node["type"] = node["type"][0]

def get_stream_schema(self, stream_name: str) -> Optional[InferredSchema]:
"""
Returns the inferred JSON schema for the specified stream. Might be `None` if there were no records for the given stream name.
"""
return self._clean(self.stream_to_builder[stream_name].to_schema()) if stream_name in self.stream_to_builder else None
return (
self._add_required_properties(self._clean(self.stream_to_builder[stream_name].to_schema()))
if stream_name in self.stream_to_builder
else None
)
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,19 @@ def test_config_update():

@patch("traceback.TracebackException.from_exception")
def test_read_returns_error_response(mock_from_exception):
class MockDeclarativeStream:
@property
def primary_key(self):
return [[]]

@property
def cursor_field(self):
return []

class MockManifestDeclarativeSource:
def streams(self, config):
return [MockDeclarativeStream()]

def read(self, logger, config, catalog, state):
raise ValueError("error_message")

Expand Down
Loading

0 comments on commit 2f34f08

Please sign in to comment.