diff --git a/python/oneflow/nn/modules/linspace.py b/python/oneflow/nn/modules/linspace.py index 9c8b902ff40..9962104485e 100644 --- a/python/oneflow/nn/modules/linspace.py +++ b/python/oneflow/nn/modules/linspace.py @@ -19,9 +19,9 @@ def linspace_op( - start: float, - end: float, - steps: int, + start: Union[float, flow.Tensor], + end: Union[float, flow.Tensor], + steps: Union[int, flow.Tensor], dtype: flow.dtype = flow.float32, device: Union[str, flow.device] = None, placement: flow.placement = None, @@ -60,6 +60,35 @@ def linspace_op( tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000], dtype=oneflow.float32) """ + + def is_scalar(tensor): + return tensor.ndim == 0 and tensor.nelement() == 1 + + if isinstance(start, flow.Tensor): + if not is_scalar(start): + raise TypeError( + "linspace(): argument 'start' (position 1) must be Number, not Tensor" + ) + start = start.item() + if isinstance(end, flow.Tensor): + if not is_scalar(end): + raise TypeError( + "linspace(): argument 'end' (position 2) must be Number, not Tensor" + ) + end = end.item() + if isinstance(steps, flow.Tensor): + if not is_scalar(steps): + raise TypeError( + "linspace(): argument 'steps' (position 3) must be Number, not Tensor" + ) + if flow.is_floating_point(steps): + raise TypeError( + "linspace(): argument 'steps' (position 3) must be int, not Tensor (with dtype: " + + str(steps.dtype) + + ")" + ) + steps = steps.item() + if start == end: return flow.full((steps,), start * 1.0) step = 1.0 diff --git a/python/oneflow/test/modules/test_linspace.py b/python/oneflow/test/modules/test_linspace.py index 1e9ed197ad4..678d0e6d254 100644 --- a/python/oneflow/test/modules/test_linspace.py +++ b/python/oneflow/test/modules/test_linspace.py @@ -48,6 +48,14 @@ def test_linspace_float_with_random_data(test_case): x.to(device) return x + @autotest(n=5, auto_backward=False) + def test_linspace_with_scalar_tensor_as_params(test_case): + start = random_tensor(2, 3, 4, requires_grad=False).mean() + end = start + random_tensor(2, 3, 4, requires_grad=False).mean() + steps = random(0, 10).to(int) + y = torch.linspace(start=start, end=end, steps=steps) + return y + def test_global_naive(test_case): placement = flow.placement("cpu", ranks=[0]) sbp = (flow.sbp.broadcast,)