Skip to content

Commit

Permalink
CAT: set additionalProperties recursively for objects (#34448)
Browse files Browse the repository at this point in the history
  • Loading branch information
artem1205 authored Jan 29, 2024
1 parent b37efe9 commit dccb2fa
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 3.3.2
Fix TestBasicRead.test_read.validate_schema: set `additionalProperties` to False recursively for objects

## 3.3.1
Fix TestSpec.test_oauth_is_default_method to skip connectors that doesn't have predicate_key object.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from collections import defaultdict
from typing import Any, Dict, List, Mapping

import dpath.util
import pendulum
from airbyte_protocol.models import AirbyteRecordMessage, ConfiguredAirbyteCatalog
from jsonschema import Draft7Validator, FormatChecker, FormatError, ValidationError, validators
Expand All @@ -26,6 +25,40 @@
Draft7ValidatorWithStrictInteger = validators.extend(Draft7Validator, type_checker=strict_integer_type_checker)


class NoAdditionalPropertiesValidator(Draft7Validator):
def __init__(self, schema, **kwargs):
schema = self._enforce_false_additional_properties(schema)
super().__init__(schema, **kwargs)

@staticmethod
def _enforce_false_additional_properties(json_schema: Dict[str, Any]) -> Dict[str, Any]:
"""Create a copy of the schema in which `additionalProperties` is set to False for all non-null object properties.
This method will override the value of `additionalProperties` if it is set,
or will create the property and set it to False if it does not exist.
"""
new_schema = copy.deepcopy(json_schema)
new_schema["additionalProperties"] = False

def add_properties(properties):
for prop_name, prop_value in properties.items():
if "type" in prop_value and "object" in prop_value["type"] and len(prop_value.get("properties", [])):
prop_value["additionalProperties"] = False
add_properties(prop_value.get("properties", {}))
elif "type" in prop_value and "array" in prop_value["type"]:
if (
prop_value.get("items")
and "object" in prop_value.get("items", {}).get("type")
and len(prop_value.get("items", {}).get("properties", []))
):
prop_value["items"]["additionalProperties"] = False
if prop_value.get("items", {}).get("properties"):
add_properties(prop_value["items"]["properties"])

add_properties(new_schema.get("properties", {}))
return new_schema


class CustomFormatChecker(FormatChecker):
@staticmethod
def check_datetime(value: str) -> bool:
Expand All @@ -46,17 +79,6 @@ def check(self, instance, format):
return super().check(instance, format)


def _enforce_no_additional_top_level_properties(json_schema: Dict[str, Any]):
"""Create a copy of the schema in which `additionalProperties` is set to False for the dict of top-level properties.
This method will override the value of `additionalProperties` if it is set,
or will create the property and set it to False if it does not exist.
"""
enforced_schema = copy.deepcopy(json_schema)
dpath.util.new(enforced_schema, "additionalProperties", False)
return enforced_schema


def verify_records_schema(
records: List[AirbyteRecordMessage], catalog: ConfiguredAirbyteCatalog, fail_on_extra_columns: bool
) -> Mapping[str, Mapping[str, ValidationError]]:
Expand All @@ -66,11 +88,8 @@ def verify_records_schema(
stream_validators = {}
for stream in catalog.streams:
schema_to_validate_against = stream.stream.json_schema
if fail_on_extra_columns:
schema_to_validate_against = _enforce_no_additional_top_level_properties(schema_to_validate_against)
stream_validators[stream.stream.name] = Draft7ValidatorWithStrictInteger(
schema_to_validate_against, format_checker=CustomFormatChecker()
)
validator = NoAdditionalPropertiesValidator if fail_on_extra_columns else Draft7ValidatorWithStrictInteger
stream_validators[stream.stream.name] = validator(schema_to_validate_against, format_checker=CustomFormatChecker())
stream_errors = defaultdict(dict)
for record in records:
validator = stream_validators.get(record.stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,62 @@ def test_verify_records_schema(configured_catalog: ConfiguredAirbyteCatalog):
]


@pytest.mark.parametrize(
"json_schema, record, should_fail",
[
(
{"type": "object", "properties": {"a": {"type": "string"}}},
{"a": "str", "b": "extra_string"},
True
),
(
{"type": "object", "properties": {"a": {"type": "string"}, "some_obj": {"type": ["null", "object"]}}},
{"a": "str", "some_obj": {"b": "extra_string"}},
False
),
(
{
"type": "object",
"properties": {"a": {"type": "string"}, "some_obj": {"type": ["null", "object"], "properties": {"a": {"type": "string"}}}},
},
{"a": "str", "some_obj": {"a": "str", "b": "extra_string"}},
True
),
(
{"type": "object", "properties": {"a": {"type": "string"}, "b": {"type": "array", "items": {"type": "object"}}}},
{"a": "str", "b": [{"a": "extra_string"}]},
False
),
(
{
"type": "object",
"properties": {
"a": {"type": "string"},
"b": {"type": "array", "items": {"type": "object", "properties": {"a": {"type": "string"}}}},
}
},
{"a": "str", "b": [{"a": "string", "b": "extra_string"}]},
True
),
],
ids=[
"simple_schema_and_record_with_extra_property",
"schema_with_object_without_properties_and_record_with_object_with_property",
"schema_with_object_with_properties_and_record_with_object_with_extra_property",
"schema_with_array_of_objects_without_properties_and_record_with_array_of_objects_with_property",
"schema_with_array_of_objects_with_properties_and_record_with_array_of_objects_with_extra_property",
],
)
def test_verify_records_schema_with_fail_on_extra_columns(configured_catalog: ConfiguredAirbyteCatalog, json_schema, record, should_fail):
"""Test that fail_on_extra_columns works correctly with nested objects, array of objects"""
configured_catalog.streams[0].stream.json_schema =json_schema
records = [AirbyteRecordMessage(stream="my_stream", data=record, emitted_at=0)]
streams_with_errors = verify_records_schema(records, configured_catalog, fail_on_extra_columns=True)
errors = [error.message for error in streams_with_errors["my_stream"].values()]
assert errors if should_fail else not errors


@pytest.mark.parametrize(
"record, configured_catalog, valid",
[
Expand Down

0 comments on commit dccb2fa

Please sign in to comment.