Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1534 type hints numpy 1 20 #1536

Merged
merged 7 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
print_gpu_info,
print_system_info,
)
from .type_definitions import IndexSelection, KeysCollection
from .type_definitions import DtypeLike, IndexSelection, KeysCollection, NdarrayTensor
2 changes: 1 addition & 1 deletion monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_system_info() -> OrderedDict:
_dict_append(
output,
"Avg. sensor temp. (Celsius)",
lambda: round(
lambda: np.round(
np.mean([item.current for sublist in psutil.sensors_temperatures().values() for item in sublist], 1)
),
)
Expand Down
20 changes: 18 additions & 2 deletions monai/config/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, Hashable, Iterable, Union
from typing import Collection, Hashable, Iterable, TypeVar, Union

__all__ = ["KeysCollection", "IndexSelection"]
import numpy as np
import torch

__all__ = ["KeysCollection", "IndexSelection", "DtypeLike", "NdarrayTensor"]

"""Commonly used concepts
This module provides naming and type specifications for commonly used concepts
Expand Down Expand Up @@ -51,3 +54,16 @@
The indices must be integers, and if a container of indices is specified, the
container must be iterable.
"""

DtypeLike = Union[
np.dtype,
type,
None,
]
"""Type of datatypes
adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py
"""

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
# Unlike Union can create a dependence between parameter(s) / return(s)
NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor)
2 changes: 1 addition & 1 deletion monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
"""
save_key = meta_data["filename_or_obj"] if meta_data else str(self._data_index)
self._data_index += 1
if torch.is_tensor(data):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if not isinstance(data, np.ndarray):
raise AssertionError
Expand Down
3 changes: 2 additions & 1 deletion monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from torch.utils.data import Dataset

from monai.config import DtypeLike
from monai.data.image_reader import ImageReader
from monai.transforms import LoadImage, Randomizable, apply_transform
from monai.utils import MAX_SEED, get_seed
Expand All @@ -36,7 +37,7 @@ def __init__(
transform: Optional[Callable] = None,
seg_transform: Optional[Callable] = None,
image_only: bool = True,
dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float32,
reader: Optional[Union[ImageReader, str]] = None,
*args,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.config import KeysCollection
from monai.config import DtypeLike, KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.utils import ensure_tuple, optional_import

Expand Down Expand Up @@ -244,7 +244,7 @@ def _get_affine(self, img) -> np.ndarray:
affine = np.eye(direction.shape[0] + 1)
affine[(slice(-1), slice(-1))] = direction @ np.diag(spacing)
affine[(slice(-1), -1)] = origin
return affine
return np.asarray(affine)

def _get_spatial_shape(self, img) -> np.ndarray:
"""
Expand All @@ -258,7 +258,7 @@ def _get_spatial_shape(self, img) -> np.ndarray:
shape.reverse()
return np.asarray(shape)

def _get_array_data(self, img) -> np.ndarray:
def _get_array_data(self, img):
"""
Get the raw array data of the image, converted to Numpy array.

Expand Down Expand Up @@ -295,7 +295,7 @@ class NibabelReader(ImageReader):

"""

def __init__(self, as_closest_canonical: bool = False, dtype: Optional[np.dtype] = np.float32, **kwargs):
def __init__(self, as_closest_canonical: bool = False, dtype: DtypeLike = np.float32, **kwargs):
super().__init__()
self.as_closest_canonical = as_closest_canonical
self.dtype = dtype
Expand Down Expand Up @@ -385,7 +385,7 @@ def _get_affine(self, img) -> np.ndarray:
img: a Nibabel image object loaded from a image file.

"""
return img.affine.copy()
return np.array(img.affine, copy=True)

def _get_spatial_shape(self, img) -> np.ndarray:
"""
Expand Down
9 changes: 5 additions & 4 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.config import DtypeLike
from monai.data.nifti_writer import write_nifti
from monai.data.utils import create_file_basename
from monai.utils import GridSampleMode, GridSamplePadMode
Expand All @@ -36,8 +37,8 @@ def __init__(
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
Expand Down Expand Up @@ -100,7 +101,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
affine = meta_data.get("affine", None) if meta_data else None
spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None

if torch.is_tensor(data):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir)
Expand All @@ -109,7 +110,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
while len(data.shape) < 4:
data = np.expand_dims(data, -1)
# change data to "channel last" format and write to nifti format file
data = np.moveaxis(data, 0, -1)
data = np.moveaxis(np.asarray(data), 0, -1)
write_nifti(
data,
file_name=filename,
Expand Down
9 changes: 5 additions & 4 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.config import DtypeLike
from monai.data.utils import compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import
Expand All @@ -27,12 +28,12 @@ def write_nifti(
affine: Optional[np.ndarray] = None,
target_affine: Optional[np.ndarray] = None,
resample: bool = True,
output_spatial_shape: Optional[Sequence[int]] = None,
output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None,
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
) -> None:
"""
Write numpy data into NIfTI files to disk. This function converts data
Expand Down Expand Up @@ -126,7 +127,7 @@ def write_nifti(
transform = np.linalg.inv(_affine) @ target_affine
if output_spatial_shape is None:
output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine)
output_spatial_shape_ = list(output_spatial_shape)
output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else []
if data.ndim > 3: # multi channel, resampling each channel
while len(output_spatial_shape_) < 3:
output_spatial_shape_ = output_spatial_shape_ + [1]
Expand Down
6 changes: 3 additions & 3 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
self._data_index += 1
spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None

if torch.is_tensor(data):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

filename = create_file_basename(self.output_postfix, filename, self.output_dir)
Expand All @@ -95,12 +95,12 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
if data.shape[0] == 1:
data = data.squeeze(0)
elif 2 < data.shape[0] < 5:
data = np.moveaxis(data, 0, -1)
data = np.moveaxis(np.asarray(data), 0, -1)
else:
raise ValueError(f"Unsupported number of channels: {data.shape[0]}, available options are [1, 3, 4]")

write_png(
data,
np.asarray(data),
file_name=filename,
output_spatial_shape=spatial_shape,
mode=self.mode,
Expand Down
4 changes: 2 additions & 2 deletions monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def write_png(
data = np.expand_dims(data, 0) # make a channel
data = xform(data)[0] # first channel
if mode != InterpolateMode.NEAREST:
data = np.clip(data, _min, _max)
data = np.clip(data, _min, _max) # type: ignore

if scale is not None:
data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1]
data = np.clip(data, 0.0, 1.0) # type: ignore # png writer only can scale data in range [0, 1]
if scale == np.iinfo(np.uint8).max:
data = (scale * data).astype(np.uint8)
elif scale == np.iinfo(np.uint16).max:
Expand Down
4 changes: 2 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def rectify_header_sform_qform(img_nii):
return img_nii


def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True) -> np.ndarray:
def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = True):
"""
To make column norm of `affine` the same as `scale`. If diagonal is False,
returns an affine that combines orthogonal rotation and the new scale.
Expand Down Expand Up @@ -379,7 +379,7 @@ def zoom_affine(affine: np.ndarray, scale: Sequence[float], diagonal: bool = Tru


def compute_shape_offset(
spatial_shape: np.ndarray, in_affine: np.ndarray, out_affine: np.ndarray
spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: np.ndarray, out_affine: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given input and output affine, compute appropriate shapes
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/iteration_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def compute(self) -> Any:
# broadcast result to all processes
result = idist.broadcast(result, src=0)

return result.item() if torch.is_tensor(result) else result
return result.item() if isinstance(result, torch.Tensor) else result

def _reduce(self, scores) -> Any:
return do_metric_reduction(scores, MetricReduction.MEAN)[0]
Expand Down
5 changes: 3 additions & 2 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from monai.config import DtypeLike
from monai.data import NiftiSaver, PNGSaver
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import

Expand All @@ -38,8 +39,8 @@ def __init__(
mode: Union[GridSampleMode, InterpolateMode, str] = "nearest",
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: Optional[np.dtype] = np.float64,
output_dtype: Optional[np.dtype] = np.float32,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
batch_transform: Callable = lambda x: x,
output_transform: Callable = lambda x: x,
name: Optional[str] = None,
Expand Down
6 changes: 4 additions & 2 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ def _default_iteration_print(self, engine: Engine) -> None:
" {}:{}".format(name, type(value))
)
continue # not printing multi dimensional output
out_str += self.key_var_format.format(name, value.item() if torch.is_tensor(value) else value)
out_str += self.key_var_format.format(name, value.item() if isinstance(value, torch.Tensor) else value)
else:
if is_scalar(loss): # not printing multi dimensional output
out_str += self.key_var_format.format(self.tag_name, loss.item() if torch.is_tensor(loss) else loss)
out_str += self.key_var_format.format(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss
)
else:
warnings.warn(
"ignoring non-scalar output in StatsHandler,"
Expand Down
14 changes: 9 additions & 5 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,13 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No
" {}:{}".format(name, type(value))
)
continue # not plot multi dimensional output
writer.add_scalar(name, value.item() if torch.is_tensor(value) else value, engine.state.iteration)
writer.add_scalar(
name, value.item() if isinstance(value, torch.Tensor) else value, engine.state.iteration
)
elif is_scalar(loss): # not printing multi dimensional output
writer.add_scalar(self.tag_name, loss.item() if torch.is_tensor(loss) else loss, engine.state.iteration)
writer.add_scalar(
self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss, engine.state.iteration
)
else:
warnings.warn(
"ignoring non-scalar output in TensorBoardStatsHandler,"
Expand Down Expand Up @@ -261,7 +265,7 @@ def __call__(self, engine: Engine) -> None:
"""
step = self.global_iter_transform(engine.state.epoch if self.epoch_level else engine.state.iteration)
show_images = self.batch_transform(engine.state.batch)[0]
if torch.is_tensor(show_images):
if isinstance(show_images, torch.Tensor):
show_images = show_images.detach().cpu().numpy()
if show_images is not None:
if not isinstance(show_images, np.ndarray):
Expand All @@ -274,7 +278,7 @@ def __call__(self, engine: Engine) -> None:
)

show_labels = self.batch_transform(engine.state.batch)[1]
if torch.is_tensor(show_labels):
if isinstance(show_labels, torch.Tensor):
show_labels = show_labels.detach().cpu().numpy()
if show_labels is not None:
if not isinstance(show_labels, np.ndarray):
Expand All @@ -287,7 +291,7 @@ def __call__(self, engine: Engine) -> None:
)

show_outputs = self.output_transform(engine.state.output)
if torch.is_tensor(show_outputs):
if isinstance(show_outputs, torch.Tensor):
show_outputs = show_outputs.detach().cpu().numpy()
if show_outputs is not None:
if not isinstance(show_outputs, np.ndarray):
Expand Down
6 changes: 3 additions & 3 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor:
data: source tensor to pad and execute all_gather in distributed data parallel.

"""
if not torch.is_tensor(data):
if not isinstance(data, torch.Tensor):
raise ValueError("input data must be PyTorch Tensor.")

if idist.get_world_size() <= 1:
Expand Down Expand Up @@ -127,7 +127,7 @@ def write_metrics_reports(

if metric_details is not None and len(metric_details) > 0:
for k, v in metric_details.items():
if torch.is_tensor(v):
if isinstance(v, torch.Tensor):
v = v.cpu().numpy()
if v.ndim == 0:
# reshape to [1, 1] if no batch and class dims
Expand Down Expand Up @@ -162,5 +162,5 @@ def write_metrics_reports(

with open(os.path.join(save_dir, f"{k}_summary.csv"), "w") as f:
f.write(f"class{deli}{deli.join(ops)}\n")
for i, c in enumerate(v.transpose()):
for i, c in enumerate(np.transpose(v)):
f.write(f"{class_labels[i]}{deli}{deli.join([f'{supported_ops[k](c):.4f}' for k in ops])}\n")
2 changes: 1 addition & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def wasserstein_distance_map(self, flat_proba: torch.Tensor, flat_target: torch.
flat_target: the target tensor.
"""
# Turn the distance matrix to a map of identical matrix
m = torch.clone(self.m).to(flat_proba.device)
m = torch.clone(torch.as_tensor(self.m)).to(flat_proba.device)
m_extended = torch.unsqueeze(m, dim=0)
m_extended = torch.unsqueeze(m_extended, dim=3)
m_extended = m_extended.expand((flat_proba.size(0), m_extended.size(1), m_extended.size(2), flat_proba.size(2)))
Expand Down
7 changes: 4 additions & 3 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def compute_hausdorff_distance(
y_pred=y_pred,
y=y,
)

y = y.float()
y_pred = y_pred.float()
if isinstance(y, torch.Tensor):
y = y.float()
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.float()

if y.shape != y_pred.shape:
raise ValueError("y_pred and y should have same shapes.")
Expand Down
4 changes: 2 additions & 2 deletions monai/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import warnings
from typing import Callable, List, Optional, Union, cast
from typing import Callable, Optional, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -57,7 +57,7 @@ def compute_roc_auc(
softmax: bool = False,
other_act: Optional[Callable] = None,
average: Union[Average, str] = Average.MACRO,
) -> Union[np.ndarray, List[float], float]:
):
"""Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
`sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.
Expand Down
Loading