Skip to content

Commit

Permalink
1534 type hints numpy 1 20 (#1536)
Browse files Browse the repository at this point in the history
* numpy dtype alias

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* is_tensor => isinstance

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* type of dtype

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* relaxes some typing constraints

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes unit tests

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* update based on the comments

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

* fixes docstring typos

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Feb 2, 2021
1 parent a6cb37c commit c415509
Show file tree
Hide file tree
Showing 60 changed files with 306 additions and 261 deletions.
2 changes: 1 addition & 1 deletion monai/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
print_gpu_info,
print_system_info,
)
from .type_definitions import IndexSelection, KeysCollection
from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor
2 changes: 1 addition & 1 deletion monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_system_info() -> OrderedDict:
_dict_append(
output,
"Avg. sensor temp. (Celsius)",
lambda: round(
lambda: np.round(
np.mean([item.current for sublist in psutil.sensors_temperatures().values() for item in sublist], 1)
),
)
Expand Down
20 changes: 18 additions & 2 deletions monai/config/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, Hashable, Iterable, Union
from typing import Collection, Hashable, Iterable, TypeVar, Union

__all__ = ["KeysCollection", "IndexSelection"]
import numpy as np
import torch

__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor"]

"""Commonly used concepts
This module provides naming and type specifications for commonly used concepts
Expand Down Expand Up @@ -51,3 +54,16 @@
The indices must be integers, and if a container of indices is specified, the
container must be iterable.
"""

DtypeLike = Union[
np.dtype,
type,
None,
]
"""Type of datatypes
adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py
"""

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
# Unlike Union can create a dependence between parameter(s) / return(s)
NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor)
2 changes: 1 addition & 1 deletion monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
"""
save_key = meta_data["filename_or_obj"] if meta_data else str(self._data_index)
self._data_index += 1
if torch.is_tensor(data):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if not isinstance(data, np.ndarray):
raise AssertionError
Expand Down
3 changes: 2 additions & 1 deletion monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from torch.utils.data import Dataset

from monai.config import DtypeLike
from monai.data.image_reader import ImageReader
from monai.transforms import LoadImage, Randomizable, apply_transform
from monai.utils import MAX_SEED, get_seed
Expand All @@ -36,7 +37,7 @@ def __init__(
transform: Optional[Callable] = None,
seg_transform: Optional[Callable] = None,
image_only: bool = True,
dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float32,
reader: Optional[Union[ImageReader, str]] = None,
*args,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.config import KeysCollection
from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.utils import ensure_tuple, optional_import

Expand Down Expand Up @@ -244,7 +244,7 @@ def _get_affine(self, img) -> np.ndarray:
affine = np.eye(direction.shape[0] + 1)
affine[(slice(-1), slice(-1))] = direction @ np.diag(spacing)
affine[(slice(-1), -1)] = origin
return affine
return np.asarray(affine)

def _get_spatial_shape(self, img) -> np.ndarray:
"""
Expand All @@ -258,7 +258,7 @@ def _get_spatial_shape(self, img) -> np.ndarray:
shape.reverse()
return np.asarray(shape)

def _get_array_data(self, img) -> np.ndarray:
def _get_array_data(self, img):
"""
Get the raw array data of the image, converted to Numpy array.
Expand Down Expand Up @@ -295,7 +295,7 @@ class NibabelReader(ImageReader):
"""

def __init__(self, as_closest_canonical: bool = False, dtype: Optional[np.dtype] = np.float32, **kwargs):
def __init__(self, as_closest_canonical: bool = False, dtype: DtypeLike = np.float32, **kwargs):
super().__init__()
self.as_closest_canonical = as_closest_canonical
self.dtype = dtype
Expand Down Expand Up @@ -385,7 +385,7 @@ def _get_affine(self, img) -> np.ndarray:
img: a Nibabel image object loaded from a image file.
"""
return img.affine.copy()
return np.array(img.affine, copy=True)

def _get_spatial_shape(self, img) -> np.ndarray:
"""
Expand Down
9 changes: 5 additions & 4 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.config import DtypeLike
from monai.data.nifti_writer import write_nifti
from monai.data.utils import create_file_basename
from monai.utils import GridSampleMode, GridSamplePadMode
Expand All @@ -36,8 +37,8 @@ def __init__(
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
Expand Down Expand Up @@ -100,7 +101,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
affine = meta_data.get("affine", None) if meta_data else None
spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None

if torch.is_tensor(data):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir)
Expand All @@ -109,7 +110,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
while len(data.shape) < 4:
data = np.expand_dims(data, -1)
# change data to "channel last" format and write to nifti format file
data = np.moveaxis(data, 0, -1)
data = np.moveaxis(np.asarray(data), 0, -1)
write_nifti(
data,
file_name=filename,
Expand Down
9 changes: 5 additions & 4 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.config import DtypeLike
from monai.data.utils import compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import
Expand All @@ -27,12 +28,12 @@ def write_nifti(
affine: Optional[np.ndarray] = None,
target_affine: Optional[np.ndarray] = None,
resample: bool = True,
output_spatial_shape: Optional[Sequence[int]] = None,
output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None,
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
) -> None:
"""
Write numpy data into NIfTI files to disk. This function converts data
Expand Down Expand Up @@ -126,7 +127,7 @@ def write_nifti(
transform = np.linalg.inv(_affine) @ target_affine
if output_spatial_shape is None:
output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine)
output_spatial_shape_ = list(output_spatial_shape)
output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else []
if data.ndim > 3: # multi channel, resampling each channel
while len(output_spatial_shape_) < 3:
output_spatial_shape_ = output_spatial_shape_ + [1]
Expand Down
6 changes: 3 additions & 3 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
self._data_index += 1
spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None

if torch.is_tensor(data):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir)
Expand All @@ -95,12 +95,12 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if data.shape[0] == 1:
data = data.squeeze(0)
elif 2 < data.shape[0] < 5:
data = np.moveaxis(data, 0, -1)
data = np.moveaxis(np.asarray(data), 0, -1)
else:
raise ValueError(f"Unsupported number of channels: {data.shape[0]}, available options are [1, 3, 4]")

write_png(
data,
np.asarray(data),
file_name=filename,
output_spatial_shape=spatial_shape,
mode=self.mode,
Expand Down
4 changes: 2 additions & 2 deletions monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def write_png(
data = np.expand_dims(data, 0) # make a channel
data = xform(data)[0] # first channel
if mode != InterpolateMode.NEAREST:
data = np.clip(data, _min, _max)
data = np.clip(data, _min, _max) # type: ignore

if scale is not None:
data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1]
data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1]
if scale == np.iinfo(np.uint8).max:
data = (scale * data).astype(np.uint8)
elif scale == np.iinfo(np.uint16).max:
Expand Down
4 changes: 2 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def rectify_header_sform_qform(img_nii):
return img_nii


def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True) -> np.ndarray:
def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True):
"""
To make column norm of `affine` the same as `scale`. If diagonal is False,
returns an affine that combines orthogonal rotation and the new scale.
Expand Down Expand Up @@ -379,7 +379,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru


def compute_shape_offset(
spatial_shape: np.ndarray, in_affine: np.ndarray, out_affine: np.ndarray
spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: np.ndarray, out_affine: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given input and output affine, compute appropriate shapes
Expand Down
2 changes: 1 addition & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

class IterationEvents(EventEnum):
"""
Addtional Events engine can register and trigger in the iteration process.
Additional Events engine can register and trigger in the iteration process.
Refer to the example in ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/events.py#L146
These Events can be triggered during training iteration:
`FORWARD_COMPLETED` is the Event when `network(image, label)` completed.
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/iteration_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compute(self) -> Any:
# save score of every image into engine.state for other components
if self.save_details:
if self._engine is None or self._name is None:
raise RuntimeError("plesae call the attach() function to connect expected engine first.")
raise RuntimeError("please call the attach() function to connect expected engine first.")
self._engine.state.metric_details[self._name] = _scores

result: torch.Tensor = torch.zeros(1)
Expand All @@ -108,7 +108,7 @@ def compute(self) -> Any:
# broadcast result to all processes
result = idist.broadcast(result, src=0)

return result.item() if torch.is_tensor(result) else result
return result.item() if isinstance(result, torch.Tensor) else result

def _reduce(self, scores) -> Any:
return do_metric_reduction(scores, MetricReduction.MEAN)[0]
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MetricsSaver:
should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`].
default to None.
save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0.
delimiter: the delimiter charactor in CSV file, default to "\t".
delimiter: the delimiter character in CSV file, default to "\t".
output_type: expected output file type, supported types: ["csv"], default to "csv".
"""
Expand Down
5 changes: 3 additions & 2 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from monai.config import DtypeLike
from monai.data import NiftiSaver, PNGSaver
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import

Expand All @@ -38,8 +39,8 @@ def __init__(
mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
name: Optional[str] = None,
Expand Down
6 changes: 4 additions & 2 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ def _default_iteration_print(self, engine: Engine) -> None:
" {}:{}".format(name, type(value))
)
continue # not printing multi dimensional output
out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value)
out_str += self.key_var_format.format(name, value.item() if isinstance(value, torch.Tensor) else value)
else:
if is_scalar(loss): # not printing multi dimensional output
out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss)
out_str += self.key_var_format.format(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss
)
else:
warnings.warn(
"ignoring non-scalar output in StatsHandler,"
Expand Down
14 changes: 9 additions & 5 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,13 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
" {}:{}".format(name, type(value))
)
continue # not plot multi dimensional output
writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration)
writer.add_scalar(
name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration
)
elif is_scalar(loss): # not printing multi dimensional output
writer.add_scalar(self.tag_name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration)
writer.add_scalar(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration
)
else:
warnings.warn(
"ignoring non-scalar output in TensorBoardStatsHandler,"
Expand Down Expand Up @@ -261,7 +265,7 @@ def __call__(self, engine: Engine) -> None:
"""
step = self.global_iter_transform(engine.state.epoch if self.epoch_level else engine.state.iteration)
show_images = self.batch_transform(engine.state.batch)[0]
if torch.is_tensor(show_images):
if isinstance(show_images, torch.Tensor):
show_images = show_images.detach().cpu().numpy()
if show_images is not None:
if not isinstance(show_images, np.ndarray):
Expand All @@ -274,7 +278,7 @@ def __call__(self, engine: Engine) -> None:
)

show_labels = self.batch_transform(engine.state.batch)[1]
if torch.is_tensor(show_labels):
if isinstance(show_labels, torch.Tensor):
show_labels = show_labels.detach().cpu().numpy()
if show_labels is not None:
if not isinstance(show_labels, np.ndarray):
Expand All @@ -287,7 +291,7 @@ def __call__(self, engine: Engine) -> None:
)

show_outputs = self.output_transform(engine.state.output)
if torch.is_tensor(show_outputs):
if isinstance(show_outputs, torch.Tensor):
show_outputs = show_outputs.detach().cpu().numpy()
if show_outputs is not None:
if not isinstance(show_outputs, np.ndarray):
Expand Down
8 changes: 4 additions & 4 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor:
data: source tensor to pad and execute all_gather in distributed data parallel.
"""
if not torch.is_tensor(data):
if not isinstance(data, torch.Tensor):
raise ValueError("input data must be PyTorch Tensor.")

if idist.get_world_size() <= 1:
Expand Down Expand Up @@ -110,7 +110,7 @@ def write_metrics_reports(
list of strings - generate summary report for every metric_details with specified operations, they
should be within this list: [`mean`, `median`, `max`, `min`, `90percent`, `std`].
default to None.
deli: the delimiter charactor in the file, default to "\t".
deli: the delimiter character in the file, default to "\t".
output_type: expected output file type, supported types: ["csv"], default to "csv".
"""
Expand All @@ -127,7 +127,7 @@ def write_metrics_reports(

if metric_details is not None and len(metric_details) > 0:
for k, v in metric_details.items():
if torch.is_tensor(v):
if isinstance(v, torch.Tensor):
v = v.cpu().numpy()
if v.ndim == 0:
# reshape to [1, 1] if no batch and class dims
Expand Down Expand Up @@ -162,5 +162,5 @@ def write_metrics_reports(

with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f:
f.write(f"class{deli}{deli.join(ops)}\n")
for i, c in enumerate(v.transpose()):
for i, c in enumerate(np.transpose(v)):
f.write(f"{class_labels[i]}{deli}{deli.join([f'{supported_ops[k](c):.4f}' for k in ops])}\n")
2 changes: 1 addition & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def wasserstein_distance_map(self, flat_proba: torch.Tensor, flat_target: torch.
flat_target: the target tensor.
"""
# Turn the distance matrix to a map of identical matrix
m = torch.clone(self.m).to(flat_proba.device)
m = torch.clone(torch.as_tensor(self.m)).to(flat_proba.device)
m_extended = torch.unsqueeze(m, dim=0)
m_extended = torch.unsqueeze(m_extended, dim=3)
m_extended = m_extended.expand((flat_proba.size(0), m_extended.size(1), m_extended.size(2), flat_proba.size(2)))
Expand Down
Loading

0 comments on commit c415509

Please sign in to comment.