diff --git a/rastervision_core/rastervision/core/box.py b/rastervision_core/rastervision/core/box.py index fe994a5d8..c2f959778 100644 --- a/rastervision_core/rastervision/core/box.py +++ b/rastervision_core/rastervision/core/box.py @@ -1,6 +1,6 @@ from typing import (TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union) -from pydantic import PositiveInt as PosInt, conint +from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt import math import random @@ -11,8 +11,6 @@ from rastervision.pipeline.utils import repr_with_args -NonNegInt = conint(ge=0) - if TYPE_CHECKING: from shapely.geometry import MultiPolygon diff --git a/rastervision_core/rastervision/core/data/class_config.py b/rastervision_core/rastervision/core/data/class_config.py index 3f63de66d..3594a0de5 100644 --- a/rastervision_core/rastervision/core/data/class_config.py +++ b/rastervision_core/rastervision/core/data/class_config.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Self, Tuple, Union from rastervision.pipeline.config import (Config, register_config, ConfigError, - Field, validator) + Field, model_validator) from rastervision.core.data.utils import color_to_triple, normalize_color DEFAULT_NULL_CLASS_NAME = 'null' @@ -32,42 +32,43 @@ class ClassConfig(Config): 'Config is part of a SemanticSegmentationConfig, a null class will be ' 'added automatically.') - @validator('colors', always=True) - def validate_colors(cls, v: Optional[List[Union[str, Tuple]]], - values: dict) -> Optional[List[Union[str, Tuple]]]: + @model_validator(mode='after') + def validate_colors(self) -> Self: """Compare length w/ names. Also auto-generate if not specified.""" - class_names = values['names'] - class_colors = v - if class_colors is None: - class_colors = [color_to_triple() for _ in class_names] - elif len(class_names) != len(class_colors): - raise ConfigError(f'len(class_names) ({len(class_names)}) != ' - f'len(class_colors) ({len(class_colors)})\n' - f'class_names: {class_names}\n' - f'class_colors: {class_colors}') - return class_colors - - @validator('null_class', always=True) - def validate_null_class(cls, v: Optional[str], - values: dict) -> Optional[str]: + names = self.names + colors = self.colors + if colors is None: + self.colors = [color_to_triple() for _ in names] + elif len(names) != len(colors): + raise ConfigError(f'len(class_names) ({len(names)}) != ' + f'len(class_colors) ({len(colors)})\n' + f'class_names: {names}\n' + f'class_colors: {colors}') + return self + + @model_validator(mode='after') + def validate_null_class(self) -> Self: """Check if in names. If 'null' in names, use it as null class.""" - names = values['names'] - if v is None: + names = self.names + null_class = self.null_class + if null_class is None: if DEFAULT_NULL_CLASS_NAME in names: - v = DEFAULT_NULL_CLASS_NAME + self.null_class = DEFAULT_NULL_CLASS_NAME else: - if v not in names: + if null_class not in names: raise ConfigError( - f'The null_class, "{v}", must be in list of class names.') + f'The null_class, "{null_class}", must be in list of ' + 'class names.') # edge case default_null_class_in_names = (DEFAULT_NULL_CLASS_NAME in names) - null_class_neq_default = (v != DEFAULT_NULL_CLASS_NAME) + null_class_neq_default = (null_class != DEFAULT_NULL_CLASS_NAME) if default_null_class_in_names and null_class_neq_default: raise ConfigError( f'"{DEFAULT_NULL_CLASS_NAME}" is in names but the ' - f'specified null_class is something else ("{v}").') - return v + 'specified null_class is something else ' + f'("{null_class}").') + return self def get_class_id(self, name: str) -> int: return self.names.index(name) diff --git a/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source_config.py b/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source_config.py index faf188185..a1ceaed1c 100644 --- a/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source_config.py +++ b/rastervision_core/rastervision/core/data/label_source/chip_classification_label_source_config.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Optional, Self from rastervision.core.data.vector_source import (VectorSourceConfig) from rastervision.core.data.label_source import (LabelSourceConfig, ChipClassificationLabelSource) from rastervision.pipeline.config import (ConfigError, register_config, Field, - validator, root_validator) + field_validator, model_validator) from rastervision.core.data.vector_transformer import ( ClassInferenceTransformerConfig, BufferTransformerConfig) @@ -62,7 +62,8 @@ class ChipClassificationLabelSourceConfig(LabelSourceConfig): description='If True, labels will not be populated automatically ' 'during initialization of the label source.') - @validator('vector_source') + @field_validator('vector_source') + @classmethod def ensure_required_transformers( cls, v: VectorSourceConfig) -> VectorSourceConfig: """Add class-inference and buffer transformers if absent.""" @@ -84,14 +85,14 @@ def ensure_required_transformers( return v - @root_validator(skip_on_failure=True) - def ensure_bg_class_id_if_inferring(cls, values: dict) -> dict: - infer_cells = values.get('infer_cells') - has_bg_class_id = values.get('background_class_id') is not None + @model_validator(mode='after') + def ensure_bg_class_id_if_inferring(self) -> Self: + infer_cells = self.infer_cells + has_bg_class_id = self.background_class_id is not None if infer_cells and not has_bg_class_id: raise ConfigError( 'background_class_id is required if infer_cells=True.') - return values + return self def build(self, class_config, crs_transformer, bbox=None, tmp_dir=None) -> ChipClassificationLabelSource: diff --git a/rastervision_core/rastervision/core/data/label_source/object_detection_label_source_config.py b/rastervision_core/rastervision/core/data/label_source/object_detection_label_source_config.py index dfe7551be..10347c254 100644 --- a/rastervision_core/rastervision/core/data/label_source/object_detection_label_source_config.py +++ b/rastervision_core/rastervision/core/data/label_source/object_detection_label_source_config.py @@ -3,7 +3,7 @@ from rastervision.core.data.vector_source import VectorSourceConfig from rastervision.core.data.vector_transformer import ( ClassInferenceTransformerConfig, BufferTransformerConfig) -from rastervision.pipeline.config import register_config, validator +from rastervision.pipeline.config import register_config, field_validator @register_config('object_detection_label_source') @@ -12,7 +12,8 @@ class ObjectDetectionLabelSourceConfig(LabelSourceConfig): vector_source: VectorSourceConfig - @validator('vector_source') + @field_validator('vector_source') + @classmethod def ensure_required_transformers( cls, v: VectorSourceConfig) -> VectorSourceConfig: """Add class-inference and buffer transformers if absent.""" diff --git a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py index f617e9c78..d86627d10 100644 --- a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Optional, Sequence, Self, Tuple -from pydantic import conint +from pydantic import NonNegativeInt as NonNegInt import numpy as np from pystac import Item @@ -18,29 +18,28 @@ class MultiRasterSource(RasterSource): def __init__(self, raster_sources: Sequence[RasterSource], - primary_source_idx: conint(ge=0) = 0, + primary_source_idx: NonNegInt = 0, force_same_dtype: bool = False, - channel_order: Optional[Sequence[conint(ge=0)]] = None, + channel_order: Sequence[NonNegInt] | None = None, raster_transformers: Sequence = [], - bbox: Optional[Box] = None): + bbox: Box | None = None): """Constructor. Args: - raster_sources (Sequence[RasterSource]): Sequence of RasterSources. + raster_sources: Sequence of RasterSources. primary_source_idx (0 <= int < len(raster_sources)): Index of the raster source whose CRS, dtype, and other attributes will override those of the other raster sources. - force_same_dtype (bool): If true, force all sub-chips to have the + force_same_dtype: If true, force all sub-chips to have the same dtype as the primary_source_idx-th sub-chip. No careful conversion is done, just a quick cast. Use with caution. - channel_order (Sequence[conint(ge=0)], optional): Channel ordering - that will be used by .get_chip(). Defaults to None. - raster_transformers (Sequence, optional): Sequence of transformers. - Defaults to []. - bbox (Optional[Box], optional): User-specified crop of the extent. - If given, the primary raster source's bbox is set to this. - If None, the full extent available in the source file of the - primary raster source is used. + channel_order: Channel ordering that will be used by + :meth:`MultiRasterSource.get_chip()`. Defaults to ``None``. + raster_transformers: List of transformers. Defaults to ``[]``. + bbox: User-specified crop of the extent. If specified, the primary + raster source's bbox is set to this. If ``None``, the full + extent available in the source file of the primary raster + source is used. """ num_channels_raw = sum(rs.num_channels for rs in raster_sources) if not channel_order: @@ -78,7 +77,7 @@ def from_stac( cls, item: Item, assets: list[str] | None, - primary_source_idx: conint(ge=0) = 0, + primary_source_idx: NonNegInt = 0, raster_transformers: list['RasterTransformer'] = [], force_same_dtype: bool = False, channel_order: Sequence[int] | None = None, diff --git a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source_config.py b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source_config.py index f223e7c87..19f31c40e 100644 --- a/rastervision_core/rastervision/core/data/raster_source/multi_raster_source_config.py +++ b/rastervision_core/rastervision/core/data/raster_source/multi_raster_source_config.py @@ -1,7 +1,10 @@ -from typing import Optional -from pydantic import conint, conlist +from typing import List, Optional, Self -from rastervision.pipeline.config import (Field, register_config, validator) +from typing_extensions import Annotated +from pydantic import NonNegativeInt as NonNegInt + +from rastervision.pipeline.config import (Field, register_config, + model_validator) from rastervision.core.box import Box from rastervision.core.data.raster_source import (RasterSourceConfig, MultiRasterSource) @@ -25,10 +28,10 @@ class MultiRasterSourceConfig(RasterSourceConfig): Or :class:`.TemporalMultiRasterSource`, if ``temporal=True``. """ - raster_sources: conlist( - RasterSourceConfig, min_items=1) = Field( + raster_sources: Annotated[List[ + RasterSourceConfig], Field(min_length=1)] = Field( ..., description='List of RasterSourceConfig to combine.') - primary_source_idx: conint(ge=0) = Field( + primary_source_idx: NonNegInt = Field( 0, description= 'Index of the raster source whose CRS, dtype, and other attributes ' @@ -42,21 +45,21 @@ class MultiRasterSourceConfig(RasterSourceConfig): description='Stack images from sub raster sources into a time-series ' 'of shape (T, H, W, C) instead of concatenating bands.') - @validator('primary_source_idx') - def validate_primary_source_idx(cls, v: int, values: dict): - raster_sources = values.get('raster_sources', []) - if not (0 <= v < len(raster_sources)): + @model_validator(mode='after') + def validate_primary_source_idx(self) -> Self: + primary_source_idx = self.primary_source_idx + raster_sources = self.raster_sources + if not (0 <= primary_source_idx < len(raster_sources)): raise IndexError('primary_source_idx must be in range ' '[0, len(raster_sources)].') - return v + return self - @validator('temporal') - def validate_temporal(cls, v: int, values: dict): - channel_order = values.get('channel_order') - if v and channel_order is not None: + @model_validator(mode='after') + def validate_temporal(self) -> Self: + if self.temporal and self.channel_order is not None: raise ValueError( 'Setting channel_order is not allowed if temporal=True.') - return v + return self def build(self, tmp_dir: Optional[str] = None, diff --git a/rastervision_core/rastervision/core/data/raster_source/rasterized_source_config.py b/rastervision_core/rastervision/core/data/raster_source/rasterized_source_config.py index ff6449c15..efc30bd35 100644 --- a/rastervision_core/rastervision/core/data/raster_source/rasterized_source_config.py +++ b/rastervision_core/rastervision/core/data/raster_source/rasterized_source_config.py @@ -5,7 +5,7 @@ from rastervision.core.data.vector_transformer import ( ClassInferenceTransformerConfig, BufferTransformerConfig) from rastervision.pipeline.config import (register_config, Config, Field, - validator) + field_validator) if TYPE_CHECKING: from rastervision.core.box import Box @@ -35,7 +35,8 @@ class RasterizedSourceConfig(Config): vector_source: VectorSourceConfig rasterizer_config: RasterizerConfig - @validator('vector_source') + @field_validator('vector_source') + @classmethod def ensure_required_transformers( cls, v: VectorSourceConfig) -> VectorSourceConfig: """Add class-inference and buffer transformers if absent.""" diff --git a/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py b/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py index e3ad95052..79672a4ba 100644 --- a/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py +++ b/rastervision_core/rastervision/core/data/raster_source/temporal_multi_raster_source.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Sequence, Tuple -from pydantic import conint +from pydantic import NonNegativeInt as NonNegInt import numpy as np from rastervision.core.box import Box @@ -14,7 +14,7 @@ class TemporalMultiRasterSource(MultiRasterSource): def __init__(self, raster_sources: Sequence[RasterSource], - primary_source_idx: conint(ge=0) = 0, + primary_source_idx: NonNegInt = 0, force_same_dtype: bool = False, raster_transformers: Sequence = [], bbox: Optional[Box] = None): diff --git a/rastervision_core/rastervision/core/data/vector_transformer/class_inference_transformer_config.py b/rastervision_core/rastervision/core/data/vector_transformer/class_inference_transformer_config.py index c0dd71b38..8dda0e00b 100644 --- a/rastervision_core/rastervision/core/data/vector_transformer/class_inference_transformer_config.py +++ b/rastervision_core/rastervision/core/data/vector_transformer/class_inference_transformer_config.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING from rastervision.pipeline.config import register_config, Field from rastervision.core.data.vector_transformer import ( @@ -12,13 +12,13 @@ class ClassInferenceTransformerConfig(VectorTransformerConfig): """Configure a :class:`.ClassInferenceTransformer`.""" - default_class_id: Optional[int] = Field( + default_class_id: int | None = Field( None, description='The default ``class_id`` to use if class cannot be ' 'inferred using other mechanisms. If a feature has an inferred ' '``class_id`` of ``None``, then it will be deleted. ' 'Defaults to ``None``.') - class_id_to_filter: Optional[Dict[int, list]] = Field( + class_id_to_filter: dict[int, list] | None = Field( None, description='Map from ``class_id`` to JSON filter used to infer ' 'missing class IDs. Each key should be a class ID, and its value ' @@ -28,7 +28,7 @@ class ClassInferenceTransformerConfig(VectorTransformerConfig): 'is that described by ' 'https://docs.mapbox.com/mapbox-gl-js/style-spec/other/#other-filter. ' 'Defaults to ``None``.') - class_name_mapping: dict[str, str] = Field( + class_name_mapping: dict[str, str] | None = Field( None, description='``old_name --> new_name`` mapping for values in the ' '``class_name`` or ``label`` property of the GeoJSON features. The ' @@ -36,7 +36,7 @@ class ClassInferenceTransformerConfig(VectorTransformerConfig): 'can also be used to merge multiple classes into one e.g.: ' '``dict(car="vehicle", truck="vehicle")``. Defaults to None.') - def build(self, class_config: Optional['ClassConfig'] = None + def build(self, class_config: 'ClassConfig | None' = None ) -> ClassInferenceTransformer: return ClassInferenceTransformer( self.default_class_id, diff --git a/rastervision_core/rastervision/core/rv_pipeline/chip_options.py b/rastervision_core/rastervision/core/rv_pipeline/chip_options.py index d15440581..ebcc32b9d 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/chip_options.py +++ b/rastervision_core/rastervision/core/rv_pipeline/chip_options.py @@ -1,15 +1,13 @@ -from typing import (Any, Dict, Literal, Optional, Tuple, Union) +from typing import (Any, Dict, Literal, Optional, Self, Tuple, Union) from enum import Enum -from pydantic import (PositiveInt as PosInt, conint) +from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt import numpy as np -from rastervision.core.utils.misc import Proportion from rastervision.core.rv_pipeline.utils import nodata_below_threshold +from rastervision.core.utils import Proportion from rastervision.pipeline.config import (Config, ConfigError, Field, - register_config, root_validator) - -NonNegInt = conint(ge=0) + register_config, model_validator) class WindowSamplingMethod(Enum): @@ -88,33 +86,25 @@ class WindowSamplingConfig(Config): 'that lie fully within the AOI. If False, windows only partially ' 'intersecting the AOI will also be allowed.') - @root_validator(skip_on_failure=True) - def validate_options(cls, values: dict) -> dict: - method = values.get('method') - size = values.get('size') + @model_validator(mode='after') + def validate_options(self) -> Self: + method = self.method + size = self.size if method == WindowSamplingMethod.sliding: - has_stride = values.get('stride') is not None - - if not has_stride: - values['stride'] = size + if self.stride is None: + self.stride = size elif method == WindowSamplingMethod.random: - size_lims = values.get('size_lims') - h_lims = values.get('h_lims') - w_lims = values.get('w_lims') - - has_size_lims = size_lims is not None - has_h_lims = h_lims is not None - has_w_lims = w_lims is not None - + has_size_lims = self.size_lims is not None + has_h_lims = self.h_lims is not None + has_w_lims = self.w_lims is not None if not (has_size_lims or has_h_lims or has_w_lims): - size_lims = (size, size + 1) + self.size_lims = (size, size + 1) has_size_lims = True - values['size_lims'] = size_lims if has_size_lims == (has_w_lims or has_h_lims): raise ConfigError('Specify either size_lims or h and w lims.') if has_h_lims != has_w_lims: raise ConfigError('h_lims and w_lims must both be specified') - return values + return self @register_config('chip_options') diff --git a/rastervision_core/rastervision/core/rv_pipeline/object_detection_config.py b/rastervision_core/rastervision/core/rv_pipeline/object_detection_config.py index 4e7cc51ac..16cbc4741 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/object_detection_config.py +++ b/rastervision_core/rastervision/core/rv_pipeline/object_detection_config.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Optional, Self -from rastervision.pipeline.config import Field, register_config, validator +from rastervision.pipeline.config import (Field, register_config, + model_validator) from rastervision.core.rv_pipeline import ( ChipOptions, RVPipelineConfig, PredictOptions, WindowSamplingConfig) from rastervision.core.data.label_store import ObjectDetectionGeoJSONStoreConfig @@ -59,20 +60,19 @@ class ObjectDetectionPredictOptions(PredictOptions): ('Predicted boxes are only output if their score is above score_thresh.' )) - @validator('stride', always=True) - def validate_stride(cls, v: Optional[int], values: dict) -> dict: - if v is None: - chip_sz: int = values['chip_sz'] - return chip_sz // 2 - return v + @model_validator(mode='after') + def validate_stride(self) -> Self: + if self.stride is None: + self.stride = self.chip_sz // 2 + return self @register_config('object_detection') class ObjectDetectionConfig(RVPipelineConfig): """Configure an :class:`.ObjectDetection` pipeline.""" - chip_options: Optional[ObjectDetectionChipOptions] - predict_options: Optional[ObjectDetectionPredictOptions] + chip_options: Optional[ObjectDetectionChipOptions] = None + predict_options: Optional[ObjectDetectionPredictOptions] = None def build(self, tmp_dir): from rastervision.core.rv_pipeline.object_detection import ObjectDetection diff --git a/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py b/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py index 3a1b3fb34..468a7942e 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py +++ b/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py @@ -1,4 +1,4 @@ -from typing import (TYPE_CHECKING, List, Optional) +from typing import (TYPE_CHECKING, List, Optional, Self) from os.path import join from rastervision.pipeline.pipeline_config import PipelineConfig @@ -10,7 +10,7 @@ from rastervision.core.analyzer import AnalyzerConfig from rastervision.core.rv_pipeline.chip_options import ChipOptions from rastervision.pipeline.config import (Config, Field, register_config, - validator) + model_validator) if TYPE_CHECKING: from rastervision.core.backend.backend import Backend # noqa @@ -27,12 +27,11 @@ class PredictOptions(Config): batch_sz: int = Field( 8, description='Batch size to use during prediction.') - @validator('stride', always=True) - def validate_stride(cls, v: Optional[int], values: dict) -> dict: - if v is None: - chip_sz: int = values['chip_sz'] - return chip_sz - return v + @model_validator(mode='after') + def validate_stride(self) -> Self: + if self.stride is None: + self.stride = self.chip_sz + return self def rv_pipeline_config_upgrader(cfg_dict: dict, version: int) -> dict: diff --git a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py index 17fda9fe3..a5e8c490b 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py +++ b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py @@ -1,10 +1,11 @@ -from typing import (List, Literal, Optional, Union) -from pydantic import conint +from typing import (List, Literal, Optional, Self, Union) import logging +from pydantic import NonNegativeInt as NonNegInt import numpy as np -from rastervision.pipeline.config import (register_config, Field, validator) +from rastervision.pipeline.config import (register_config, Field, + model_validator) from rastervision.core.rv_pipeline.rv_pipeline_config import (PredictOptions, RVPipelineConfig) from rastervision.core.rv_pipeline.chip_options import (ChipOptions, @@ -81,7 +82,7 @@ class SemanticSegmentationPredictOptions(PredictOptions): description='Stride of the sliding window for generating chips. ' 'Allows aggregating multiple predictions for each pixel if less than ' 'the chip size. Defaults to ``chip_sz``.') - crop_sz: Optional[Union[conint(gt=0), Literal['auto']]] = Field( + crop_sz: Optional[Union[NonNegInt, Literal['auto']]] = Field( None, description= 'Number of rows/columns of pixels from the edge of prediction ' @@ -90,22 +91,19 @@ class SemanticSegmentationPredictOptions(PredictOptions): 'near the edges of chips. If "auto", will be set to half the stride ' 'if stride is less than chip_sz. Defaults to None.') - @validator('crop_sz') - def validate_crop_sz(cls, - v: Optional[Union[conint(gt=0), Literal['auto']]], - values: dict) -> dict: - crop_sz = v - if crop_sz == 'auto': - chip_sz: int = values['chip_sz'] - stride: int = values['stride'] - overlap_sz = chip_sz - stride + @model_validator(mode='after') + def set_auto_crop_sz(self) -> Self: + if self.crop_sz == 'auto': + if self.stride is None: + self.validate_stride() + overlap_sz = self.chip_sz - self.stride if overlap_sz % 2 == 1: log.warning( 'Using crop_sz="auto" but overlap size (chip_sz minus ' 'stride) is odd. This means that one pixel row/col will ' 'still overlap after cropping.') - crop_sz = overlap_sz // 2 - return crop_sz + self.crop_sz = overlap_sz // 2 + return self def ss_config_upgrader(cfg_dict: dict, version: int) -> dict: @@ -124,8 +122,8 @@ def ss_config_upgrader(cfg_dict: dict, version: int) -> dict: class SemanticSegmentationConfig(RVPipelineConfig): """Configure a :class:`.SemanticSegmentation` pipeline.""" - chip_options: Optional[SemanticSegmentationChipOptions] - predict_options: Optional[SemanticSegmentationPredictOptions] + chip_options: SemanticSegmentationChipOptions | None = None + predict_options: SemanticSegmentationPredictOptions | None = None def build(self, tmp_dir): from rastervision.core.rv_pipeline.semantic_segmentation import ( diff --git a/rastervision_core/rastervision/core/utils/__init__.py b/rastervision_core/rastervision/core/utils/__init__.py index fdc6d5d60..114b3024c 100644 --- a/rastervision_core/rastervision/core/utils/__init__.py +++ b/rastervision_core/rastervision/core/utils/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa from rastervision.core.utils.stac import * -from rastervision.core.utils.misc import * +from rastervision.core.utils.types import * diff --git a/rastervision_core/rastervision/core/utils/misc.py b/rastervision_core/rastervision/core/utils/misc.py deleted file mode 100644 index dec94f462..000000000 --- a/rastervision_core/rastervision/core/utils/misc.py +++ /dev/null @@ -1,3 +0,0 @@ -from pydantic import confloat - -Proportion = confloat(ge=0, le=1) diff --git a/rastervision_core/rastervision/core/utils/types.py b/rastervision_core/rastervision/core/utils/types.py new file mode 100644 index 000000000..058347860 --- /dev/null +++ b/rastervision_core/rastervision/core/utils/types.py @@ -0,0 +1,8 @@ +from typing_extensions import Annotated +from pydantic.types import StringConstraints + +from rastervision.pipeline.config import Field + +NonEmptyStr = Annotated[str, + StringConstraints(strip_whitespace=True, min_length=1)] +Proportion = Annotated[float, Field(ge=0, le=1)] diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index e0aec3912..50d32bfb0 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -1,5 +1,5 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, - Literal, Optional, Sequence, Tuple, Union) + Literal, Optional, Self, Sequence, Tuple, Union) import os from os.path import join, isdir from enum import Enum @@ -7,10 +7,10 @@ import uuid import logging -from pydantic import (PositiveFloat, PositiveInt as PosInt, constr, confloat, - conint) -from pydantic.utils import sequence_like - +from typing_extensions import Annotated +from pydantic import (NonNegativeInt as NonNegInt, PositiveFloat, PositiveInt + as PosInt, StringConstraints) +from pydantic.v1.utils import sequence_like import albumentations as A import torch from torch import (nn, optim) @@ -18,13 +18,15 @@ from torch.utils.data import Dataset, ConcatDataset, Subset from rastervision.pipeline.config import (Config, register_config, ConfigError, - Field, validator, root_validator) + Field, field_validator, + model_validator) from rastervision.pipeline.file_system import (list_paths, download_if_needed, unzip, file_exists, get_local_path, sync_from_dir) from rastervision.core.data import (ClassConfig, Scene, DatasetConfig as SceneDatasetConfig) from rastervision.core.rv_pipeline import (WindowSamplingConfig) +from rastervision.core.utils import NonEmptyStr, Proportion from rastervision.pytorch_learner.utils import ( validate_albumentation_transform, MinMaxNormalize, deserialize_albumentation_transform, get_hubconf_dir_from_cfg, @@ -43,14 +45,11 @@ ] # types -Proportion = confloat(ge=0, le=1) -NonEmptyStr = constr(strip_whitespace=True, min_length=1) -NonNegInt = conint(ge=0) RGBTuple = Tuple[int, int, int] ChannelInds = Sequence[NonNegInt] -class Backbone(Enum): +class Backbone(str, Enum): alexnet = 'alexnet' densenet121 = 'densenet121' densenet169 = 'densenet169' @@ -136,9 +135,10 @@ class ExternalModuleConfig(Config): None, description=('Local uri of a zip file, or local uri of a directory,' 'or remote uri of zip file.')) - github_repo: Optional[constr( - strip_whitespace=True, regex=r'.+/.+')] = Field( - None, description='/[:tag]') + github_repo: Optional[Annotated[ + str, StringConstraints( + strip_whitespace=True, pattern=r'.+/.+')]] = Field( + None, description='/[:tag]') name: Optional[NonEmptyStr] = Field( None, description= @@ -157,14 +157,14 @@ class ExternalModuleConfig(Config): force_reload: bool = Field( False, description='Force reload of module definition.') - @root_validator(skip_on_failure=True) - def check_either_uri_or_repo(cls, values: dict) -> dict: - has_uri = values.get('uri') is not None - has_repo = values.get('github_repo') is not None + @model_validator(mode='after') + def check_either_uri_or_repo(self) -> Self: + has_uri = self.uri is not None + has_repo = self.github_repo is not None if has_uri == has_repo: raise ConfigError( 'Must specify one (and only one) of github_repo and uri.') - return values + return self def build(self, save_dir: str, @@ -366,11 +366,11 @@ class SolverConfig(Config): description='If specified, the loss will be built from the definition ' 'from this external source, using Torch Hub.') - @root_validator(skip_on_failure=True) - def check_no_loss_opts_if_external(cls, values: dict) -> dict: - has_external_loss_def = values.get('external_loss_def') is not None - has_ignore_class_index = values.get('ignore_class_index') is not None - has_class_loss_weights = values.get('class_loss_weights') is not None + @model_validator(mode='after') + def check_no_loss_opts_if_external(self) -> Self: + has_external_loss_def = self.external_loss_def is not None + has_ignore_class_index = self.ignore_class_index is not None + has_class_loss_weights = self.class_loss_weights is not None if has_external_loss_def: if has_ignore_class_index: @@ -379,7 +379,7 @@ def check_no_loss_opts_if_external(cls, values: dict) -> dict: if has_class_loss_weights: raise ConfigError('class_loss_weights is not supported ' 'with external_loss_def.') - return values + return self def build_loss(self, num_classes: int, @@ -576,8 +576,7 @@ class PlotOptions(Config): 'for that group.')) # validators - _tf = validator( - 'transform', allow_reuse=True)(validate_albumentation_transform) + _tf = field_validator('transform')(validate_albumentation_transform) def update(self, **kwargs) -> None: super().update() @@ -586,7 +585,8 @@ def update(self, **kwargs) -> None: self.channel_display_groups = get_default_channel_display_groups( img_channels) - @validator('channel_display_groups') + @field_validator('channel_display_groups') + @classmethod def validate_channel_display_groups( cls, v: Optional[Union[Dict[str, Sequence[NonNegInt]], Sequence[ Sequence[NonNegInt]]]] @@ -671,26 +671,24 @@ def num_classes(self): return len(self.class_config) # validators - _base_tf = validator( - 'base_transform', allow_reuse=True)(validate_albumentation_transform) - _aug_tf = validator( - 'aug_transform', allow_reuse=True)(validate_albumentation_transform) - - @validator('augmentors', each_item=True) - def validate_augmentors(cls, v: str) -> str: - if v not in augmentors: - raise ConfigError(f'Unsupported augmentor "{v}"') + _base_tf = field_validator('base_transform')( + validate_albumentation_transform) + _aug_tf = field_validator('aug_transform')( + validate_albumentation_transform) + + @field_validator('augmentors') + @classmethod + def validate_augmentors(cls, v: list[str]) -> str: + for aug_name in v: + if aug_name not in augmentors: + raise ConfigError(f'Unsupported augmentor "{aug_name}"') return v - @root_validator(skip_on_failure=True) - def validate_plot_options(cls, values: dict) -> dict: - plot_options: Optional[PlotOptions] = values.get('plot_options') - if plot_options is None: - return None - img_channels: Optional[PosInt] = values.get('img_channels') - if img_channels is not None: - plot_options.update(img_channels=img_channels) - return values + @model_validator(mode='after') + def validate_plot_options(self) -> Self: + if self.plot_options is not None and self.img_channels is not None: + self.plot_options.update(img_channels=self.img_channels) + return self def get_custom_albumentations_transforms(self) -> List[dict]: """Returns all custom transforms found in this config. @@ -833,11 +831,11 @@ class ImageDataConfig(DataConfig): 'that will be used for all groups or a list of values ' '(one for each group).') - @root_validator(skip_on_failure=True) - def validate_group_uris(cls, values: dict) -> dict: - group_train_sz = values.get('group_train_sz') - group_train_sz_rel = values.get('group_train_sz_rel') - group_uris = values.get('group_uris') + @model_validator(mode='after') + def validate_group_uris(self) -> Self: + group_train_sz = self.group_train_sz + group_train_sz_rel = self.group_train_sz_rel + group_uris = self.group_uris has_group_train_sz = group_train_sz is not None has_group_train_sz_rel = group_train_sz_rel is not None @@ -858,7 +856,7 @@ def validate_group_uris(cls, values: dict) -> dict: if len(group_train_sz_rel) != len(group_uris): raise ConfigError( 'len(group_train_sz_rel) != len(group_uris).') - return values + return self def _build_dataset(self, dirs: Iterable[str], @@ -1188,35 +1186,31 @@ def __repr_args__(self): out = [('scene_dataset', ds_str), ('sampling', sampling_str)] return out - @validator('sampling') - def validate_sampling( - cls, - v: Union[WindowSamplingConfig, Dict[str, WindowSamplingConfig]], - values: dict - ) -> Union[WindowSamplingConfig, Dict[str, WindowSamplingConfig]]: - if isinstance(v, dict): - if len(v) == 0: - return v - scene_dataset: Optional['SceneDatasetConfig'] = values.get( - 'scene_dataset') - if scene_dataset is None: - raise ConfigError('sampling is a non-empty dict but ' - 'scene_dataset is None.') - for s in scene_dataset.all_scenes: - if s.id not in v: - raise ConfigError( - f'Window config not found for scene {s.id}') - return v + @model_validator(mode='after') + def validate_sampling(self) -> Self: + if not isinstance(self.sampling, dict): + return self + + # empty dict + if len(self.sampling) == 0: + return self + + if self.scene_dataset is None: + raise ConfigError('sampling is a non-empty dict but ' + 'scene_dataset is None.') + + for s in self.scene_dataset.all_scenes: + if s.id not in self.sampling: + raise ConfigError( + f'Window sampling config not found for scene: {s.id}') - @root_validator(skip_on_failure=True) - def get_class_config_from_dataset_if_needed(cls, values: dict) -> dict: - has_class_config = values.get('class_config') is not None - if has_class_config: - return values - has_scene_dataset = values.get('scene_dataset') is not None - if has_scene_dataset: - values['class_config'] = values['scene_dataset'].class_config - return values + return self + + @model_validator(mode='after') + def get_class_config_from_dataset_if_needed(self) -> Self: + if self.class_config is None and self.scene_dataset is not None: + self.class_config = self.scene_dataset.class_config + return self def build_scenes(self, scene_configs: Iterable['SceneConfig'], @@ -1386,28 +1380,26 @@ class LearnerConfig(Config): 'last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` ' 'is the epoch number.')) - @validator('run_tensorboard') - def validate_run_tensorboard(cls, v: bool, values: dict) -> bool: - if v and not values.get('log_tensorboard'): + @model_validator(mode='after') + def validate_run_tensorboard(self) -> Self: + if self.run_tensorboard and not self.log_tensorboard: raise ConfigError( 'Cannot run tensorboard if log_tensorboard is False') - return v + return self - @root_validator(skip_on_failure=True) - def validate_class_loss_weights(cls, values: dict) -> dict: - solver: Optional[SolverConfig] = values.get('solver') - if solver is None: - return values - class_loss_weights = solver.class_loss_weights + @model_validator(mode='after') + def validate_class_loss_weights(self) -> Self: + if self.solver is None: + return self + class_loss_weights = self.solver.class_loss_weights if class_loss_weights is not None: - data: DataConfig = values.get('data') num_weights = len(class_loss_weights) - num_classes = data.num_classes + num_classes = self.data.num_classes if num_weights != num_classes: raise ConfigError( f'class_loss_weights ({num_weights}) must be same length as ' f'the number of classes ({num_classes})') - return values + return self def build(self, tmp_dir: Optional[str] = None, diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py index af4c1a5dd..271c1645f 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py @@ -11,7 +11,7 @@ from rastervision.core.data import Scene from rastervision.core.rv_pipeline import WindowSamplingMethod from rastervision.pipeline.config import (Config, register_config, Field, - validator, ConfigError) + field_validator, ConfigError) from rastervision.pytorch_learner.learner_config import ( LearnerConfig, ModelConfig, Backbone, ImageDataConfig, GeoDataConfig) from rastervision.pytorch_learner.dataset import ( @@ -131,7 +131,8 @@ class ObjectDetectionModelConfig(ModelConfig): ('The torchvision.models backbone to use, which must be in the resnet* ' 'family.')) - @validator('backbone') + @field_validator('backbone') + @classmethod def only_valid_backbones(cls, v): if v not in [ Backbone.resnet18, Backbone.resnet34, Backbone.resnet50, @@ -221,7 +222,8 @@ def build(self, loss_def_path=loss_def_path, training=training) - @validator('solver') + @field_validator('solver') + @classmethod def validate_solver_config(cls, v: 'SolverConfig') -> 'SolverConfig': if v.ignore_class_index is not None: raise ConfigError( diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py index aa59b6221..7f8439687 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py @@ -150,7 +150,7 @@ def forward(self, x: 'torch.Tensor') -> 'torch.Tensor': class RegressionModelConfig(ModelConfig): """Configure a regression model.""" - output_multiplier: List[float] = None + output_multiplier: list[float] | None = None def update(self, learner=None): if learner is not None and self.output_multiplier is None: diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py index f7131432e..a8371a0c8 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py @@ -11,7 +11,7 @@ from rastervision.core.data import Scene from rastervision.core.rv_pipeline import WindowSamplingMethod from rastervision.pipeline.config import (Config, register_config, Field, - validator, ConfigError) + field_validator, ConfigError) from rastervision.pytorch_learner.learner_config import ( Backbone, LearnerConfig, ModelConfig, ImageDataConfig, GeoDataConfig) from rastervision.pytorch_learner.dataset import ( @@ -165,7 +165,8 @@ class SemanticSegmentationModelConfig(ModelConfig): description='The torchvision.models backbone to use. Currently, only ' 'resnet50 and resnet101 are supported.') - @validator('backbone') + @field_validator('backbone') + @classmethod def only_valid_backbones(cls, v): if v not in [Backbone.resnet50, Backbone.resnet101]: raise ValueError(