From 743133c82cb427e0cfb531e908b4322e472cf861 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Mon, 1 Jul 2024 10:23:40 -0400 Subject: [PATCH] update base Config class Note: default serialization behavior has changed in pydantic v2. See https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing. We use model_dump(serialize_as_any=True) to get the old behavior. --- .../rastervision/pipeline/config.py | 50 ++++++++----------- .../rastervision/pipeline/pipeline_config.py | 22 +++++--- tests/pipeline/test_config.py | 8 ++- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/rastervision_pipeline/rastervision/pipeline/config.py b/rastervision_pipeline/rastervision/pipeline/config.py index 7d5766ec8..f655bf2df 100644 --- a/rastervision_pipeline/rastervision/pipeline/config.py +++ b/rastervision_pipeline/rastervision/pipeline/config.py @@ -1,20 +1,18 @@ from typing import (TYPE_CHECKING, Callable, Dict, List, Literal, Optional, - Type, Union) + Self, Type, Union) import inspect import logging -import json from pydantic import ( # noqa - BaseModel, create_model, Field, root_validator, validate_model, - ValidationError, validator) + ConfigDict, BaseModel, create_model, Field, model_validator, + ValidationError, field_validator) from rastervision.pipeline import (registry_ as registry, rv_config_ as rv_config) -from rastervision.pipeline.file_system import (str_to_file, json_to_file, - file_to_json) +from rastervision.pipeline.file_system import (file_to_json, json_to_file, + str_to_file) if TYPE_CHECKING: - from typing import Self from rastervision.pipeline.pipeline_config import PipelineConfig log = logging.getLogger(__name__) @@ -37,12 +35,7 @@ class Config(BaseModel): Validation, serialization, deserialization, and IDE support is provided automatically based on this schema. """ - - # This is here to forbid instantiating Configs with fields that do not - # exist in the schema, which helps avoid a command source of bugs. - class Config: - extra = 'forbid' - validate_assignment = True + model_config = ConfigDict(extra='forbid', validate_assignment=True) def update(self, *args, **kwargs): """Update any fields before validation. @@ -70,13 +63,8 @@ def revalidate(self): """Re-validate an instantiated Config. Runs all Pydantic validators plus self.validate_config(). - - Adapted from: - https://github.com/samuelcolvin/pydantic/issues/1864#issuecomment-679044432 """ - *_, validation_error = validate_model(self.__class__, self.__dict__) - if validation_error: - raise validation_error + self.model_validate(self.__dict__) self.validate_config() def recursive_validate_config(self): @@ -115,13 +103,18 @@ def validate_list(self, field: str, valid_options: List[str]): if val not in valid_options: raise ConfigError(f'{val} is not a valid option for {field}') + def copy(self) -> Self: + return self.model_copy() + def dict(self, with_rv_metadata: bool = False, **kwargs) -> dict: - cfg_json = self.json(**kwargs) - cfg_dict = json.loads(cfg_json) + cfg_dict = self.model_dump(serialize_as_any=True, **kwargs) if with_rv_metadata: cfg_dict['plugin_versions'] = registry.plugin_versions return cfg_dict + def json(self, **kwargs) -> dict: + return self.model_dump_json(serialize_as_any=True, **kwargs) + def to_file(self, uri: str, with_rv_metadata: bool = True) -> None: """Save a Config to a JSON file, optionally with RV metadata. @@ -131,11 +124,11 @@ def to_file(self, uri: str, with_rv_metadata: bool = True) -> None: ``plugin_versions``, so that the config can be upgraded when loaded. """ - cfg_dict = self.dict(with_rv_metadata=with_rv_metadata) - json_to_file(cfg_dict, uri) + cfg_json = self.dict(with_rv_metadata=with_rv_metadata) + json_to_file(cfg_json, uri) @classmethod - def deserialize(cls, inp: 'str | dict | Config') -> 'Self': + def deserialize(cls, inp: 'str | dict | Config') -> Self: """Deserialize Config from a JSON file or dict, upgrading if possible. If ``inp`` is already a :class:`.Config`, it is returned as is. @@ -152,7 +145,7 @@ def deserialize(cls, inp: 'str | dict | Config') -> 'Self': raise TypeError(f'Cannot deserialize Config from type: {type(inp)}.') @classmethod - def from_file(cls, uri: str) -> 'Self': + def from_file(cls, uri: str) -> Self: """Deserialize Config from a JSON file, upgrading if possible. Args: @@ -163,7 +156,7 @@ def from_file(cls, uri: str) -> 'Self': return cfg @classmethod - def from_dict(cls, cfg_dict: dict) -> 'Self': + def from_dict(cls, cfg_dict: dict) -> Self: """Deserialize Config from a dict. Args: @@ -211,10 +204,7 @@ def build_config(x: Union[dict, List[Union[dict, Config]], Config] Config: the corresponding Config(s) """ if isinstance(x, dict): - new_x = { - k: build_config(v) - for k, v in x.items() if k not in ('plugin_versions', 'rv_config') - } + new_x = {k: build_config(v) for k, v in x.items()} type_hint = new_x.get('type_hint') if type_hint is not None: config_cls = registry.get_config(type_hint) diff --git a/rastervision_pipeline/rastervision/pipeline/pipeline_config.py b/rastervision_pipeline/rastervision/pipeline/pipeline_config.py index b051ef998..b3c4c9d6f 100644 --- a/rastervision_pipeline/rastervision/pipeline/pipeline_config.py +++ b/rastervision_pipeline/rastervision/pipeline/pipeline_config.py @@ -1,11 +1,11 @@ from os.path import join -from typing import TYPE_CHECKING, Optional, Dict +from typing import TYPE_CHECKING, Self from rastervision.pipeline.config import Config, Field from rastervision.pipeline.config import register_config if TYPE_CHECKING: - from rastervision.pipeline.pipeline import Pipeline # noqa + from rastervision.pipeline.pipeline import Pipeline @register_config('pipeline') @@ -14,15 +14,15 @@ class PipelineConfig(Config): This should be subclassed to configure new Pipelines. """ - root_uri: str = Field( + root_uri: str | None = Field( None, description='The root URI for output generated by the pipeline') - rv_config: dict = Field( + rv_config: dict | None = Field( None, description='Used to store serialized RVConfig so pipeline can ' 'run in remote environment with the local RVConfig. This should ' 'not be set explicitly by users -- it is only used by the runner ' 'when running a remote pipeline.') - plugin_versions: Optional[Dict[str, int]] = Field( + plugin_versions: dict[str, int] | None = Field( None, description= ('Used to store a mapping of plugin module paths to the latest ' @@ -46,4 +46,14 @@ def build(self, tmp_dir: str) -> 'Pipeline': return Pipeline(self, tmp_dir) def dict(self, with_rv_metadata: bool = True, **kwargs) -> dict: - return super().dict(with_rv_metadata=with_rv_metadata, **kwargs) + # override to have with_rv_metadata=True by default + cfg_dict = super().dict(with_rv_metadata=with_rv_metadata, **kwargs) + return cfg_dict + + @classmethod + def from_dict(cls, cfg_dict: dict) -> Self: + # override to retain plugin_versions + from rastervision.pipeline.config import build_config, upgrade_config + cfg_dict: dict = upgrade_config(cfg_dict) + cfg = build_config(cfg_dict) + return cfg diff --git a/tests/pipeline/test_config.py b/tests/pipeline/test_config.py index 58aff4635..9dc99115b 100644 --- a/tests/pipeline/test_config.py +++ b/tests/pipeline/test_config.py @@ -2,11 +2,9 @@ from os.path import join import unittest -from pydantic.error_wrappers import ValidationError - from rastervision.pipeline.file_system.utils import get_tmp_dir, json_to_file -from rastervision.pipeline.config import (Config, register_config, - build_config, upgrade_config) +from rastervision.pipeline.config import ( + Config, register_config, build_config, upgrade_config, ValidationError) from rastervision.pipeline.pipeline_config import (PipelineConfig) from rastervision.pipeline import registry_ as registry @@ -123,7 +121,7 @@ def test_to_from(self): 'x' } - self.assertDictEqual(cfg.dict(), exp_dict) + self.assertDictEqual(cfg.dict(with_rv_metadata=False), exp_dict) self.assertEqual(build_config(exp_dict), cfg) def test_no_extras(self):