Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove assert statement from non-test files #1446

Merged
merged 4 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, output_dir: str = "./", filename: str = "predictions.csv", ov
"""
self.output_dir = output_dir
self._cache_dict: OrderedDict = OrderedDict()
assert isinstance(filename, str) and filename[-4:] == ".csv", "filename must be a string with CSV format."
if not (isinstance(filename, str) and filename[-4:] == ".csv"):
raise AssertionError("filename must be a string with CSV format.")
self._filepath = os.path.join(output_dir, filename)
self.overwrite = overwrite
self._data_index = 0
Expand Down Expand Up @@ -76,7 +77,8 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
self._data_index += 1
if torch.is_tensor(data):
data = data.detach().cpu().numpy()
assert isinstance(data, np.ndarray)
if not isinstance(data, np.ndarray):
raise AssertionError
self._cache_dict[save_key] = data.astype(np.float32)

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
Expand Down
3 changes: 2 additions & 1 deletion monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def write_nifti(
the output data type is always ``np.float32``.
output_dtype: data type for saving data. Defaults to ``np.float32``.
"""
assert isinstance(data, np.ndarray), "input data must be numpy array."
if not isinstance(data, np.ndarray):
raise AssertionError("input data must be numpy array.")
dtype = dtype or data.dtype
sr = min(data.ndim, 3)
if affine is None:
Expand Down
3 changes: 2 additions & 1 deletion monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def write_png(
ValueError: When ``scale`` is not one of [255, 65535].

"""
assert isinstance(data, np.ndarray), "input data must be numpy array."
if not isinstance(data, np.ndarray):
raise AssertionError("input data must be numpy array.")
if len(data.shape) == 3 and data.shape[2] == 1: # PIL Image can't save image with 1 channel
data = data.squeeze(2)
if output_spatial_shape is not None:
Expand Down
6 changes: 4 additions & 2 deletions monai/data/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def create_test_image_2d(
noisyimage = rescale_array(np.maximum(image, norm))

if channel_dim is not None:
assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 2), "invalid channel dim."
if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)):
raise AssertionError("invalid channel dim.")
if channel_dim == 0:
noisyimage = noisyimage[None]
labels = labels[None]
Expand Down Expand Up @@ -131,7 +132,8 @@ def create_test_image_3d(
noisyimage = rescale_array(np.maximum(image, norm))

if channel_dim is not None:
assert isinstance(channel_dim, int) and channel_dim in (-1, 0, 3), "invalid channel dim."
if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)):
raise AssertionError("invalid channel dim.")
noisyimage, labels = (
(noisyimage[None], labels[None]) if channel_dim == 0 else (noisyimage[..., None], labels[..., None])
)
Expand Down
3 changes: 2 additions & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def default_prepare_batch(
image, label(optional).

"""
assert isinstance(batchdata, dict), "default prepare_batch expects dictionary input data."
if not isinstance(batchdata, dict):
raise AssertionError("default prepare_batch expects dictionary input data.")
if CommonKeys.LABEL in batchdata:
return (
batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking),
Expand Down
3 changes: 2 additions & 1 deletion monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def set_sampler_epoch(engine: Engine):

@self.on(Events.ITERATION_COMPLETED)
def run_post_transform(engine: Engine) -> None:
assert post_transform is not None
if post_transform is None:
raise AssertionError
engine.state.output = apply_transform(post_transform, engine.state.output)

if key_metric is not None:
Expand Down
6 changes: 4 additions & 2 deletions monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
name: Optional[str] = None,
map_location: Optional[Dict] = None,
) -> None:
assert load_path is not None, "must provide clear path to load checkpoint."
if load_path is None:
raise AssertionError("must provide clear path to load checkpoint.")
self.load_path = load_path
assert load_dict is not None and len(load_dict) > 0, "must provide target objects to load."
if not (load_dict is not None and len(load_dict) > 0):
raise AssertionError("must provide target objects to load.")
self.logger = logging.getLogger(name)
self.load_dict = load_dict
self._name = name
Expand Down
36 changes: 24 additions & 12 deletions monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def __init__(
save_interval: int = 0,
n_saved: Optional[int] = None,
) -> None:
assert save_dir is not None, "must provide directory to save the checkpoints."
if save_dir is None:
raise AssertionError("must provide directory to save the checkpoints.")
self.save_dir = save_dir
assert save_dict is not None and len(save_dict) > 0, "must provide source objects to save."
if not (save_dict is not None and len(save_dict) > 0):
raise AssertionError("must provide source objects to save.")
self.save_dict = save_dict
self.logger = logging.getLogger(name)
self.epoch_level = epoch_level
Expand Down Expand Up @@ -202,12 +204,15 @@ def completed(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified."
if not callable(self._final_checkpoint):
raise AssertionError("Error: _final_checkpoint function not specified.")
# delete previous saved final checkpoint if existing
self._delete_previous_final_ckpt()
self._final_checkpoint(engine)
assert self.logger is not None
assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute."
if self.logger is None:
raise AssertionError
if not hasattr(self.logger, "info"):
raise AssertionError("Error, provided logger has not info attribute.")
self.logger.info(f"Train completed, saved final checkpoint: {self._final_checkpoint.last_checkpoint}")

def exception_raised(self, engine: Engine, e: Exception) -> None:
Expand All @@ -219,12 +224,15 @@ def exception_raised(self, engine: Engine, e: Exception) -> None:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
e: the exception caught in Ignite during engine.run().
"""
assert callable(self._final_checkpoint), "Error: _final_checkpoint function not specified."
if not callable(self._final_checkpoint):
raise AssertionError("Error: _final_checkpoint function not specified.")
# delete previous saved final checkpoint if existing
self._delete_previous_final_ckpt()
self._final_checkpoint(engine)
assert self.logger is not None
assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute."
if self.logger is None:
raise AssertionError
if not hasattr(self.logger, "info"):
raise AssertionError("Error, provided logger has not info attribute.")
self.logger.info(f"Exception_raised, saved exception checkpoint: {self._final_checkpoint.last_checkpoint}")
raise e

Expand All @@ -234,7 +242,8 @@ def metrics_completed(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
assert callable(self._key_metric_checkpoint), "Error: _key_metric_checkpoint function not specified."
if not callable(self._key_metric_checkpoint):
raise AssertionError("Error: _key_metric_checkpoint function not specified.")
self._key_metric_checkpoint(engine)

def interval_completed(self, engine: Engine) -> None:
Expand All @@ -244,10 +253,13 @@ def interval_completed(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
assert callable(self._interval_checkpoint), "Error: _interval_checkpoint function not specified."
if not callable(self._interval_checkpoint):
raise AssertionError("Error: _interval_checkpoint function not specified.")
self._interval_checkpoint(engine)
assert self.logger is not None
assert hasattr(self.logger, "info"), "Error, provided logger has not info attribute."
if self.logger is None:
raise AssertionError
if not hasattr(self.logger, "info"):
raise AssertionError("Error, provided logger has not info attribute.")
if self.epoch_level:
self.logger.info(f"Saved checkpoint at epoch: {engine.state.epoch}")
else:
Expand Down
3 changes: 2 additions & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def sliding_window_inference(

"""
num_spatial_dims = len(inputs.shape) - 2
assert 0 <= overlap < 1, "overlap must be >= 0 and < 1."
if overlap < 0 or overlap >= 1:
raise AssertionError("overlap must be >= 0 and < 1.")

# determine image spatial size and batch size
# Note: all input images must have the same image size and batch size
Expand Down
26 changes: 12 additions & 14 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = target[:, 1:]
input = input[:, 1:]

assert (
target.shape == input.shape
), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, len(input.shape)))
Expand Down Expand Up @@ -192,16 +191,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torc
"""
if mask is not None:
# checking if mask is of proper shape
assert input.dim() == mask.dim(), f"dim of input ({input.shape}) is different from mask ({mask.shape})"
assert (
input.shape[0] == mask.shape[0] or mask.shape[0] == 1
), f" batch size of mask ({mask.shape}) must be 1 or equal to input ({input.shape})"
if input.dim() != mask.dim():
raise AssertionError(f"dim of input ({input.shape}) is different from mask ({mask.shape})")
if not (input.shape[0] == mask.shape[0] or mask.shape[0] == 1):
raise AssertionError(f" batch size of mask ({mask.shape}) must be 1 or equal to input ({input.shape})")

if target.dim() > 1:
assert mask.shape[1] == 1, f"mask ({mask.shape}) must have only 1 channel"
assert (
input.shape[2:] == mask.shape[2:]
), f"spatial size of input ({input.shape}) is different from mask ({mask.shape})"
if mask.shape[1] != 1:
raise AssertionError(f"mask ({mask.shape}) must have only 1 channel")
if input.shape[2:] != mask.shape[2:]:
raise AssertionError(f"spatial size of input ({input.shape}) is different from mask ({mask.shape})")

input = input * mask
target = target * mask
Expand Down Expand Up @@ -322,9 +321,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = target[:, 1:]
input = input[:, 1:]

assert (
target.shape == input.shape
), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

# reducing only spatial dimensions (not batch nor channels)
reduce_axis = list(range(2, len(input.shape)))
Expand Down
5 changes: 2 additions & 3 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = target[:, 1:]
input = input[:, 1:]

assert (
target.shape == input.shape
), f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

p0 = input
p1 = 1 - p0
Expand Down
13 changes: 6 additions & 7 deletions monai/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@


def _calculate(y: torch.Tensor, y_pred: torch.Tensor) -> float:
assert y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(
y_pred
), "y and y_pred must be 1 dimension data with same length."
assert y.unique().equal(
torch.tensor([0, 1], dtype=y.dtype, device=y.device)
), "y values must be 0 or 1, can not be all 0 or all 1."
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
if not y.unique().equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
raise AssertionError("y values must be 0 or 1, can not be all 0 or all 1.")
n = len(y)
indices = y_pred.argsort()
y = y[indices].cpu().numpy()
Expand Down Expand Up @@ -126,7 +124,8 @@ def compute_roc_auc(
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
y_pred = other_act(y_pred)

assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match."
if y.shape != y_pred.shape:
raise AssertionError("data shapes of y_pred and y do not match.")

average = Average(average)
if average == Average.MICRO:
Expand Down
11 changes: 6 additions & 5 deletions monai/networks/blocks/dynunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def get_norm_layer(spatial_dims: int, out_channels: int, norm_name: str, num_gro
if norm_name not in ["batch", "instance", "group"]:
raise ValueError(f"Unsupported normalization mode: {norm_name}")
if norm_name == "group":
assert out_channels % num_groups == 0, "out_channels should be divisible by num_groups."
if out_channels % num_groups != 0:
raise AssertionError("out_channels should be divisible by num_groups.")
norm = Norm[norm_name](num_groups=num_groups, num_channels=out_channels, affine=True)
else:
norm = Norm[norm_name, spatial_dims](out_channels, affine=True)
Expand Down Expand Up @@ -276,8 +277,8 @@ def get_padding(
kernel_size_np = np.atleast_1d(kernel_size)
stride_np = np.atleast_1d(stride)
padding_np = (kernel_size_np - stride_np + 1) / 2
error_msg = "padding value should not be negative, please change the kernel size and/or stride."
assert np.min(padding_np) >= 0, error_msg
if np.min(padding_np) < 0:
raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.")
padding = tuple(int(p) for p in padding_np)

return padding if len(padding) > 1 else padding[0]
Expand All @@ -293,8 +294,8 @@ def get_output_padding(
padding_np = np.atleast_1d(padding)

out_padding_np = 2 * padding_np + stride_np - kernel_size_np
error_msg = "out_padding value should not be negative, please change the kernel size and/or stride."
assert np.min(out_padding_np) >= 0, error_msg
if np.min(out_padding_np) < 0:
raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.")
out_padding = tuple(int(p) for p in out_padding_np)

return out_padding if len(out_padding) > 1 else out_padding[0]
6 changes: 4 additions & 2 deletions monai/networks/blocks/segresnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(

super().__init__()

assert kernel_size % 2 == 1, "kernel_size should be an odd number."
assert in_channels % num_groups == 0, "in_channels should be divisible by num_groups."
if kernel_size % 2 != 1:
raise AssertionError("kernel_size should be an odd number.")
if in_channels % num_groups != 0:
raise AssertionError("in_channels should be divisible by num_groups.")

self.norm1 = get_norm_layer(spatial_dims, in_channels, norm_name, num_groups=num_groups)
self.norm2 = get_norm_layer(spatial_dims, in_channels, norm_name, num_groups=num_groups)
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/nets/ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ def __init__(
self.spatial_dims = spatial_dims
self.psp_block_num = psp_block_num

assert spatial_dims in [2, 3], "spatial_dims can only be 2 or 3."
assert psp_block_num in [0, 1, 2, 3, 4], "psp_block_num should be an integer that belongs to [0, 4]."
if spatial_dims not in [2, 3]:
raise AssertionError("spatial_dims can only be 2 or 3.")
if psp_block_num not in [0, 1, 2, 3, 4]:
raise AssertionError("psp_block_num should be an integer that belongs to [0, 4].")

self.conv1 = conv_type(
in_channels,
Expand Down
19 changes: 12 additions & 7 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
shouldn't be associated with a supervision head.
"""

assert len(downsamples) == len(upsamples), f"{len(downsamples)} != {len(upsamples)}"
assert (len(downsamples) - len(superheads)) in (1, 0), f"{len(downsamples)}-(0,1) != {len(superheads)}"
if len(downsamples) != len(upsamples):
raise AssertionError(f"{len(downsamples)} != {len(upsamples)}")
if (len(downsamples) - len(superheads)) not in (1, 0):
raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}")

if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
return bottleneck
Expand All @@ -157,22 +159,25 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
error_msg = "length of kernel_size and strides should be the same, and no less than 3."
assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg
if not (len(kernels) == len(strides) and len(kernels) >= 3):
raise AssertionError(error_msg)

for idx in range(len(kernels)):
kernel, stride = kernels[idx], strides[idx]
if not isinstance(kernel, int):
error_msg = "length of kernel_size in block {} should be the same as spatial_dims.".format(idx)
assert len(kernel) == self.spatial_dims, error_msg
if len(kernel) != self.spatial_dims:
raise AssertionError(error_msg)
if not isinstance(stride, int):
error_msg = "length of stride in block {} should be the same as spatial_dims.".format(idx)
assert len(stride) == self.spatial_dims, error_msg
if len(stride) != self.spatial_dims:
raise AssertionError(error_msg)

def check_deep_supr_num(self):
deep_supr_num, strides = self.deep_supr_num, self.strides
num_up_layers = len(strides) - 1
error_msg = "deep_supr_num should be less than the number of up sample layers."
assert 1 <= deep_supr_num < num_up_layers, error_msg
if deep_supr_num < 1 or deep_supr_num >= num_up_layers:
raise AssertionError("deep_supr_num should be less than the number of up sample layers.")

def forward(self, x):
out = self.skip_layers(x)
Expand Down
3 changes: 2 additions & 1 deletion monai/networks/nets/segresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(
):
super().__init__()

assert spatial_dims == 2 or spatial_dims == 3, "spatial_dims can only be 2 or 3."
if spatial_dims not in (2, 3):
raise AssertionError("spatial_dims can only be 2 or 3.")

self.spatial_dims = spatial_dims
self.init_filters = init_filters
Expand Down
3 changes: 2 additions & 1 deletion monai/networks/nets/vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def __init__(
):
super().__init__()

assert spatial_dims == 2 or spatial_dims == 3, "spatial_dims can only be 2 or 3."
if spatial_dims not in (2, 3):
raise AssertionError("spatial_dims can only be 2 or 3.")

self.in_tr = InputTransition(spatial_dims, in_channels, 16, act)
self.down_tr32 = DownTransition(spatial_dims, 16, 1, act)
Expand Down
Loading