diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index ba7bb3d9fe3..1d0150e5050 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -536,17 +536,17 @@ - name: "arange" signature: [ - "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=kInt64, + "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None, Device device=None) => Arange", - "Tensor (Scalar end, *, DataType dtype=kInt64, Device device=None) => Arange", + "Tensor (Scalar end, *, DataType dtype=None, Device device=None) => Arange", ] bind_python: True - name: "consistent_arange" signature: [ - "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=kInt64, + "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None, Placement placement, SbpList sbp) => ConsistentArange", - "Tensor (Scalar end, *, DataType dtype=kInt64, Placement placement, SbpList sbp) => ConsistentArange", + "Tensor (Scalar end, *, DataType dtype=None, Placement placement, SbpList sbp) => ConsistentArange", ] bind_python: True diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index a21b10e6955..4a78e740811 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -678,19 +678,34 @@ class ArangeFunctor { public: ArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("arange").Output("out").Build()); } Maybe operator()(const Scalar& start, const Scalar& limit, const Scalar& delta, - const Symbol& dtype, + const Optional>& dtype, const Optional>& device) const { MutableAttrMap attrs; - const DataType range_dtype = dtype->data_type(); - JUST(attrs.SetAttr("dtype", range_dtype)); - if (IsIntegralDataType(range_dtype)) { - JUST(attrs.SetAttr("integer_start", JUST(start.As()))); - JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + if (dtype.has_value()) { + const DataType range_dtype = JUST(dtype)->data_type(); + if (IsIntegralDataType(range_dtype)) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } } else { - JUST(attrs.SetAttr("float_start", JUST(start.As()))); - JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + if (delta.IsIntegral()) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Int64()->data_type())); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Float()->data_type())); + } } OpExprInterpContext ctx(attrs); ctx.device = device; @@ -703,7 +718,7 @@ class ArangeFunctor { class Arange2Functor { public: - Maybe operator()(const Scalar& limit, const Symbol& dtype, + Maybe operator()(const Scalar& limit, const Optional>& dtype, const Optional>& device) const { return Arange(Scalar(0), limit, Scalar(1), dtype, device); } @@ -713,21 +728,36 @@ class ConsistentArangeFunctor { public: ConsistentArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("arange").Output("out").Build()); } Maybe operator()(const Scalar& start, const Scalar& limit, const Scalar& delta, - const Symbol& dtype, const Symbol& placement, + const Optional>& dtype, + const Symbol& placement, const std::vector>& sbp_tuple) const { MutableAttrMap attrs; - const DataType range_dtype = dtype->data_type(); - JUST(attrs.SetAttr("dtype", range_dtype)); - if (IsIntegralDataType(range_dtype)) { - JUST(attrs.SetAttr("integer_start", JUST(start.As()))); - JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + if (dtype.has_value()) { + const DataType range_dtype = JUST(dtype)->data_type(); + if (IsIntegralDataType(range_dtype)) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } } else { - JUST(attrs.SetAttr("float_start", JUST(start.As()))); - JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + if (delta.IsIntegral()) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Int64()->data_type())); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Float()->data_type())); + } } - if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { diff --git a/python/oneflow/nn/modules/arange.py b/python/oneflow/nn/modules/arange.py index ff1b005277b..980bd151130 100644 --- a/python/oneflow/nn/modules/arange.py +++ b/python/oneflow/nn/modules/arange.py @@ -24,7 +24,7 @@ def arange_op( start: int = 0, end: int = None, step: int = 1, - dtype: flow.dtype = flow.int64, + dtype: flow.dtype = None, device: Union[str, flow.device] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, diff --git a/python/oneflow/test/modules/test_arange.py b/python/oneflow/test/modules/test_arange.py index dd05da8c62c..5d873692c66 100644 --- a/python/oneflow/test/modules/test_arange.py +++ b/python/oneflow/test/modules/test_arange.py @@ -77,6 +77,16 @@ def test_arange_with_random_data(test_case): x.to(device) return x + @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=False) + def test_arange_with_float_delta(test_case): + start = random().to(int) + end = start + random().to(int) + step = random(0, end - start).to(float) + x = torch.arange(start=start, end=end, step=step) + device = random_device() + x.to(device) + return x + def test_consistent_naive(test_case): placement = flow.placement("cpu", {0: [0]}) sbp = (flow.sbp.broadcast,)