Skip to content

Commit

Permalink
update base Config class
Browse files Browse the repository at this point in the history
Note: default serialization behavior has changed in pydantic v2. See https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing. We use model_dump(serialize_as_any=True) to get the old behavior.
  • Loading branch information
AdeelH committed Jul 1, 2024
1 parent 9b5e8cd commit 743133c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 41 deletions.
50 changes: 20 additions & 30 deletions rastervision_pipeline/rastervision/pipeline/config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from typing import (TYPE_CHECKING, Callable, Dict, List, Literal, Optional,
Type, Union)
Self, Type, Union)
import inspect
import logging
import json

from pydantic import ( # noqa
BaseModel, create_model, Field, root_validator, validate_model,
ValidationError, validator)
ConfigDict, BaseModel, create_model, Field, model_validator,
ValidationError, field_validator)

from rastervision.pipeline import (registry_ as registry, rv_config_ as
rv_config)
from rastervision.pipeline.file_system import (str_to_file, json_to_file,
file_to_json)
from rastervision.pipeline.file_system import (file_to_json, json_to_file,
str_to_file)

if TYPE_CHECKING:
from typing import Self
from rastervision.pipeline.pipeline_config import PipelineConfig

log = logging.getLogger(__name__)
Expand All @@ -37,12 +35,7 @@ class Config(BaseModel):
Validation, serialization, deserialization, and IDE support is
provided automatically based on this schema.
"""

# This is here to forbid instantiating Configs with fields that do not
# exist in the schema, which helps avoid a command source of bugs.
class Config:
extra = 'forbid'
validate_assignment = True
model_config = ConfigDict(extra='forbid', validate_assignment=True)

def update(self, *args, **kwargs):
"""Update any fields before validation.
Expand Down Expand Up @@ -70,13 +63,8 @@ def revalidate(self):
"""Re-validate an instantiated Config.
Runs all Pydantic validators plus self.validate_config().
Adapted from:
https://github.com/samuelcolvin/pydantic/issues/1864#issuecomment-679044432
"""
*_, validation_error = validate_model(self.__class__, self.__dict__)
if validation_error:
raise validation_error
self.model_validate(self.__dict__)
self.validate_config()

def recursive_validate_config(self):
Expand Down Expand Up @@ -115,13 +103,18 @@ def validate_list(self, field: str, valid_options: List[str]):
if val not in valid_options:
raise ConfigError(f'{val} is not a valid option for {field}')

def copy(self) -> Self:
return self.model_copy()

def dict(self, with_rv_metadata: bool = False, **kwargs) -> dict:
cfg_json = self.json(**kwargs)
cfg_dict = json.loads(cfg_json)
cfg_dict = self.model_dump(serialize_as_any=True, **kwargs)
if with_rv_metadata:
cfg_dict['plugin_versions'] = registry.plugin_versions
return cfg_dict

def json(self, **kwargs) -> dict:
return self.model_dump_json(serialize_as_any=True, **kwargs)

def to_file(self, uri: str, with_rv_metadata: bool = True) -> None:
"""Save a Config to a JSON file, optionally with RV metadata.
Expand All @@ -131,11 +124,11 @@ def to_file(self, uri: str, with_rv_metadata: bool = True) -> None:
``plugin_versions``, so that the config can be upgraded when
loaded.
"""
cfg_dict = self.dict(with_rv_metadata=with_rv_metadata)
json_to_file(cfg_dict, uri)
cfg_json = self.dict(with_rv_metadata=with_rv_metadata)
json_to_file(cfg_json, uri)

@classmethod
def deserialize(cls, inp: 'str | dict | Config') -> 'Self':
def deserialize(cls, inp: 'str | dict | Config') -> Self:
"""Deserialize Config from a JSON file or dict, upgrading if possible.
If ``inp`` is already a :class:`.Config`, it is returned as is.
Expand All @@ -152,7 +145,7 @@ def deserialize(cls, inp: 'str | dict | Config') -> 'Self':
raise TypeError(f'Cannot deserialize Config from type: {type(inp)}.')

@classmethod
def from_file(cls, uri: str) -> 'Self':
def from_file(cls, uri: str) -> Self:
"""Deserialize Config from a JSON file, upgrading if possible.
Args:
Expand All @@ -163,7 +156,7 @@ def from_file(cls, uri: str) -> 'Self':
return cfg

@classmethod
def from_dict(cls, cfg_dict: dict) -> 'Self':
def from_dict(cls, cfg_dict: dict) -> Self:
"""Deserialize Config from a dict.
Args:
Expand Down Expand Up @@ -211,10 +204,7 @@ def build_config(x: Union[dict, List[Union[dict, Config]], Config]
Config: the corresponding Config(s)
"""
if isinstance(x, dict):
new_x = {
k: build_config(v)
for k, v in x.items() if k not in ('plugin_versions', 'rv_config')
}
new_x = {k: build_config(v) for k, v in x.items()}
type_hint = new_x.get('type_hint')
if type_hint is not None:
config_cls = registry.get_config(type_hint)
Expand Down
22 changes: 16 additions & 6 deletions rastervision_pipeline/rastervision/pipeline/pipeline_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from os.path import join
from typing import TYPE_CHECKING, Optional, Dict
from typing import TYPE_CHECKING, Self

from rastervision.pipeline.config import Config, Field
from rastervision.pipeline.config import register_config

if TYPE_CHECKING:
from rastervision.pipeline.pipeline import Pipeline # noqa
from rastervision.pipeline.pipeline import Pipeline


@register_config('pipeline')
Expand All @@ -14,15 +14,15 @@ class PipelineConfig(Config):
This should be subclassed to configure new Pipelines.
"""
root_uri: str = Field(
root_uri: str | None = Field(
None, description='The root URI for output generated by the pipeline')
rv_config: dict = Field(
rv_config: dict | None = Field(
None,
description='Used to store serialized RVConfig so pipeline can '
'run in remote environment with the local RVConfig. This should '
'not be set explicitly by users -- it is only used by the runner '
'when running a remote pipeline.')
plugin_versions: Optional[Dict[str, int]] = Field(
plugin_versions: dict[str, int] | None = Field(
None,
description=
('Used to store a mapping of plugin module paths to the latest '
Expand All @@ -46,4 +46,14 @@ def build(self, tmp_dir: str) -> 'Pipeline':
return Pipeline(self, tmp_dir)

def dict(self, with_rv_metadata: bool = True, **kwargs) -> dict:
return super().dict(with_rv_metadata=with_rv_metadata, **kwargs)
# override to have with_rv_metadata=True by default
cfg_dict = super().dict(with_rv_metadata=with_rv_metadata, **kwargs)
return cfg_dict

@classmethod
def from_dict(cls, cfg_dict: dict) -> Self:
# override to retain plugin_versions
from rastervision.pipeline.config import build_config, upgrade_config
cfg_dict: dict = upgrade_config(cfg_dict)
cfg = build_config(cfg_dict)
return cfg
8 changes: 3 additions & 5 deletions tests/pipeline/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from os.path import join
import unittest

from pydantic.error_wrappers import ValidationError

from rastervision.pipeline.file_system.utils import get_tmp_dir, json_to_file
from rastervision.pipeline.config import (Config, register_config,
build_config, upgrade_config)
from rastervision.pipeline.config import (
Config, register_config, build_config, upgrade_config, ValidationError)
from rastervision.pipeline.pipeline_config import (PipelineConfig)
from rastervision.pipeline import registry_ as registry

Expand Down Expand Up @@ -123,7 +121,7 @@ def test_to_from(self):
'x'
}

self.assertDictEqual(cfg.dict(), exp_dict)
self.assertDictEqual(cfg.dict(with_rv_metadata=False), exp_dict)
self.assertEqual(build_config(exp_dict), cfg)

def test_no_extras(self):
Expand Down

0 comments on commit 743133c

Please sign in to comment.