diff --git a/src/pytypes/onefuzztypes/_monkeypatch.py b/src/pytypes/onefuzztypes/_monkeypatch.py new file mode 100644 index 0000000000..fc784c417d --- /dev/null +++ b/src/pytypes/onefuzztypes/_monkeypatch.py @@ -0,0 +1,52 @@ +# TODO: Remove once `smart_union` like support is added to Pydantic +# +# Written by @PrettyWood +# Code from https://github.com/samuelcolvin/pydantic/pull/2092 +# +# Original project licensed under the MIT License. + +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from pydantic.fields import ModelField +from pydantic.typing import get_origin + +if TYPE_CHECKING: + from pydantic.fields import LocStr, ValidateReturn + from pydantic.types import ModelOrDc + +upstream_validate_singleton = ModelField._validate_singleton + + +# this is a direct port of the functionality from the PR discussed above, though +# *all* unions are considered "smart" for our purposes. +def wrap_validate_singleton( + self: ModelField, + v: Any, + values: Dict[str, Any], + loc: "LocStr", + cls: Optional["ModelOrDc"], +) -> "ValidateReturn": + if self.sub_fields: + if get_origin(self.type_) is Union: + for field in self.sub_fields: + if v.__class__ is field.outer_type_: + return v, None + for field in self.sub_fields: + try: + if isinstance(v, field.outer_type_): + return v, None + except TypeError: + pass + + return upstream_validate_singleton(self, v, values, loc, cls) + + +ModelField._validate_singleton = wrap_validate_singleton # type: ignore + + +# this should be included in any file that defines a pydantic model that uses a +# Union and calls to it should be removed when Pydantic's smart union support +# lands +def _check_hotfix() -> None: + if ModelField._validate_singleton != wrap_validate_singleton: + raise Exception("pydantic Union hotfix not applied") diff --git a/src/pytypes/onefuzztypes/events.py b/src/pytypes/onefuzztypes/events.py index 85e405ee2f..ef23595990 100644 --- a/src/pytypes/onefuzztypes/events.py +++ b/src/pytypes/onefuzztypes/events.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field +from ._monkeypatch import _check_hotfix from .enums import ( OS, Architecture, @@ -319,3 +320,6 @@ def parse_event_message(data: Dict[str, Any]) -> EventMessage: instance_id=instance_id, instance_name=instance_name, ) + + +_check_hotfix() diff --git a/src/pytypes/onefuzztypes/job_templates.py b/src/pytypes/onefuzztypes/job_templates.py index f1dc41b242..c0d32e616c 100644 --- a/src/pytypes/onefuzztypes/job_templates.py +++ b/src/pytypes/onefuzztypes/job_templates.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, root_validator, validator +from ._monkeypatch import _check_hotfix from .enums import OS, ContainerType, UserFieldOperation, UserFieldType from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers from .primitives import File @@ -184,3 +185,6 @@ class JobTemplateGet(BaseRequest): class JobTemplateRequestParameters(BaseRequest): user_fields: TemplateUserFields + + +_check_hotfix() diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index bd22204748..4ad4999639 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, root_validator, validator from pydantic.dataclasses import dataclass +from ._monkeypatch import _check_hotfix from .consts import ONE_HOUR, SEVEN_DAYS from .enums import ( OS, @@ -842,3 +843,6 @@ class Task(BaseModel): events: Optional[List[TaskEventSummary]] nodes: Optional[List[NodeAssignment]] user_info: Optional[UserInfo] + + +_check_hotfix() diff --git a/src/pytypes/onefuzztypes/requests.py b/src/pytypes/onefuzztypes/requests.py index df7c32ebb2..e3c36a9073 100644 --- a/src/pytypes/onefuzztypes/requests.py +++ b/src/pytypes/onefuzztypes/requests.py @@ -8,6 +8,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator +from ._monkeypatch import _check_hotfix from .consts import ONE_HOUR, SEVEN_DAYS from .enums import ( OS, @@ -250,3 +251,6 @@ class WebhookUpdate(BaseModel): class NodeAddSshKey(BaseModel): machine_id: UUID public_key: str + + +_check_hotfix()