Skip to content

Commit

Permalink
only import tpying.Self if TPYE_CHECKING (#2221)
Browse files Browse the repository at this point in the history
typing.Self is only available in python >=3.11, so this makes the codebase compatible with earlier python versions again
  • Loading branch information
AdeelH authored Aug 13, 2024
1 parent 546297f commit 1f3b91a
Show file tree
Hide file tree
Showing 24 changed files with 170 additions and 127 deletions.
63 changes: 32 additions & 31 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (TYPE_CHECKING, Literal, Self)
from typing import (TYPE_CHECKING, Literal)
from collections.abc import Callable
from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt
import math
Expand All @@ -14,6 +14,7 @@
ensure_tuple)

if TYPE_CHECKING:
from typing import Self
from shapely.geometry import MultiPolygon
from shapely.geometry.base import BaseGeometry

Expand Down Expand Up @@ -44,11 +45,11 @@ def __init__(self, ymin: int, xmin: int, ymax: int, xmax: int):
self.ymax = ymax
self.xmax = xmax

def __eq__(self, other: Self) -> bool:
def __eq__(self, other: 'Self') -> bool:
"""Return true if other has same coordinates."""
return self.tuple_format() == other.tuple_format()

def __ne__(self, other: Self):
def __ne__(self, other: 'Self'):
"""Return true if other has different coordinates."""
return self.tuple_format() != other.tuple_format()

Expand All @@ -63,7 +64,7 @@ def width(self) -> int:
return self.xmax - self.xmin

@property
def extent(self) -> Self:
def extent(self) -> 'Self':
"""Return a Box(0, 0, h, w) representing the size of this Box."""
return Box(0, 0, self.height, self.width)

Expand All @@ -77,7 +78,7 @@ def area(self) -> int:
"""Return area of Box."""
return self.height * self.width

def normalize(self) -> Self:
def normalize(self) -> 'Self':
"""Ensure ymin <= ymax and xmin <= xmax."""
ymin, ymax = sorted((self.ymin, self.ymax))
xmin, xmax = sorted((self.xmin, self.xmax))
Expand Down Expand Up @@ -110,7 +111,7 @@ def npbox_format(self) -> np.ndarray:
return np.array(self.tuple_format(), dtype=float)

@staticmethod
def to_npboxes(boxes: list[Self]) -> np.ndarray:
def to_npboxes(boxes: list['Self']) -> np.ndarray:
"""Return nx4 numpy array from list of Box."""
nb_boxes = len(boxes)
npboxes = np.empty((nb_boxes, 4))
Expand Down Expand Up @@ -139,15 +140,15 @@ def geojson_coordinates(self) -> list[tuple[int, int]]:
sw = [self.xmax, self.ymin]
return [nw, ne, se, sw, nw]

def make_random_square_container(self, size: int) -> Self:
def make_random_square_container(self, size: int) -> 'Self':
"""Return a new square Box that contains this Box.
Args:
size: the width and height of the new Box
"""
return self.make_random_box_container(size, size)

def make_random_box_container(self, out_h: int, out_w: int) -> Self:
def make_random_box_container(self, out_h: int, out_w: int) -> 'Self':
"""Return a new rectangular Box that contains this Box.
Args:
Expand All @@ -173,7 +174,7 @@ def make_random_box_container(self, out_h: int, out_w: int) -> Self:

return Box(out_ymin, out_xmin, out_ymin + out_h, out_xmin + out_w)

def make_random_square(self, size: int) -> Self:
def make_random_square(self, size: int) -> 'Self':
"""Return new randomly positioned square Box that lies inside this Box.
Args:
Expand All @@ -197,7 +198,7 @@ def make_random_square(self, size: int) -> Self:

return Box.make_square(rand_y, rand_x, size)

def intersection(self, other: Self) -> Self:
def intersection(self, other: 'Self') -> 'Self':
"""Return the intersection of this Box and the other.
Args:
Expand All @@ -218,7 +219,7 @@ def intersection(self, other: Self) -> Self:
ymax = min(box1.ymax, box2.ymax)
return Box(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)

def intersects(self, other: Self) -> bool:
def intersects(self, other: 'Self') -> bool:
box1 = self.normalize()
box2 = other.normalize()
if box1.ymax <= box2.ymin or box1.ymin >= box2.ymax:
Expand All @@ -228,7 +229,7 @@ def intersects(self, other: Self) -> bool:
return True

@classmethod
def from_npbox(cls, npbox: np.ndarray) -> Self:
def from_npbox(cls, npbox: np.ndarray) -> 'Self':
"""Return new Box based on npbox format.
Args:
Expand All @@ -237,13 +238,13 @@ def from_npbox(cls, npbox: np.ndarray) -> Self:
return Box(*npbox)

@classmethod
def from_shapely(cls, shape: 'BaseGeometry') -> Self:
def from_shapely(cls, shape: 'BaseGeometry') -> 'Self':
"""Instantiate from the bounds of a shapely geometry."""
xmin, ymin, xmax, ymax = shape.bounds
return Box(ymin, xmin, ymax, xmax)

@classmethod
def from_rasterio(cls, rio_window: RioWindow) -> Self:
def from_rasterio(cls, rio_window: RioWindow) -> 'Self':
"""Instantiate from a rasterio window."""
yslice, xslice = rio_window.toslices()
return Box(yslice.start, xslice.start, yslice.stop, xslice.stop)
Expand Down Expand Up @@ -274,12 +275,12 @@ def to_slices(self, h_step: int | None = None,
return slice(self.ymin, self.ymax, h_step), slice(
self.xmin, self.xmax, w_step)

def translate(self, dy: int, dx: int) -> Self:
def translate(self, dy: int, dx: int) -> 'Self':
"""Translate window along y and x axes by the given distances."""
ymin, xmin, ymax, xmax = self
return Box(ymin + dy, xmin + dx, ymax + dy, xmax + dx)

def to_global_coords(self, bbox: Self) -> Self:
def to_global_coords(self, bbox: 'Self') -> 'Self':
"""Go from bbox coords to global coords.
E.g., Given a box Box(20, 20, 40, 40) and bbox Box(20, 20, 100, 100),
Expand All @@ -289,7 +290,7 @@ def to_global_coords(self, bbox: Self) -> Self:
"""
return self.translate(dy=bbox.ymin, dx=bbox.xmin)

def to_local_coords(self, bbox: Self) -> Self:
def to_local_coords(self, bbox: 'Self') -> 'Self':
"""Go from to global coords bbox coords.
E.g., Given a box Box(40, 40, 60, 60) and bbox Box(20, 20, 100, 100),
Expand All @@ -299,7 +300,7 @@ def to_local_coords(self, bbox: Self) -> Self:
"""
return self.translate(dy=-bbox.ymin, dx=-bbox.xmin)

def reproject(self, transform_fn: Callable[[tuple], tuple]) -> Self:
def reproject(self, transform_fn: Callable[[tuple], tuple]) -> 'Self':
"""Reprojects this box based on a transform function.
Args:
Expand All @@ -313,23 +314,23 @@ def reproject(self, transform_fn: Callable[[tuple], tuple]) -> Self:
return Box(ymin, xmin, ymax, xmax)

@staticmethod
def make_square(ymin, xmin, size) -> Self:
def make_square(ymin, xmin, size) -> 'Self':
"""Return new square Box."""
return Box(ymin, xmin, ymin + size, xmin + size)

def center_crop(self, edge_offset_y: int, edge_offset_x: int) -> Self:
def center_crop(self, edge_offset_y: int, edge_offset_x: int) -> 'Self':
"""Return Box whose sides are eroded by the given offsets.
Box(0, 0, 10, 10).center_crop(2, 4) == Box(2, 4, 8, 6)
"""
return Box(self.ymin + edge_offset_y, self.xmin + edge_offset_x,
self.ymax - edge_offset_y, self.xmax - edge_offset_x)

def erode(self, erosion_sz) -> Self:
def erode(self, erosion_sz) -> 'Self':
"""Return new Box whose sides are eroded by erosion_sz."""
return self.center_crop(erosion_sz, erosion_sz)

def buffer(self, buffer_sz: float, max_extent: Self) -> Self:
def buffer(self, buffer_sz: float, max_extent: 'Self') -> 'Self':
"""Return new Box whose sides are buffered by buffer_sz.
The resulting box is clipped so that the values of the corners are
Expand All @@ -351,15 +352,15 @@ def buffer(self, buffer_sz: float, max_extent: Self) -> Self:
min(max_extent.width,
int(self.xmax) + delta_width))

def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> Self:
def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> 'Self':
"""Pad sides by the given amount."""
return Box(
ymin=self.ymin - ymin,
xmin=self.xmin - xmin,
ymax=self.ymax + ymax,
xmax=self.xmax + xmax)

def copy(self) -> Self:
def copy(self) -> 'Self':
return Box(*self)

def get_windows(
Expand All @@ -368,7 +369,7 @@ def get_windows(
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end'
) -> list[Self]:
) -> list['Self']:
"""Return sliding windows for given size, stride, and padding.
Each of size, stride, and padding can be either a positive int or
Expand Down Expand Up @@ -452,13 +453,13 @@ def to_dict(self) -> dict[str, int]:
}

@classmethod
def from_dict(cls, d: dict) -> Self:
def from_dict(cls, d: dict) -> 'Self':
return cls(d['ymin'], d['xmin'], d['ymax'], d['xmax'])

@staticmethod
def filter_by_aoi(windows: list[Self],
def filter_by_aoi(windows: list['Self'],
aoi_polygons: list[Polygon],
within: bool = True) -> list[Self]:
within: bool = True) -> list['Self']:
"""Filters windows by a list of AOI polygons
Args:
Expand All @@ -478,7 +479,7 @@ def filter_by_aoi(windows: list[Self],
return out

@staticmethod
def within_aoi(window: Self,
def within_aoi(window: 'Self',
aoi_polygons: Polygon | list[Polygon]) -> bool:
"""Check if window is within the union of given AOI polygons."""
aoi_polygons: Polygon | MultiPolygon = unary_union(aoi_polygons)
Expand All @@ -487,15 +488,15 @@ def within_aoi(window: Self,
return out

@staticmethod
def intersects_aoi(window: Self,
def intersects_aoi(window: 'Self',
aoi_polygons: Polygon | list[Polygon]) -> bool:
"""Check if window intersects with the union of given AOI polygons."""
aoi_polygons: Polygon | MultiPolygon = unary_union(aoi_polygons)
w = window.to_shapely()
out = aoi_polygons.intersects(w)
return out

def __contains__(self, query: Self | tuple[int, int]) -> bool:
def __contains__(self, query: 'Self | tuple[int, int]') -> bool:
"""Check if box or point is contained within this box.
Args:
Expand Down
9 changes: 6 additions & 3 deletions rastervision_core/rastervision/core/data/class_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Self
from typing import TYPE_CHECKING

from rastervision.pipeline.config import (Config, register_config, ConfigError,
Field, model_validator)
from rastervision.core.data.utils import color_to_triple, normalize_color

if TYPE_CHECKING:
from typing import Self

DEFAULT_NULL_CLASS_NAME = 'null'
DEFAULT_NULL_CLASS_COLOR = 'black'

Expand Down Expand Up @@ -33,7 +36,7 @@ class ClassConfig(Config):
'added automatically.')

@model_validator(mode='after')
def validate_colors(self) -> Self:
def validate_colors(self) -> 'Self':
"""Compare length w/ names. Also auto-generate if not specified."""
names = self.names
colors = self.colors
Expand All @@ -47,7 +50,7 @@ def validate_colors(self) -> Self:
return self

@model_validator(mode='after')
def validate_null_class(self) -> Self:
def validate_null_class(self) -> 'Self':
"""Check if in names. If 'null' in names, use it as null class."""
names = self.names
null_class = self.null_class
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Self
from typing import TYPE_CHECKING, Any
from pyproj import Transformer

import numpy as np
Expand All @@ -9,6 +9,9 @@
from rastervision.core.data.crs_transformer import (CRSTransformer,
IdentityCRSTransformer)

if TYPE_CHECKING:
from typing import Self


class RasterioCRSTransformer(CRSTransformer):
"""Transformer for a RasterioRasterSource."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (TYPE_CHECKING, Any, Iterable, Self)
from typing import (TYPE_CHECKING, Any, Iterable)
from dataclasses import dataclass

import numpy as np
Expand All @@ -8,6 +8,7 @@
from rastervision.core.utils.types import Vector

if TYPE_CHECKING:
from typing import Self
from rastervision.core.data import (ClassConfig, CRSTransformer)
from shapely.geometry import Polygon

Expand Down Expand Up @@ -38,11 +39,11 @@ def __init__(self,
def __len__(self) -> int:
return len(self.cell_to_label)

def __eq__(self, other: Self) -> bool:
def __eq__(self, other: 'Self') -> bool:
return (isinstance(other, ChipClassificationLabels)
and self.cell_to_label == other.cell_to_label)

def __add__(self, other: Self) -> Self:
def __add__(self, other: 'Self') -> 'Self':
result = ChipClassificationLabels()
result.extend(self)
result.extend(other)
Expand All @@ -66,7 +67,7 @@ def from_predictions(cls, windows: Iterable['Box'],
return super().from_predictions(windows, predictions)

@classmethod
def make_empty(cls) -> Self:
def make_empty(cls) -> 'Self':
return ChipClassificationLabels()

def filter_by_aoi(self, aoi_polygons: Iterable['Polygon']):
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_values(self) -> list[ClassificationLabel]:
"""Return list of class_ids and scores for all cells."""
return list(self.cell_to_label.values())

def extend(self, labels: Self) -> None:
def extend(self, labels: 'Self') -> None:
"""Adds cells contained in labels.
Args:
Expand Down
11 changes: 6 additions & 5 deletions rastervision_core/rastervision/core/data/label/labels.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Defines the abstract Labels class."""

from typing import TYPE_CHECKING, Any, Iterable, Self
from typing import TYPE_CHECKING, Any, Iterable
from abc import ABC, abstractmethod

if TYPE_CHECKING:
from typing import Self
from shapely.geometry import Polygon
from rastervision.core.box import Box

Expand All @@ -16,14 +17,14 @@ class Labels(ABC):
"""

@abstractmethod
def __add__(self, other: Self):
def __add__(self, other: 'Self'):
"""Add labels to these labels.
Returns a concatenation of this and the other labels.
"""

@abstractmethod
def filter_by_aoi(self, aoi_polygons: list['Polygon']) -> Self:
def filter_by_aoi(self, aoi_polygons: list['Polygon']) -> 'Self':
"""Return a copy of these labels filtered by given AOI polygons.
Args:
Expand All @@ -41,7 +42,7 @@ def __setitem__(self, key, value):

@classmethod
@abstractmethod
def make_empty(cls) -> Self:
def make_empty(cls) -> 'Self':
"""Instantiate an empty instance of this class.
Returns:
Expand All @@ -51,7 +52,7 @@ def make_empty(cls) -> Self:

@classmethod
def from_predictions(cls, windows: Iterable['Box'],
predictions: Iterable[Any]) -> Self:
predictions: Iterable[Any]) -> 'Self':
"""Instantiate from windows and their corresponding predictions.
This makes no assumptions about the type or format of the predictions.
Expand Down
Loading

0 comments on commit 1f3b91a

Please sign in to comment.