Skip to content

Commit

Permalink
bump to pydantic 2.0
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com>
  • Loading branch information
ThibaultFy committed Jul 25, 2023
1 parent f783727 commit 429fc44
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
8 changes: 4 additions & 4 deletions substra/sdk/backends/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _check_metadata(metadata: Optional[Dict[str, str]]):
if metadata is not None:
if any([len(key) > _MAX_LEN_KEY_METADATA for key in metadata]):
raise exceptions.InvalidRequest("The key in metadata cannot be more than 50 characters", 400)
if any([len(value) > _MAX_LEN_VALUE_METADATA or len(value) == 0 for value in metadata.values()]):
if any([len(str(value)) > _MAX_LEN_VALUE_METADATA or len(str(value)) == 0 for value in metadata.values()]):
raise exceptions.InvalidRequest("Values in metadata cannot be empty or more than 100 characters", 400)
if any("__" in key for key in metadata):
raise exceptions.InvalidRequest(
Expand Down Expand Up @@ -327,7 +327,7 @@ def _add_dataset(self, key, spec, spec_options=None):
"storage_address": dataset_description_path,
},
metadata=spec.metadata if spec.metadata else dict(),
logs_permission=logs_permission.dict(),
logs_permission=logs_permission.model_dump(),
)
return self._db.add(asset)

Expand Down Expand Up @@ -555,7 +555,7 @@ def update(self, key, spec, spec_options=None):
asset_type = spec.__class__.type_
asset = self.get(asset_type, key)
data = asset.dict()
data.update(spec.dict())
data.update(spec.model_dump())
updated_asset = models.SCHEMA_TO_MODEL[asset_type](**data)
self._db.update(updated_asset)
return
Expand Down Expand Up @@ -588,7 +588,7 @@ def _output_from_spec(outputs: Dict[str, schemas.ComputeTaskOutputSpec]) -> Dict
"""Convert a list of schemas.ComputeTaskOuput to a list of models.ComputeTaskOutput"""
return {
identifier: models.ComputeTaskOutput(
permissions=models.Permissions(process=output.permissions), transient=output.is_transient, value=None
permissions=models.Permissions(process=dict(output.permissions)), transient=output.is_transient, value=None
)
# default isNone (= outputs are not computed yet)
for identifier, output in outputs.items()
Expand Down
2 changes: 1 addition & 1 deletion substra/sdk/backends/local/compute/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _save_cp_performances_as_json(self, compute_plan_key: str, path: Path):
performances = self._db.get_performances(compute_plan_key)

with (path).open("w", encoding="UTF-8") as json_file:
json.dump(performances.dict(), json_file, default=str)
json.dump(performances.model_dump(), json_file, default=str)

def _get_asset_unknown_type(self, asset_key, possible_types: List[schemas.Type]) -> Tuple[Any, schemas.Type]:
for asset_type in possible_types:
Expand Down
13 changes: 7 additions & 6 deletions substra/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ class Function(_Model):
permissions: Permissions
metadata: Dict[str, str]
creation_date: datetime
inputs: List[FunctionInput]
outputs: List[FunctionOutput]
inputs: List[schemas.FunctionInputSpec]
outputs: List[schemas.FunctionOutputSpec]

description: _File
function: _File
Expand All @@ -190,7 +190,7 @@ def dict_input_to_list(cls, v):
if isinstance(v, dict):
# Transform the inputs dict to a list
return [
FunctionInput(
schemas.FunctionInputSpec(
identifier=identifier,
kind=function_input["kind"],
optional=function_input["optional"],
Expand Down Expand Up @@ -246,7 +246,7 @@ class InputRef(schemas._PydanticConfig):
parent_task_output_identifier: Optional[str] = None

# either (asset_key) or (parent_task_key, parent_task_output_identifier) must be specified
_check_asset_key_or_parent_ref = pydantic.root_validator(allow_reuse=True)(schemas.check_asset_key_or_parent_ref)
_check_asset_key_or_parent_ref = pydantic.model_validator(mode="before")(schemas.check_asset_key_or_parent_ref)


class ComputeTaskOutput(schemas._PydanticConfig):
Expand All @@ -262,7 +262,7 @@ class Task(_Model):
function: Function
owner: str
compute_plan_key: str
metadata: Dict[str, str]
metadata: Union[Dict[str, str], Dict[str, int]]
status: Status
worker: str
rank: Optional[int] = None
Expand Down Expand Up @@ -361,7 +361,8 @@ class Organization(schemas._PydanticConfig):
type_: ClassVar[str] = schemas.Type.Organization


class OrganizationInfoConfig(schemas._PydanticConfig):
class OrganizationInfoConfig(schemas._PydanticConfig, extra="allow"):
model_config = ConfigDict(protected_namespaces=())
model_export_enabled: bool


Expand Down
26 changes: 15 additions & 11 deletions substra/sdk/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

import pydantic
from pydantic.fields import Field
Expand Down Expand Up @@ -76,6 +77,7 @@ def __str__(self):

class _PydanticConfig(pydantic.BaseModel):
"""Shared configuration for all schemas here"""

model_config = ConfigDict(extra="ignore")


Expand Down Expand Up @@ -143,7 +145,7 @@ class DataSampleSpec(_Spec):
def is_many(self):
return self.paths and len(self.paths) > 0

@pydantic.root_validator(pre=True)
@pydantic.model_validator(mode="before")
def exclusive_paths(cls, values):
"""Check that one and only one path(s) field is defined."""
if "paths" in values and "path" in values:
Expand All @@ -167,8 +169,8 @@ def build_request_kwargs(self, local):
def check_asset_key_or_parent_ref(cls, values):
"""Check that either (asset key) or (parent_task_key, parent_task_output_identifier) are set, but not both."""

has_asset_key = bool(values.get("asset_key"))
has_parent = bool(values.get("parent_task_key")) and bool(values.get("parent_task_output_identifier"))
has_asset_key = bool(dict(values).get("asset_key"))
has_parent = bool(dict(values).get("parent_task_key")) and bool(dict(values).get("parent_task_output_identifier"))

if has_asset_key != has_parent: # xor
return values
Expand All @@ -185,7 +187,7 @@ class InputRef(_PydanticConfig):
parent_task_output_identifier: Optional[str] = None

# either (asset_key) or (parent_task_key, parent_task_output_identifier) must be specified
_check_asset_key_or_parent_ref = pydantic.root_validator(allow_reuse=True)(check_asset_key_or_parent_ref)
_check_asset_key_or_parent_ref = pydantic.model_validator(mode="before")(check_asset_key_or_parent_ref)


class ComputeTaskOutputSpec(_PydanticConfig):
Expand All @@ -206,7 +208,7 @@ class ComputePlanTaskSpec(_Spec):
function_key: str
worker: str
tag: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
metadata: Optional[Union[Dict[str, str], Dict[str, int]]] = None
inputs: Optional[List[InputRef]] = None
outputs: Optional[Dict[str, ComputeTaskOutputSpec]] = None

Expand Down Expand Up @@ -289,7 +291,7 @@ class FunctionInputSpec(_Spec):
optional: bool
kind: AssetKind

@pydantic.root_validator
@pydantic.model_validator(mode="before")
def _check_identifiers(cls, values):
"""Checks that the multiplicity and the optionality of a data manager is always set to False"""
if values["kind"] == AssetKind.data_manager:
Expand Down Expand Up @@ -323,7 +325,7 @@ class FunctionOutputSpec(_Spec):
kind: AssetKind
multiple: bool

@pydantic.root_validator
@pydantic.model_validator(mode="before")
def _check_performance(cls, values):
"""Checks that the performance is always set to False"""
if values.get("kind") == AssetKind.performance and values.get("multiple"):
Expand Down Expand Up @@ -373,10 +375,12 @@ def build_request_kwargs(self):
# Computed fields using `@property` are not dumped when `exclude_unset` flag is enabled,
# this is why we need to reimplement this custom function.
data["inputs"] = (
{input.identifier: input.dict(exclude={"identifier"}) for input in self.inputs} if self.inputs else {}
{input.identifier: input.dict(exclude={"identifier"}) for input in self.inputs} if self.inputs else dict()
)
data["outputs"] = (
{output.identifier: output.dict(exclude={"identifier"}) for output in self.outputs} if self.outputs else {}
{output.identifier: output.dict(exclude={"identifier"}) for output in self.outputs}
if self.outputs
else dict()
)

if self.Meta.file_attributes:
Expand Down Expand Up @@ -404,7 +408,7 @@ class TaskSpec(_Spec):
key: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
tag: Optional[str] = None
compute_plan_key: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
metadata: Optional[Union[Dict[str, str], Dict[str, int]]] = None
function_key: str
worker: str
rank: Optional[int] = None
Expand All @@ -419,7 +423,7 @@ def build_request_kwargs(self):
# this is why we need to reimplement this custom function.
data = json.loads(self.json(exclude_unset=True))
data["key"] = self.key
data["inputs"] = [input.dict() for input in self.inputs] if self.inputs else []
data["inputs"] = [input.model_dump() for input in self.inputs] if self.inputs else []
data["outputs"] = {k: v.dict(by_alias=True) for k, v in self.outputs.items()} if self.outputs else {}
yield data, None

Expand Down

0 comments on commit 429fc44

Please sign in to comment.