Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

add monkeypatch to hotfix pydantic Union issues #982

Merged
6 commits merged into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/pytypes/onefuzztypes/_monkeypatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# TODO: Remove once `smart_union` like support is added to Pydantic
#
# Written by @PrettyWood
# Code from https://github.com/samuelcolvin/pydantic/pull/2092
#
# Original project licensed under the MIT License.

from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

from pydantic.fields import ModelField
from pydantic.typing import get_origin

if TYPE_CHECKING:
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
from pydantic.error_wrappers import ErrorList
from pydantic.types import ModelOrDc

ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]]
LocStr = Union[Tuple[Union[int, str], ...], str]

orig = ModelField._validate_singleton
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved


def wrapper(
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
self: ModelField,
v: Any,
values: Dict[str, Any],
loc: "LocStr",
cls: Optional["ModelOrDc"],
) -> "ValidateReturn":
if self.sub_fields:
if get_origin(self.type_) is Union:
for field in self.sub_fields:
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
if v.__class__ is field.outer_type_:
return v, None
for field in self.sub_fields:
try:
if isinstance(v, field.outer_type_):
return v, None
except TypeError:
pass

return orig(self, v, values, loc, cls)


ModelField._validate_singleton = wrapper # type: ignore


def _check_hotfix() -> None:
bmc-msft marked this conversation as resolved.
Show resolved Hide resolved
if ModelField._validate_singleton != wrapper:
raise Exception("pydantic Union hotfix not applied")
4 changes: 4 additions & 0 deletions src/pytypes/onefuzztypes/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from pydantic import BaseModel, Field

from ._monkeypatch import _check_hotfix
from .enums import (
OS,
Architecture,
Expand Down Expand Up @@ -319,3 +320,6 @@ def parse_event_message(data: Dict[str, Any]) -> EventMessage:
instance_id=instance_id,
instance_name=instance_name,
)


_check_hotfix()
4 changes: 4 additions & 0 deletions src/pytypes/onefuzztypes/job_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pydantic import BaseModel, Field, root_validator, validator

from ._monkeypatch import _check_hotfix
from .enums import OS, ContainerType, UserFieldOperation, UserFieldType
from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers
from .primitives import File
Expand Down Expand Up @@ -184,3 +185,6 @@ class JobTemplateGet(BaseRequest):

class JobTemplateRequestParameters(BaseRequest):
user_fields: TemplateUserFields


_check_hotfix()
4 changes: 4 additions & 0 deletions src/pytypes/onefuzztypes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel, Field, root_validator, validator
from pydantic.dataclasses import dataclass

from ._monkeypatch import _check_hotfix
from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import (
OS,
Expand Down Expand Up @@ -842,3 +843,6 @@ class Task(BaseModel):
events: Optional[List[TaskEventSummary]]
nodes: Optional[List[NodeAssignment]]
user_info: Optional[UserInfo]


_check_hotfix()
4 changes: 4 additions & 0 deletions src/pytypes/onefuzztypes/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator

from ._monkeypatch import _check_hotfix
from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import (
OS,
Expand Down Expand Up @@ -250,3 +251,6 @@ class WebhookUpdate(BaseModel):
class NodeAddSshKey(BaseModel):
machine_id: UUID
public_key: str


_check_hotfix()