From 19d46b9a33c8842b48a2df51d966796b8231a45f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Sep 2023 11:25:33 -0400 Subject: [PATCH 1/5] q --- .github/workflows/test.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2f16f7f..9ee81b1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,6 +3,18 @@ on: push: branches: [main] pull_request: + workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI + +# If another push to the same PR or branch happens while this workflow is still running, +# cancel the earlier run in favor of the next run. +# +# There's no point in testing an outdated version of the code. GitHub only allows +# a limited number of job runners to be active at the same time, so it's better to cancel +# pointless jobs early so that more useful jobs can run sooner. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + env: POETRY_VERSION: "1.3.1" @@ -31,3 +43,10 @@ jobs: - name: Run unit tests run: | poetry run poe test + + pydantic-compatibility: + uses: + ./.github/workflows/_pydantic_compatibility.yml + with: + working-directory: . + secrets: inherit From ff4cc8733cff94431b30a85987583384d8aebd07 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Sep 2023 11:28:15 -0400 Subject: [PATCH 2/5] x --- .github/workflows/_pydantic_compatibility.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/_pydantic_compatibility.yml b/.github/workflows/_pydantic_compatibility.yml index 7d8fe26..68f142a 100644 --- a/.github/workflows/_pydantic_compatibility.yml +++ b/.github/workflows/_pydantic_compatibility.yml @@ -78,4 +78,5 @@ jobs: echo "Found pydantic version ${CURRENT_VERSION}, as expected" - name: Run pydantic compatibility tests shell: bash - run: make test + run: poetry run poe test + From c1a4dd4e5f03db5d68912682b6f0f06b984b7227 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Sep 2023 11:44:35 -0400 Subject: [PATCH 3/5] x --- kor/nodes.py | 13 ++++++++++++- tests/test_examples.py | 8 ++++---- tests/test_serialization.py | 8 ++++---- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/kor/nodes.py b/kor/nodes.py index 8d8dad0..9f73c22 100644 --- a/kor/nodes.py +++ b/kor/nodes.py @@ -119,8 +119,11 @@ class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC): def __init__(self, **kwargs: Any) -> None: """Initialize.""" - kwargs[TYPE_DISCRIMINATOR_FIELD] = type(self).__name__ + if PYDANTIC_MAJOR_VERSION == 2: + kwargs[TYPE_DISCRIMINATOR_FIELD] = type(self).__name__ super().__init__(**kwargs) + if PYDANTIC_MAJOR_VERSION == 1: + self.__dict__[TYPE_DISCRIMINATOR_FIELD] = type(self).__name__ @classmethod def parse_obj(cls, data: dict) -> ExtractionSchemaNode: @@ -146,6 +149,10 @@ def validate(cls: Type[ExtractionSchemaNode], v: Any) -> ExtractionSchemaNode: class Number(ExtractionSchemaNode): """Built-in number input.""" + examples: Sequence[ + Tuple[str, Union[int, float, Sequence[Union[float, int]]]] + ] = tuple() + def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: """Accept a visitor.""" return visitor.visit_number(self, **kwargs) @@ -154,6 +161,8 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: class Text(ExtractionSchemaNode): """Built-in text input.""" + examples: Sequence[Tuple[str, Union[Sequence[str], str]]] = tuple() + def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: """Accept a visitor.""" return visitor.visit_text(self, **kwargs) @@ -162,6 +171,8 @@ def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: class Bool(ExtractionSchemaNode): """Built-in bool input.""" + examples: Sequence[Tuple[str, Union[Sequence[bool], bool]]] = tuple() + def accept(self, visitor: AbstractVisitor[T], **kwargs: Any) -> T: """Accept a visitor.""" return visitor.visit_bool(self, **kwargs) diff --git a/tests/test_examples.py b/tests/test_examples.py index f539b1e..83c5f45 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -9,7 +9,7 @@ def test_example_generation() -> None: """ option = Option(id="option", description="Option", examples=["selection"]) number = Number( - id="number", description="Number", examples=[("number", "2")], many=True + id="number", description="Number", examples=[("number", 2)], many=True ) text = Text(id="text", description="Text", examples=[("text", "3")], many=True) @@ -24,15 +24,15 @@ def test_example_generation() -> None: obj = Object( id="object", description="object", - examples=[("another number", {"number": "1"})], + examples=[("another number", {"number": 1})], attributes=[number, text, selection], many=True, ) examples = generate_examples(obj) assert examples == [ - ("another number", {"object": [{"number": "1"}]}), - ("number", {"object": [{"number": ["2"]}]}), + ("another number", {"object": [{"number": 1}]}), + ("number", {"object": [{"number": [2]}]}), ("text", {"object": [{"text": ["3"]}]}), ("selection", {"object": [{"selection": ["option"]}]}), ("foo", {}), diff --git a/tests/test_serialization.py b/tests/test_serialization.py index d8034e9..b09199c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -33,8 +33,8 @@ def test_serialize_deserialize_equals() -> None: id="root", description="root-object", attributes=[ - Text(id="text", description="text description", examples=[]), Number(id="number", description="Number description", examples=[]), + Text(id="text", description="text description", examples=[]), Bool(id="bool", description="bool description", examples=[]), ], examples=[], @@ -50,21 +50,21 @@ def test_serialize_deserialize_equals() -> None: "examples": [], "id": "number", "many": False, - "type_": "Number", + "$type": "Number", }, { "description": "text description", "examples": [], "id": "text", "many": False, - "type_": "Text", + "$type": "Text", }, { "description": "bool description", "examples": [], "id": "bool", "many": False, - "type_": "Bool", + "$type": "Bool", }, ], "description": "root-object", From 17f3d783f9559e24766c4024aedc1b517fc4c8d4 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Sep 2023 12:42:05 -0400 Subject: [PATCH 4/5] x --- kor/adapters.py | 13 +++++++------ kor/validators.py | 14 ++++++++++---- tests/test_adapters.py | 2 +- tests/utils.py | 2 +- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/kor/adapters.py b/kor/adapters.py index 02678fb..fc3a329 100644 --- a/kor/adapters.py +++ b/kor/adapters.py @@ -105,16 +105,17 @@ def _translate_pydantic_to_kor( if PYDANTIC_MAJOR_VERSION == 1: field_info = field.field_info extra = field_info.extra - field_examples = extra.get("examples", tuple()) - field_description = field_info.description or "" - type_ = field.type_ + field_examples = extra.get( # type: ignore[attr-defined] + "examples", tuple() + ) + field_description = getattr(field_info, "description", "") + type_ = field.outer_type_ else: type_ = field.annotation - field_examples = field.examples or tuple() - field_description = field.description or "" + field_examples = field.examples or tuple() # type: ignore[attr-defined] + field_description = getattr(field, "description", "") field_many = _is_many(type_) - get_origin(type_) attribute: Union[ExtractionSchemaNode, Selection, "Object"] diff --git a/kor/validators.py b/kor/validators.py index 3acf92c..9dca1c5 100644 --- a/kor/validators.py +++ b/kor/validators.py @@ -55,9 +55,13 @@ def clean_data( for item in data: try: if PYDANTIC_MAJOR_VERSION == 1: - record = self.model_class.parse_obj(item) + record = self.model_class.parse_obj( # type: ignore[attr-defined] + item + ) else: - record = self.model_class.model_validate(item) + record = self.model_class.model_validate( # type: ignore[attr-defined] + item + ) records.append(record) except ValidationError as e: @@ -65,10 +69,12 @@ def clean_data( return records, exceptions else: try: + model_class = self.model_class if PYDANTIC_MAJOR_VERSION == 1: - record = self.model_class.parse_obj(data) + _loader = model_class.parse_obj # type: ignore[attr-defined] else: - record = self.model_class.model_validate(data) + _loader = model_class.model_validate # type: ignore[attr-defined] + record = _loader(data) return record, [] except ValidationError as e: return None, [e] diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 5772303..9e3b238 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -82,7 +82,7 @@ class Toy(pydantic.BaseModel): attributes=[ Text(id="a", description="hello"), Number(id="b", examples=[("b is 1", 1)]), - Number(id="c"), + Number(id="c", many=False), Bool(id="d"), # We don't have optional yet internally, so we don't check the # optional setting. diff --git a/tests/utils.py b/tests/utils.py index 4ba9b8e..91e1467 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,7 +12,7 @@ if PYDANTIC_MAJOR_VERSION == 1: from pydantic import Extra # type: ignore[assignment] else: - from pydantic.v1 import Extra # type: ignore[assignment] + from pydantic.v1 import Extra # type: ignore[assignment,no-redef] class ToyChatModel(BaseChatModel): From 1af76fa7fea254ee8f2604447ab5ed66171d48c9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 6 Sep 2023 12:47:19 -0400 Subject: [PATCH 5/5] x --- kor/adapters.py | 4 ++-- kor/nodes.py | 4 ++-- kor/validators.py | 14 ++++++-------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/kor/adapters.py b/kor/adapters.py index fc3a329..aec5445 100644 --- a/kor/adapters.py +++ b/kor/adapters.py @@ -108,12 +108,12 @@ def _translate_pydantic_to_kor( field_examples = extra.get( # type: ignore[attr-defined] "examples", tuple() ) - field_description = getattr(field_info, "description", "") + field_description = getattr(field_info, "description") or "" type_ = field.outer_type_ else: type_ = field.annotation field_examples = field.examples or tuple() # type: ignore[attr-defined] - field_description = getattr(field, "description", "") + field_description = getattr(field, "description") or "" field_many = _is_many(type_) diff --git a/kor/nodes.py b/kor/nodes.py index 9f73c22..414cc10 100644 --- a/kor/nodes.py +++ b/kor/nodes.py @@ -119,8 +119,6 @@ class ExtractionSchemaNode(AbstractSchemaNode, abc.ABC): def __init__(self, **kwargs: Any) -> None: """Initialize.""" - if PYDANTIC_MAJOR_VERSION == 2: - kwargs[TYPE_DISCRIMINATOR_FIELD] = type(self).__name__ super().__init__(**kwargs) if PYDANTIC_MAJOR_VERSION == 1: self.__dict__[TYPE_DISCRIMINATOR_FIELD] = type(self).__name__ @@ -128,6 +126,8 @@ def __init__(self, **kwargs: Any) -> None: @classmethod def parse_obj(cls, data: dict) -> ExtractionSchemaNode: """Parse an object.""" + if PYDANTIC_MAJOR_VERSION != 1: + raise NotImplementedError("Only supported for pydantic 1.x") type_ = data.pop(TYPE_DISCRIMINATOR_FIELD, None) if type_ is None: raise ValueError(f"Need to specify type ({TYPE_DISCRIMINATOR_FIELD})") diff --git a/kor/validators.py b/kor/validators.py index 9dca1c5..1d32357 100644 --- a/kor/validators.py +++ b/kor/validators.py @@ -48,6 +48,8 @@ def clean_data( Returns: cleaned data instantiated as the corresponding pydantic model """ + model_ = self.model_class # a proxy to make code fit in char limit + if self.many: exceptions: List[Exception] = [] records: List[BaseModel] = [] @@ -55,11 +57,9 @@ def clean_data( for item in data: try: if PYDANTIC_MAJOR_VERSION == 1: - record = self.model_class.parse_obj( # type: ignore[attr-defined] - item - ) + record = model_.parse_obj(item) # type: ignore[attr-defined] else: - record = self.model_class.model_validate( # type: ignore[attr-defined] + record = model_.model_validate( # type: ignore[attr-defined] item ) @@ -69,12 +69,10 @@ def clean_data( return records, exceptions else: try: - model_class = self.model_class if PYDANTIC_MAJOR_VERSION == 1: - _loader = model_class.parse_obj # type: ignore[attr-defined] + record = model_.parse_obj(data) # type: ignore[attr-defined] else: - _loader = model_class.model_validate # type: ignore[attr-defined] - record = _loader(data) + record = model_.model_validate(data) # type: ignore[attr-defined] return record, [] except ValidationError as e: return None, [e]