From a1996e0a06b9fd18ba0596a2927ee0e77d8fa43a Mon Sep 17 00:00:00 2001 From: elanou Date: Thu, 11 Sep 2025 16:38:54 -0400 Subject: [PATCH] Update trp2.py BaseSchema such that unknown or new fields from textract model don't cause marshmallow validation errors and kill existing functionality of this library. --- src-python/tests/test_trp2.py | 35 +++++++++++++++++++++++++++++++++++ src-python/trp/trp2.py | 16 ++++------------ 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src-python/tests/test_trp2.py b/src-python/tests/test_trp2.py index 3e19702..8fc5faa 100644 --- a/src-python/tests/test_trp2.py +++ b/src-python/tests/test_trp2.py @@ -428,6 +428,41 @@ def test_tbbox_union(): assert (b_union == b_gt) +def test_geometry_schema_ignores_unknown_fields(): + """ + Ensure unknown fields are excluded during load for Geometry schema. + This validates BaseSchema Meta.unknown = EXCLUDE across nested schemas. + """ + geometry_input = { + "BoundingBox": { + "Width": 1.0, + "Height": 1.0, + "Left": 0.0, + "Top": 0.0, + # Unknown field inside nested schema + "BogusField": "should be ignored", + }, + "Polygon": [ + {"X": 0.0, "Y": 0.0}, + {"X": 1.0, "Y": 0.0}, + {"X": 1.0, "Y": 1.0}, + {"X": 0.0, "Y": 1.0}, + ], + # Unknown field at top-level geometry + "Foo": "bar", + "AnotherUnknown": {"nested": 1}, + } + + geom = t2.TGeometrySchema().load(geometry_input) # type: ignore + assert isinstance(geom, t2.TGeometry) + + dumped = t2.TGeometrySchema().dump(geom) + # Unknown fields should not appear in dump + assert "Foo" not in dumped + assert "AnotherUnknown" not in dumped + assert "BogusField" not in dumped.get("BoundingBox", {}) + + def test_get_blocks_for_relationship(caplog): caplog.set_level(logging.DEBUG) diff --git a/src-python/trp/trp2.py b/src-python/trp/trp2.py index 7020203..fd86d79 100644 --- a/src-python/trp/trp2.py +++ b/src-python/trp/trp2.py @@ -26,6 +26,10 @@ class BaseSchema(m.Schema): """ SKIP_VALUES = set([None]) + class Meta: + # Exclude unknown fields during load to be lenient with inputs + unknown = m.EXCLUDE + @m.post_dump def remove_skip_values(self, data, many, pass_many=False): return { @@ -839,10 +843,6 @@ def link_tables(self, table_array_ids: List[List[str]]): class THttpHeadersSchema(BaseSchema): - - class Meta: - unknown = m.EXCLUDE - date = m.fields.String(data_key="date", required=False) x_amzn_request_id = m.fields.String(data_key="x-amzn-requestid", required=False, allow_none=False) content_type = m.fields.String(data_key="content-type", required=False, allow_none=False) @@ -855,10 +855,6 @@ def make_thttp_headers(self, data, **kwargs): class TResponseMetadataSchema(BaseSchema): - - class Meta: - unknown = m.EXCLUDE - request_id = m.fields.String(data_key="RequestId", required=False, allow_none=False) http_status_code = m.fields.Int(data_key="HTTPStatusCode", required=False, allow_none=False) retry_attempts = m.fields.Int(data_key="RetryAttempts", required=False, allow_none=False) @@ -870,10 +866,6 @@ def make_tresponse_metadata(self, data, **kwargs): class TDocumentSchema(BaseSchema): - - class Meta: - unknown = m.EXCLUDE - document_metadata = m.fields.Nested(TDocumentMetadataSchema, data_key="DocumentMetadata", required=False,