Skip to content

Commit

Permalink
misc type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Aug 5, 2024
1 parent a4ea0dd commit d5e3b1c
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 45 deletions.
10 changes: 5 additions & 5 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def get_matching_s3_objects(
)
while True:
resp: dict = s3.list_objects_v2(**kwargs)
dirs: list[dict] = resp.get('CommonPrefixes', {})
files: list[dict] = resp.get('Contents', {})
dirs: list[dict[str, Any]] = resp.get('CommonPrefixes', {})
files: list[dict[str, Any]] = resp.get('Contents', {})
for obj in dirs:
key = obj['Prefix']
key: str = obj['Prefix']
if key.startswith(prefix) and key.endswith(suffix):
yield key, obj
for obj in files:
key = obj['Key']
key: str = obj['Key']
if key.startswith(prefix) and key.endswith(suffix):
yield key, obj
# The S3 API is paginated, returning up to 1000 keys at a time.
Expand Down Expand Up @@ -214,7 +214,7 @@ def write_bytes(uri: str, data: bytes) -> None:
file_size = len(data)
with io.BytesIO(data) as str_buffer:
try:
with progressbar(file_size, desc=f'Uploading') as bar:
with progressbar(file_size, desc='Uploading') as bar:
s3.upload_fileobj(
Fileobj=str_buffer,
Bucket=bucket,
Expand Down
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/data/class_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def null_class_id(self) -> int:
raise ValueError('null_class is not set')
return self.get_class_id(self.null_class)

def get_color_to_class_id(self) -> dict:
def get_color_to_class_id(self) -> dict[str | tuple[int, int, int], int]:
return dict([(self.colors[i], i) for i in range(len(self.colors))])

def ensure_null_class(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ def __repr__(self) -> str:
"""
return out

def _map_to_pixel(self, map_point: tuple[float, float] | np.ndarray
) -> tuple[int, int] | np.ndarray:
def _map_to_pixel(
self,
map_point: tuple[float, float] | tuple[np.ndarray, np.ndarray]
) -> tuple[int, int] | tuple[np.ndarray, np.ndarray]:
"""Transform point from map to pixel-based coordinates.
Args:
Expand All @@ -82,8 +84,9 @@ def _map_to_pixel(self, map_point: tuple[float, float] | np.ndarray
pixel_point = (col, row)
return pixel_point

def _pixel_to_map(self, pixel_point: tuple[int, int] | np.ndarray
) -> tuple[float, float] | np.ndarray:
def _pixel_to_map(
self, pixel_point: tuple[int, int] | tuple[np.ndarray, np.ndarray]
) -> tuple[float, float] | tuple[np.ndarray, np.ndarray]:
"""Transform point from pixel to map-based coordinates.
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Sequence, overload
from os.path import join
import logging

Expand Down Expand Up @@ -333,11 +333,17 @@ def write_vector_output(self, vo: 'VectorOutputConfig', mask: np.ndarray,
out_uri = vo.get_uri(vector_output_dir, self.class_config)
json_to_file(geojson, out_uri)

def _clip_to_extent(self,
extent: Box,
window: Box,
arr: np.ndarray | None = None
) -> tuple[Box, np.ndarray | None]:
@overload
def _clip_to_extent(self, extent: Box, window: Box,
arr: np.ndarray) -> tuple[Box, np.ndarray]:
...

@overload
def _clip_to_extent(self, extent: Box, window: Box,
arr: None = ...) -> tuple[Box, None]:
...

def _clip_to_extent(self, extent, window, arr=None):
clipped_window = window.intersection(extent)
if arr is not None:
h, w = clipped_window.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def sample(self, n: int = 1) -> np.ndarray:

def triangulate_polygon(self, polygon: Polygon) -> dict:
"""Triangulate polygon.
Extracts vertices and edges from the polygon (and its holes, if any)
and passes them to the Triangle library for triangulation.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def enough_target_pixels(self, label_arr: np.ndarray) -> bool:
True (the window does contain interesting pixels) or False.
"""
target_count = 0
if self.target_class_ids is None:
raise ValueError('target_class_ids not specified.')
for class_id in self.target_class_ids:
target_count += (label_arr == class_id).sum()
enough_target_pixels = target_count >= self.target_count_threshold
Expand Down
9 changes: 6 additions & 3 deletions rastervision_pipeline/rastervision/pipeline/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import sys
import os
import logging
Expand Down Expand Up @@ -44,7 +44,7 @@ def convert_bool_args(args: dict) -> dict:

def get_configs(cfg_module_path: str,
runner: str | None = None,
args: dict[str, any] | None = None) -> list[PipelineConfig]:
args: dict[str, Any] | None = None) -> list[PipelineConfig]:
"""Get PipelineConfigs from a module.
Calls a get_config(s) function with some arguments from the CLI
Expand Down Expand Up @@ -74,14 +74,17 @@ def get_configs(cfg_module_path: str,


def get_configs_from_module(cfg_module_path: str, runner: str,
args: dict[str, any]) -> list[PipelineConfig]:
args: dict[str, Any]) -> list[PipelineConfig]:
import importlib
import importlib.util

if cfg_module_path.endswith('.py'):
# From https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa
spec = importlib.util.spec_from_file_location('rastervision.pipeline',
cfg_module_path)
if spec is None:
raise ImportError(
f'Failed to read module spec from {cfg_module_path}.')
cfg_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cfg_module)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from rastervision.pipeline.file_system import (FileSystem, NotReadableError)


def make_dir(path, check_empty=False, force_empty=False, use_dirname=False):
def make_dir(path: str,
check_empty: bool = False,
force_empty: bool = False,
use_dirname: bool = False):
"""Make a local directory.
Args:
path: path to directory
check_empty: if True, check that directory is empty
force_empty: if True, delete files if necessary to make directory
empty
force_empty: if True, delete files if necessary to make directory empty
use_dirname: if True, use the the parent directory as path
Raises:
Expand Down
6 changes: 3 additions & 3 deletions rastervision_pipeline/rastervision/pipeline/rv_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def get_cache_dir(self) -> TemporaryDirectory:
return cache_dir

def set_everett_config(self,
profile: str = None,
rv_home: str = None,
config_overrides: dict[str, str] = None):
profile: str | None = None,
rv_home: str | None = None,
config_overrides: dict[str, str] | None = None):
"""Set Everett config.
This sets up any other configuration using the Everett library.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,23 @@ def get_config(runner,
by RV, with arguments from the command line, when this example is run.
Args:
runner (Runner): Runner for the pipeline.
raw_uri (str): Directory where the raw data resides
processed_uri (str): Directory for storing processed data.
runner: Runner for the pipeline.
raw_uri: Directory where the raw data resides
processed_uri: Directory for storing processed data.
E.g. crops for testing.
root_uri (str): Directory where all the output will be written.
nochip (bool): If True, read directly from the TIFF during
root_uri: Directory where all the output will be written.
nochip: If True, read directly from the TIFF during
training instead of from pre-generated chips. The analyze and chip
commands should not be run, if this is set to True. Defaults to
True.
test (bool): If True, does the following simplifications:
test: If True, does the following simplifications:
(1) Uses only the first 2 scenes
(2) Uses only a 600x600 crop of the scenes
(3) Trains for only 2 epochs and uses a batch size of 2.
Defaults to False.
Returns:
SemanticSegmentationConfig: A pipeline config.
A pipeline config.
"""
if not test:
train_ids, val_ids = TRAIN_IDS, VAL_IDS
Expand Down Expand Up @@ -256,7 +256,7 @@ def make_multi_raster_source(

def make_crop(processed_uri: UriPath,
raster_uri: UriPath,
label_uri: UriPath = None) -> tuple[UriPath, UriPath]:
label_uri: UriPath | None = None) -> tuple[UriPath, UriPath]:
crop_uri = processed_uri / TEST_CROP_DIR / raster_uri.name
if label_uri is not None:
label_crop_uri = processed_uri / TEST_CROP_DIR / label_uri.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
self.stride: tuple[PosInt, PosInt] = ensure_tuple(stride)
self.padding = padding
self.pad_direction = pad_direction
self.windows = []
self.init_windows()

def init_windows(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, overload
from collections.abc import Callable
from enum import Enum

Expand Down Expand Up @@ -58,9 +58,21 @@ def apply_transform(transform: A.BasicTransform, **kwargs) -> dict:
return out


def classification_transformer(inp: tuple[np.ndarray, int | None],
transform=A.BasicTransform | None
) -> tuple[np.ndarray, np.ndarray | None]:
@overload
def classification_transformer(
inp: tuple[np.ndarray, int],
transform: A.BasicTransform | None) -> tuple[np.ndarray, np.ndarray]:
...


@overload
def classification_transformer(
inp: tuple[np.ndarray, None],
transform: A.BasicTransform | None) -> tuple[np.ndarray, None]:
...


def classification_transformer(inp, transform):
"""Apply transform to image only."""
x, y = inp
x = np.array(x)
Expand All @@ -72,9 +84,21 @@ def classification_transformer(inp: tuple[np.ndarray, int | None],
return x, y


def regression_transformer(inp: tuple[np.ndarray, Any | None],
transform=A.BasicTransform | None
) -> tuple[np.ndarray, np.ndarray | None]:
@overload
def regression_transformer(
inp: tuple[np.ndarray, Any],
transform: A.BasicTransform | None) -> tuple[np.ndarray, np.ndarray]:
...


@overload
def regression_transformer(
inp: tuple[np.ndarray, None],
transform: A.BasicTransform | None) -> tuple[np.ndarray, None]:
...


def regression_transformer(inp, transform):
"""Apply transform to image only."""
x, y = inp
x = np.array(x)
Expand Down Expand Up @@ -143,10 +167,22 @@ def albu_to_yxyx(xyxy: np.ndarray,
return yxyx


@overload
def object_detection_transformer(
inp: tuple[np.ndarray, tuple[np.ndarray, np.ndarray, str] | None],
transform: A.BasicTransform | None = None
inp: tuple[np.ndarray, tuple[np.ndarray, np.ndarray, str]],
transform: A.BasicTransform | None
) -> tuple[torch.Tensor, BoxList | None]:
...


@overload
def object_detection_transformer(
inp: tuple[np.ndarray, None],
transform: A.BasicTransform | None) -> tuple[torch.Tensor, None]:
...


def object_detection_transformer(inp, transform):
"""Apply transform to image, bounding boxes, and labels. Also perform
normalization and conversion to pytorch tensors.
Expand Down Expand Up @@ -214,10 +250,21 @@ def object_detection_transformer(
return x, y


@overload
def semantic_segmentation_transformer(
inp: tuple[np.ndarray, np.ndarray | None],
transform=A.BasicTransform | None
inp: tuple[np.ndarray, np.ndarray], transform: A.BasicTransform | None
) -> tuple[np.ndarray, np.ndarray | None]:
...


@overload
def semantic_segmentation_transformer(
inp: tuple[np.ndarray, None],
transform: A.BasicTransform | None) -> tuple[np.ndarray, None]:
...


def semantic_segmentation_transformer(inp, transform):
"""Apply transform to image and mask."""
x, y = inp
x = np.array(x)
Expand Down

0 comments on commit d5e3b1c

Please sign in to comment.