From d5e3b1c4b0a03498a309d8b63207b124144a939b Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Fri, 26 Jul 2024 13:09:31 -0400 Subject: [PATCH] misc type fixes --- .../rastervision/aws_s3/s3_file_system.py | 10 +-- .../rastervision/core/data/class_config.py | 2 +- .../rasterio_crs_transformer.py | 11 +-- .../semantic_segmentation_label_store.py | 18 +++-- .../core/data/utils/aoi_sampler.py | 2 +- .../semantic_segmentation_config.py | 2 + .../rastervision/pipeline/cli.py | 9 ++- .../pipeline/file_system/local_file_system.py | 8 ++- .../rastervision/pipeline/rv_config.py | 6 +- .../isprs_potsdam_multi_source.py | 16 ++--- .../pytorch_learner/dataset/dataset.py | 1 + .../pytorch_learner/dataset/transform.py | 69 ++++++++++++++++--- 12 files changed, 109 insertions(+), 45 deletions(-) diff --git a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py index e45ac6c65..23237dcb2 100644 --- a/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py +++ b/rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py @@ -37,14 +37,14 @@ def get_matching_s3_objects( ) while True: resp: dict = s3.list_objects_v2(**kwargs) - dirs: list[dict] = resp.get('CommonPrefixes', {}) - files: list[dict] = resp.get('Contents', {}) + dirs: list[dict[str, Any]] = resp.get('CommonPrefixes', {}) + files: list[dict[str, Any]] = resp.get('Contents', {}) for obj in dirs: - key = obj['Prefix'] + key: str = obj['Prefix'] if key.startswith(prefix) and key.endswith(suffix): yield key, obj for obj in files: - key = obj['Key'] + key: str = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): yield key, obj # The S3 API is paginated, returning up to 1000 keys at a time. @@ -214,7 +214,7 @@ def write_bytes(uri: str, data: bytes) -> None: file_size = len(data) with io.BytesIO(data) as str_buffer: try: - with progressbar(file_size, desc=f'Uploading') as bar: + with progressbar(file_size, desc='Uploading') as bar: s3.upload_fileobj( Fileobj=str_buffer, Bucket=bucket, diff --git a/rastervision_core/rastervision/core/data/class_config.py b/rastervision_core/rastervision/core/data/class_config.py index 8ebc6437e..89b542ddb 100644 --- a/rastervision_core/rastervision/core/data/class_config.py +++ b/rastervision_core/rastervision/core/data/class_config.py @@ -82,7 +82,7 @@ def null_class_id(self) -> int: raise ValueError('null_class is not set') return self.get_class_id(self.null_class) - def get_color_to_class_id(self) -> dict: + def get_color_to_class_id(self) -> dict[str | tuple[int, int, int], int]: return dict([(self.colors[i], i) for i in range(len(self.colors))]) def ensure_null_class(self) -> None: diff --git a/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py b/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py index d974eec9e..fa352fd37 100644 --- a/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py +++ b/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py @@ -63,8 +63,10 @@ def __repr__(self) -> str: """ return out - def _map_to_pixel(self, map_point: tuple[float, float] | np.ndarray - ) -> tuple[int, int] | np.ndarray: + def _map_to_pixel( + self, + map_point: tuple[float, float] | tuple[np.ndarray, np.ndarray] + ) -> tuple[int, int] | tuple[np.ndarray, np.ndarray]: """Transform point from map to pixel-based coordinates. Args: @@ -82,8 +84,9 @@ def _map_to_pixel(self, map_point: tuple[float, float] | np.ndarray pixel_point = (col, row) return pixel_point - def _pixel_to_map(self, pixel_point: tuple[int, int] | np.ndarray - ) -> tuple[float, float] | np.ndarray: + def _pixel_to_map( + self, pixel_point: tuple[int, int] | tuple[np.ndarray, np.ndarray] + ) -> tuple[float, float] | tuple[np.ndarray, np.ndarray]: """Transform point from pixel to map-based coordinates. Args: diff --git a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py index e43d90680..2161e61bb 100644 --- a/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py +++ b/rastervision_core/rastervision/core/data/label_store/semantic_segmentation_label_store.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, overload from os.path import join import logging @@ -333,11 +333,17 @@ def write_vector_output(self, vo: 'VectorOutputConfig', mask: np.ndarray, out_uri = vo.get_uri(vector_output_dir, self.class_config) json_to_file(geojson, out_uri) - def _clip_to_extent(self, - extent: Box, - window: Box, - arr: np.ndarray | None = None - ) -> tuple[Box, np.ndarray | None]: + @overload + def _clip_to_extent(self, extent: Box, window: Box, + arr: np.ndarray) -> tuple[Box, np.ndarray]: + ... + + @overload + def _clip_to_extent(self, extent: Box, window: Box, + arr: None = ...) -> tuple[Box, None]: + ... + + def _clip_to_extent(self, extent, window, arr=None): clipped_window = window.intersection(extent) if arr is not None: h, w = clipped_window.size diff --git a/rastervision_core/rastervision/core/data/utils/aoi_sampler.py b/rastervision_core/rastervision/core/data/utils/aoi_sampler.py index 9cf049602..4051fd378 100644 --- a/rastervision_core/rastervision/core/data/utils/aoi_sampler.py +++ b/rastervision_core/rastervision/core/data/utils/aoi_sampler.py @@ -63,7 +63,7 @@ def sample(self, n: int = 1) -> np.ndarray: def triangulate_polygon(self, polygon: Polygon) -> dict: """Triangulate polygon. - + Extracts vertices and edges from the polygon (and its holes, if any) and passes them to the Triangle library for triangulation. """ diff --git a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py index 1046c83e1..a84ab1900 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py +++ b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py @@ -69,6 +69,8 @@ def enough_target_pixels(self, label_arr: np.ndarray) -> bool: True (the window does contain interesting pixels) or False. """ target_count = 0 + if self.target_class_ids is None: + raise ValueError('target_class_ids not specified.') for class_id in self.target_class_ids: target_count += (label_arr == class_id).sum() enough_target_pixels = target_count >= self.target_count_threshold diff --git a/rastervision_pipeline/rastervision/pipeline/cli.py b/rastervision_pipeline/rastervision/pipeline/cli.py index 25263448f..c185f16fd 100644 --- a/rastervision_pipeline/rastervision/pipeline/cli.py +++ b/rastervision_pipeline/rastervision/pipeline/cli.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import sys import os import logging @@ -44,7 +44,7 @@ def convert_bool_args(args: dict) -> dict: def get_configs(cfg_module_path: str, runner: str | None = None, - args: dict[str, any] | None = None) -> list[PipelineConfig]: + args: dict[str, Any] | None = None) -> list[PipelineConfig]: """Get PipelineConfigs from a module. Calls a get_config(s) function with some arguments from the CLI @@ -74,7 +74,7 @@ def get_configs(cfg_module_path: str, def get_configs_from_module(cfg_module_path: str, runner: str, - args: dict[str, any]) -> list[PipelineConfig]: + args: dict[str, Any]) -> list[PipelineConfig]: import importlib import importlib.util @@ -82,6 +82,9 @@ def get_configs_from_module(cfg_module_path: str, runner: str, # From https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa spec = importlib.util.spec_from_file_location('rastervision.pipeline', cfg_module_path) + if spec is None: + raise ImportError( + f'Failed to read module spec from {cfg_module_path}.') cfg_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(cfg_module) else: diff --git a/rastervision_pipeline/rastervision/pipeline/file_system/local_file_system.py b/rastervision_pipeline/rastervision/pipeline/file_system/local_file_system.py index fb4444eed..8c750f7d5 100644 --- a/rastervision_pipeline/rastervision/pipeline/file_system/local_file_system.py +++ b/rastervision_pipeline/rastervision/pipeline/file_system/local_file_system.py @@ -9,14 +9,16 @@ from rastervision.pipeline.file_system import (FileSystem, NotReadableError) -def make_dir(path, check_empty=False, force_empty=False, use_dirname=False): +def make_dir(path: str, + check_empty: bool = False, + force_empty: bool = False, + use_dirname: bool = False): """Make a local directory. Args: path: path to directory check_empty: if True, check that directory is empty - force_empty: if True, delete files if necessary to make directory - empty + force_empty: if True, delete files if necessary to make directory empty use_dirname: if True, use the the parent directory as path Raises: diff --git a/rastervision_pipeline/rastervision/pipeline/rv_config.py b/rastervision_pipeline/rastervision/pipeline/rv_config.py index df19c8ebd..f2681929e 100644 --- a/rastervision_pipeline/rastervision/pipeline/rv_config.py +++ b/rastervision_pipeline/rastervision/pipeline/rv_config.py @@ -138,9 +138,9 @@ def get_cache_dir(self) -> TemporaryDirectory: return cache_dir def set_everett_config(self, - profile: str = None, - rv_home: str = None, - config_overrides: dict[str, str] = None): + profile: str | None = None, + rv_home: str | None = None, + config_overrides: dict[str, str] | None = None): """Set Everett config. This sets up any other configuration using the Everett library. diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam_multi_source.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam_multi_source.py index b8092d047..c0e6a50e7 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam_multi_source.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/examples/semantic_segmentation/isprs_potsdam_multi_source.py @@ -83,23 +83,23 @@ def get_config(runner, by RV, with arguments from the command line, when this example is run. Args: - runner (Runner): Runner for the pipeline. - raw_uri (str): Directory where the raw data resides - processed_uri (str): Directory for storing processed data. + runner: Runner for the pipeline. + raw_uri: Directory where the raw data resides + processed_uri: Directory for storing processed data. E.g. crops for testing. - root_uri (str): Directory where all the output will be written. - nochip (bool): If True, read directly from the TIFF during + root_uri: Directory where all the output will be written. + nochip: If True, read directly from the TIFF during training instead of from pre-generated chips. The analyze and chip commands should not be run, if this is set to True. Defaults to True. - test (bool): If True, does the following simplifications: + test: If True, does the following simplifications: (1) Uses only the first 2 scenes (2) Uses only a 600x600 crop of the scenes (3) Trains for only 2 epochs and uses a batch size of 2. Defaults to False. Returns: - SemanticSegmentationConfig: A pipeline config. + A pipeline config. """ if not test: train_ids, val_ids = TRAIN_IDS, VAL_IDS @@ -256,7 +256,7 @@ def make_multi_raster_source( def make_crop(processed_uri: UriPath, raster_uri: UriPath, - label_uri: UriPath = None) -> tuple[UriPath, UriPath]: + label_uri: UriPath | None = None) -> tuple[UriPath, UriPath]: crop_uri = processed_uri / TEST_CROP_DIR / raster_uri.name if label_uri is not None: label_crop_uri = processed_uri / TEST_CROP_DIR / label_uri.name diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py index 1a1d6c5ea..a8b225609 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/dataset.py @@ -245,6 +245,7 @@ def __init__( self.stride: tuple[PosInt, PosInt] = ensure_tuple(stride) self.padding = padding self.pad_direction = pad_direction + self.windows = [] self.init_windows() def init_windows(self) -> None: diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py index c6f110d0e..0951be573 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, overload from collections.abc import Callable from enum import Enum @@ -58,9 +58,21 @@ def apply_transform(transform: A.BasicTransform, **kwargs) -> dict: return out -def classification_transformer(inp: tuple[np.ndarray, int | None], - transform=A.BasicTransform | None - ) -> tuple[np.ndarray, np.ndarray | None]: +@overload +def classification_transformer( + inp: tuple[np.ndarray, int], + transform: A.BasicTransform | None) -> tuple[np.ndarray, np.ndarray]: + ... + + +@overload +def classification_transformer( + inp: tuple[np.ndarray, None], + transform: A.BasicTransform | None) -> tuple[np.ndarray, None]: + ... + + +def classification_transformer(inp, transform): """Apply transform to image only.""" x, y = inp x = np.array(x) @@ -72,9 +84,21 @@ def classification_transformer(inp: tuple[np.ndarray, int | None], return x, y -def regression_transformer(inp: tuple[np.ndarray, Any | None], - transform=A.BasicTransform | None - ) -> tuple[np.ndarray, np.ndarray | None]: +@overload +def regression_transformer( + inp: tuple[np.ndarray, Any], + transform: A.BasicTransform | None) -> tuple[np.ndarray, np.ndarray]: + ... + + +@overload +def regression_transformer( + inp: tuple[np.ndarray, None], + transform: A.BasicTransform | None) -> tuple[np.ndarray, None]: + ... + + +def regression_transformer(inp, transform): """Apply transform to image only.""" x, y = inp x = np.array(x) @@ -143,10 +167,22 @@ def albu_to_yxyx(xyxy: np.ndarray, return yxyx +@overload def object_detection_transformer( - inp: tuple[np.ndarray, tuple[np.ndarray, np.ndarray, str] | None], - transform: A.BasicTransform | None = None + inp: tuple[np.ndarray, tuple[np.ndarray, np.ndarray, str]], + transform: A.BasicTransform | None ) -> tuple[torch.Tensor, BoxList | None]: + ... + + +@overload +def object_detection_transformer( + inp: tuple[np.ndarray, None], + transform: A.BasicTransform | None) -> tuple[torch.Tensor, None]: + ... + + +def object_detection_transformer(inp, transform): """Apply transform to image, bounding boxes, and labels. Also perform normalization and conversion to pytorch tensors. @@ -214,10 +250,21 @@ def object_detection_transformer( return x, y +@overload def semantic_segmentation_transformer( - inp: tuple[np.ndarray, np.ndarray | None], - transform=A.BasicTransform | None + inp: tuple[np.ndarray, np.ndarray], transform: A.BasicTransform | None ) -> tuple[np.ndarray, np.ndarray | None]: + ... + + +@overload +def semantic_segmentation_transformer( + inp: tuple[np.ndarray, None], + transform: A.BasicTransform | None) -> tuple[np.ndarray, None]: + ... + + +def semantic_segmentation_transformer(inp, transform): """Apply transform to image and mask.""" x, y = inp x = np.array(x)