Skip to content

Commit

Permalink
Merge branch 'dev' into temp-fix-windows-test
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli authored Jan 10, 2022
2 parents 7c7d4e1 + 2fef7ff commit d4a57d6
Showing 1 changed file with 50 additions and 13 deletions.
63 changes: 50 additions & 13 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
return result


def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]:
def percentile(
x: NdarrayOrTensor, q, dim: Optional[int] = None, keepdim: bool = False, **kwargs
) -> Union[NdarrayOrTensor, float, int]:
"""`np.percentile` with equivalent implementation for torch.
Pytorch uses `quantile`, but this functionality is only available from v1.7.
Expand All @@ -97,6 +99,9 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
q: percentile to compute (should in range 0 <= q <= 100)
dim: the dim along which the percentiles are computed. default is to compute the percentile
along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
keepdim: whether the output data has dim retained or not.
kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
Returns:
Resulting value (scalar)
Expand All @@ -108,11 +113,11 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q, axis=dim)
result = np.percentile(x, q, axis=dim, keepdims=keepdim, **kwargs)
else:
q = torch.tensor(q, device=x.device)
if hasattr(torch, "quantile"): # `quantile` is new in torch 1.7.0
result = torch.quantile(x, q / 100.0, dim=dim)
result = torch.quantile(x, q / 100.0, dim=dim, keepdim=keepdim)
else:
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
Expand Down Expand Up @@ -282,13 +287,23 @@ def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> N
return torch.cat(to_cat, dim=axis, out=out) # type: ignore


def cumsum(a: NdarrayOrTensor, axis=None):
"""`np.cumsum` with equivalent implementation for torch."""
def cumsum(a: NdarrayOrTensor, axis=None, **kwargs):
"""
`np.cumsum` with equivalent implementation for torch.
Args:
a: input data to compute cumsum.
axis: expected axis to compute cumsum.
kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
https://pytorch.org/docs/stable/generated/torch.cumsum.html.
"""

if isinstance(a, np.ndarray):
return np.cumsum(a, axis)
if axis is None:
return torch.cumsum(a[:], 0)
return torch.cumsum(a, dim=axis)
return torch.cumsum(a[:], 0, **kwargs)
return torch.cumsum(a, dim=axis, **kwargs)


def isfinite(x):
Expand All @@ -298,18 +313,40 @@ def isfinite(x):
return torch.isfinite(x)


def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None):
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs):
"""
`np.searchsorted` with equivalent implementation for torch.
Args:
a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
v: containing the search values.
right: if False, return the first suitable location that is found, if True, return the last such index.
sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
https://pytorch.org/docs/stable/generated/torch.searchsorted.html.
"""
side = "right" if right else "left"
if isinstance(a, np.ndarray):
return np.searchsorted(a, v, side, sorter) # type: ignore
return torch.searchsorted(a, v, right=right) # type: ignore
return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore


def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs):
"""
`np.repeat` with equivalent implementation for torch (`repeat_interleave`).
Args:
a: input data to repeat.
repeats: number of repetitions for each element, repeats is broadcasted to fit the shape of the given axis.
axis: axis along which to repeat values.
kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.
def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None):
"""`np.repeat` with equivalent implementation for torch (`repeat_interleave`)."""
"""
if isinstance(a, np.ndarray):
return np.repeat(a, repeats, axis)
return torch.repeat_interleave(a, repeats, dim=axis)
return torch.repeat_interleave(a, repeats, dim=axis, **kwargs)


def isnan(x: NdarrayOrTensor):
Expand All @@ -330,7 +367,7 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
Args:
x: array/tensor
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.
"""
if isinstance(x, np.ndarray):
Expand Down

0 comments on commit d4a57d6

Please sign in to comment.