diff --git a/oneflow/core/framework/tensor_methods.cpp b/oneflow/core/framework/tensor_methods.cpp index 8f08df8e08a..85eb23adddb 100644 --- a/oneflow/core/framework/tensor_methods.cpp +++ b/oneflow/core/framework/tensor_methods.cpp @@ -49,6 +49,15 @@ Maybe IsContiguous(const std::shared_ptr& tensor) { namespace view { +bool IsViewApplicable(const std::shared_ptr& input) { + // NOTE: only eager local tensor support view for now + // elem_cnt() > 1 used to excluding 0 shape tensor + if (input->is_local() && !(LazyMode::is_enabled()) && input->shape()->elem_cnt() >= 1) { + return true; + } + return false; +} + Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, int64_t storage_offset) { /** @@ -64,7 +73,6 @@ Maybe BasicView(const std::shared_ptr& input, const Shape& targe Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, const Stride& target_stride, int64_t storage_offset) { - storage_offset = storage_offset + JUST(JUST(input->AsMirroredTensor())->storage_offset()); // TODO(): Check shape compatible. auto device = JUST(input->device()); auto tensor_meta = std::make_shared( @@ -86,38 +94,23 @@ Maybe BasicView(const std::shared_ptr& input, const Shape& targe return output; } -Maybe Reshape(const std::shared_ptr& input, const Shape& shape) { - if (!(input->is_eager() && input->is_local())) { - return Error::RuntimeError() << "view::Reshape(): input should be eager local tensor, but got " - << (input->is_lazy() ? "lazy" : "consistent"); - } - int need_infer_axis = -1; - size_t count = 1; - for (int i = 0; i < shape.NumAxes(); ++i) { - if (shape.At(i) < -1) { - return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i); - } else if (shape.At(i) == -1) { - CHECK_EQ_OR_RETURN(need_infer_axis, -1) - << "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered."; - need_infer_axis = i; - } else { - count *= shape.At(i); - } - } +Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape) { + Stride target_stride(target_shape); + return Reshape(input, target_shape, target_stride); +} - std::shared_ptr output; - size_t x_count = input->shape()->Count(0); - if (need_infer_axis == -1) { - CHECK_EQ_OR_RETURN(shape.Count(0), x_count); - output = JUST(BasicView(input, shape, 0)); - } else { - Shape infered_shape = shape; - infered_shape.Set(need_infer_axis, x_count / count); - CHECK_EQ_OR_RETURN(infered_shape.Count(0), x_count) - << "Shape " << shape.ToString() << " is invalid for input of shape " - << input->shape()->ToString(); - output = JUST(BasicView(input, infered_shape, 0)); - } +Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape, + const Stride& target_stride) { + // TODO:(zhaoluyang) check input tensor is contiguous + CHECK_OR_RETURN(IsViewApplicable(input)) + << Error::RuntimeError() + << "view::Reshape(): input should be eager local tensor with element count >=1 , but got " + << (input->is_lazy() ? "lazy tensor" : "consistent tensor") + << " with shape: " << input->shape()->ToString() << "; element count: " << input->nelement(); + + int64_t storage_offset = JUST(JUST(input->AsMirroredTensor())->storage_offset()); + std::shared_ptr output = + JUST(BasicView(input, target_shape, target_stride, storage_offset)); if (autograd::GradMode::is_enabled() && input->requires_grad()) { Shape input_shape(input->shape()->dim_vec()); @@ -128,7 +121,8 @@ Maybe Reshape(const std::shared_ptr& input, const Shape& shape) autograd::AutoGradMode mode(create_graph); CHECK_EQ_OR_RETURN(out_grads.size(), 1); in_grads->resize(1); - in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), input_shape)); + *JUST(oneflow::VectorAt(in_grads, 0)) = + JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), input_shape)); return Maybe::Ok(); }); TensorTuple outputs{output}; @@ -140,9 +134,10 @@ Maybe Reshape(const std::shared_ptr& input, const Shape& shape) Maybe Slice(const std::shared_ptr& input, const std::vector& starts, const std::vector& ends, const std::vector& steps) { - CHECK_OR_RETURN(input->is_eager() && input->is_local()) + CHECK_OR_RETURN(IsViewApplicable(input)) << Error::RuntimeError() << "view::Slice(): input should be eager local tensor, but is " - << (input->is_lazy() ? "lazy" : "consistent"); + << (input->is_lazy() ? "lazy tensor" : "consistent tensor") + << " with shape: " << input->shape()->ToString() << "; element count: " << input->nelement(); const auto& shape = input->shape(); const auto& strides = JUST(input->stride()); const int64_t ndim = starts.size(); @@ -192,6 +187,103 @@ Maybe Slice(const std::shared_ptr& input, const std::vector Unsqueeze(const std::shared_ptr& input, const int32_t& expand_dim) { + CHECK_OR_RETURN(IsViewApplicable(input)) + << Error::RuntimeError() << "view::Unsqueeze(): input should be eager local tensor, but got " + << (input->is_lazy() ? "lazy tensor" : "consistent tensor") + << " with shape: " << input->shape()->ToString() << "; element count: " << input->nelement(); + + const auto& shape = input->shape(); + const auto& strides = JUST(input->stride()); + const auto& ndim = shape->NumAxes(); + + DimVector target_dim_vec(ndim + 1); + StrideVector target_stride_vec(ndim + 1); + + { + int cnt = 0; + for (int i = 0; i < ndim; i++) { + if (i == expand_dim) { cnt++; } + target_dim_vec[cnt] = shape->At(i); + target_stride_vec[cnt] = strides->At(i); + cnt++; + } + target_dim_vec[expand_dim] = 1; + target_stride_vec[expand_dim] = strides->At(expand_dim); + } + + int64_t storage_offset = JUST(JUST(input->AsMirroredTensor())->storage_offset()); + std::shared_ptr output = + JUST(BasicView(input, Shape(target_dim_vec), Stride(target_stride_vec), storage_offset)); + + if (autograd::GradMode::is_enabled() && input->requires_grad()) { + auto backward_fn = + std::make_shared(const TensorTuple&, TensorTuple*, bool)>>( + [=](const TensorTuple& out_grads, TensorTuple* in_grads, + bool create_graph) -> Maybe { + autograd::AutoGradMode mode(create_graph); + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + *JUST(oneflow::VectorAt(in_grads, 0)) = + JUST(functional::Reshape(JUST(oneflow::VectorAt(out_grads, 0)), *shape)); + return Maybe::Ok(); + }); + TensorTuple outputs{output}; + JUST(GetThreadLocalAutogradEngine()->AddBackwardFuncPtr("view::unsqueeze_backward", backward_fn, + {input}, &outputs)); + } + return output; +} + +Maybe Squeeze(const std::shared_ptr& input, + const std::vector& squeeze_dims) { + CHECK_OR_RETURN(IsViewApplicable(input)) + << Error::RuntimeError() << "view::Squeeze(): input should be eager local tensor, but got " + << (input->is_lazy() ? "lazy tensor" : "consistent tensor") + << " with shape: " << input->shape()->ToString() << "; element count: " << input->nelement(); + + const auto& shape = input->shape(); + const auto& strides = JUST(input->stride()); + const int64_t ndim = shape->NumAxes(); + + const int target_ndim = ndim - squeeze_dims.size(); + DimVector target_dim_vec(target_ndim); + StrideVector target_stride_vec(target_ndim); + + { + int cnt = 0; + for (int i = 0; i < ndim; i++) { + if (find(squeeze_dims.begin(), squeeze_dims.end(), i) == squeeze_dims.end()) { + target_dim_vec[cnt] = shape->At(i); + target_stride_vec[cnt] = strides->At(i); + cnt++; + } + } + } + + int64_t storage_offset = JUST(JUST(input->AsMirroredTensor())->storage_offset()); + std::shared_ptr output = + JUST(BasicView(input, Shape(target_dim_vec), Stride(target_stride_vec), storage_offset)); + + if (autograd::GradMode::is_enabled() && input->requires_grad()) { + auto backward_fn = + std::make_shared(const TensorTuple&, TensorTuple*, bool)>>( + [=](const TensorTuple& out_grads, TensorTuple* in_grads, + bool create_graph) -> Maybe { + autograd::AutoGradMode mode(create_graph); + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + *JUST(oneflow::VectorAt(in_grads, 0)) = JUST(functional::Reshape( + JUST(oneflow::VectorAt(out_grads, 0)), Shape(input->shape()->dim_vec()))); + return Maybe::Ok(); + }); + TensorTuple outputs{output}; + JUST(GetThreadLocalAutogradEngine()->AddBackwardFuncPtr("view::squeeze_backward", backward_fn, + {input}, &outputs)); + } + return output; +} + } // namespace view } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/tensor_methods.h b/oneflow/core/framework/tensor_methods.h index d4d3384e0f7..60781ecfbb0 100644 --- a/oneflow/core/framework/tensor_methods.h +++ b/oneflow/core/framework/tensor_methods.h @@ -28,14 +28,27 @@ Maybe IsContiguous(const std::shared_ptr& tensor); namespace view { +bool IsViewApplicable(const std::shared_ptr& input); + +Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, + int64_t storage_offset); + Maybe BasicView(const std::shared_ptr& input, const Shape& target_shape, - const Stride& target_strides, int64_t storage_offset); + const Stride& target_stride, int64_t storage_offset); -Maybe Reshape(const std::shared_ptr& input, const Shape& shape); +Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape); + +Maybe Reshape(const std::shared_ptr& input, const Shape& target_shape, + const Stride& target_stride); Maybe Slice(const std::shared_ptr& input, const std::vector& starts, const std::vector& ends, const std::vector& steps); +Maybe Unsqueeze(const std::shared_ptr& input, const int32_t& expand_dim); + +Maybe Squeeze(const std::shared_ptr& input, + const std::vector& squeeze_dims); + } // namespace view } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index b76a43b1bfb..15d4e86b94a 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -788,6 +788,12 @@ signature: "Tensor (Tensor input, Int32 dim) => Unsqueeze" bind_python: True +- name: "squeeze" + signature: [ + "Tensor (Tensor x, Int32List[1] dim=None) => Squeeze", + ] + bind_python: True + - name: "exp" signature: "Tensor (Tensor x) => Exp" bind_python: True @@ -1111,6 +1117,10 @@ signature: "Tensor (Tensor x, Shape shape) => Reshape" bind_python: True +- name: "view" + signature: "Tensor (Tensor x, Shape shape) => View" + bind_python: True + - name: "slice_view_1d_contiguous" signature: "Tensor (Tensor x, Int64 start, Int64 end) => SliceView1dContiguous" bind_python: True @@ -1143,12 +1153,6 @@ signature: "Void (Tensor ref, Tensor value, Int64List start, Int64List stop, Int64List step) => LogicalSliceAssign" bind_python: True -- name: "squeeze" - signature: [ - "Tensor (Tensor x, Int32List[1] dim=None) => Squeeze", - ] - bind_python: True - - name: "copy" signature: "Tensor (Tensor x, String device_type, Int64 device_id) => Copy" bind_python: True diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 317bb048697..4cbf678e0fd 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -46,7 +46,6 @@ limitations under the License. namespace oneflow { namespace one { namespace functional { - namespace impl { class ArgMaxFunctor { @@ -652,6 +651,9 @@ class ExpandDimsFunctor { if (dim < 0) { expand_dim = dim + ndim + 1; } MutableAttrMap attrs; JUST(attrs.SetAttr("axis", expand_dim)); + + if (view::IsViewApplicable(input)) { return view::Unsqueeze(input, expand_dim); } + return OpInterpUtil::Dispatch(*op_, {input}, attrs); } @@ -659,6 +661,43 @@ class ExpandDimsFunctor { std::shared_ptr op_; }; +class SqueezeFunctor { + public: + SqueezeFunctor() { + op_ = CHECK_JUST(one::OpBuilder("squeeze").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const Optional>& dim) const { + int32_t ndim = x->shape()->NumAxes(); + std::vector squeeze_dims; + squeeze_dims.reserve(ndim); + if (dim.has_value()) { + std::vector dims = *JUST(dim); + for (int32_t dim_i : dims) { + CHECK_OR_RETURN((dim_i >= -ndim) && (dim_i <= ndim - 1)) + << "Dimension out of range (expected to be in range of [" << -ndim << "," << ndim - 1 + << "], but got " << dim_i; + if (dim_i < 0) { dim_i += ndim; } + if (x->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); } + } + } else { + for (int i = 0; i < ndim; ++i) { + if (x->shape()->At(i) == 1) { squeeze_dims.emplace_back(i); } + } + } + + MutableAttrMap attrs; + JUST(attrs.SetAttr>("axes", squeeze_dims)); + + if (view::IsViewApplicable(x)) { return view::Squeeze(x, squeeze_dims); } + + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + class RollFunctor { public: RollFunctor() { op_ = CHECK_JUST(one::OpBuilder("roll").Input("in").Output("out").Build()); } @@ -989,36 +1028,41 @@ class ReshapeFunctor { op_ = CHECK_JUST(one::OpBuilder("reshape").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { - // if input tensor is eager local, than return tensor's view - if (x->is_local() && !(LazyMode::is_enabled())) { return view::Reshape(x, shape); } - int need_infer_axis = -1; - size_t count = 1; - for (int i = 0; i < shape.NumAxes(); ++i) { - if (shape.At(i) < -1) { - return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i); - } else if (shape.At(i) == -1) { - CHECK_EQ_OR_RETURN(need_infer_axis, -1) - << "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered."; - need_infer_axis = i; - } else { - count *= shape.At(i); + Shape infered_shape = *JUST(InferShape(x, shape)); + MutableAttrMap attrs; + JUST(attrs.SetAttr("shape", infered_shape)); + + if (view::IsViewApplicable(x)) { + Optional infered_stride = + ComputeStride(*(x->shape()), *JUST(x->stride()), infered_shape); + if (infered_stride.has_value()) { + return view::Reshape(x, infered_shape, *JUST(infered_stride)); } } - size_t x_count = x->shape()->Count(0); + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class ViewFunctor { + public: + ViewFunctor() { op_ = CHECK_JUST(one::OpBuilder("reshape").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { + Shape infered_shape = *JUST(InferShape(x, shape)); MutableAttrMap attrs; - if (need_infer_axis == -1) { - CHECK_EQ_OR_RETURN(shape.Count(0), x_count) - << "\n Shape " << shape.ToString() << " is invalid for input shape " - << x->shape()->ToString(); - JUST(attrs.SetAttr("shape", shape)); - } else { - Shape infered_shape = shape; - infered_shape.Set(need_infer_axis, x_count / count); - CHECK_EQ_OR_RETURN(infered_shape.Count(0), x_count) - << "\n Shape " << shape.ToString() << " is invalid for input shape " - << x->shape()->ToString(); - JUST(attrs.SetAttr("shape", infered_shape)); + JUST(attrs.SetAttr("shape", infered_shape)); + + if (view::IsViewApplicable(x)) { + Optional infered_stride = + ComputeStride(*(x->shape()), *JUST(x->stride()), infered_shape); + CHECK_OR_RETURN(infered_stride.has_value()) + << " >> view size is not compatible with input tensor's size and stride (at least one " + "dimension spans across two contiguous subspaces). Use .reshape(...) instead."; + return view::Reshape(x, infered_shape, *JUST(infered_stride)); } + return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1172,40 +1216,6 @@ class SliceUpdateFunctor { std::shared_ptr op_; }; -class SqueezeFunctor { - public: - SqueezeFunctor() { - op_ = CHECK_JUST(one::OpBuilder("squeeze").Input("in").Output("out").Build()); - } - Maybe operator()(const std::shared_ptr& x, - const Optional>& dim) const { - int32_t ndim = x->shape()->NumAxes(); - std::vector squeeze_dims; - squeeze_dims.reserve(ndim); - if (dim.has_value() == true) { - std::vector dims = *JUST(dim); - for (int32_t dim_i : dims) { - CHECK_OR_RETURN((dim_i >= -(ndim + 1)) && (dim_i <= ndim)) - << "Dimension out of range (expected to be in range of [" << -ndim << "," << ndim - 1 - << "], but got " << dim_i; - if (dim_i < 0) { dim_i += ndim; } - if (x->shape()->At(dim_i) == 1) { squeeze_dims.emplace_back(dim_i); } - } - } else { - for (int i = 0; i < ndim; ++i) { - if (x->shape()->At(i) == 1) { squeeze_dims.emplace_back(i); } - } - } - - MutableAttrMap attrs; - JUST(attrs.SetAttr>("axes", squeeze_dims)); - return OpInterpUtil::Dispatch(*op_, {x}, attrs); - } - - private: - std::shared_ptr op_; -}; - class UpsampleGradFunctor { public: UpsampleGradFunctor() { @@ -2707,6 +2717,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ExpandGrad"); m.add_functor("ExpandDims"); m.add_functor("Unsqueeze"); + m.add_functor("Squeeze"); m.add_functor("Roll"); m.add_functor("Gather"); m.add_functor("DimGather"); @@ -2716,6 +2727,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("TensorScatterNdUpdate"); m.add_functor("ScatterNdLike"); m.add_functor("Reshape"); + m.add_functor("View"); m.add_functor("Slice"); m.add_functor("SliceGrad"); m.add_functor("Narrow"); @@ -2724,7 +2736,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("LogicalSlice"); m.add_functor("SliceUpdate"); m.add_functor("SliceView1dContiguous"); - m.add_functor("Squeeze"); m.add_functor("Copy"); m.add_functor("Flip"); m.add_functor("FlipGrad"); diff --git a/oneflow/core/functional/impl/common.cpp b/oneflow/core/functional/impl/common.cpp index 31f002806f0..4483b1dcd4e 100644 --- a/oneflow/core/functional/impl/common.cpp +++ b/oneflow/core/functional/impl/common.cpp @@ -80,6 +80,81 @@ Maybe CheckShapeCanExpandTo(const Shape& shape, const Shape& expand_shape) return Maybe::Ok(); } +Optional ComputeStride(const Shape& shape, const Stride& stride, + const Shape& target_shape) { + /************************************************* + * Description: in some case, view operate is not allowed, so need to check it's validation, + * the check refers to torch(aten/src/ATen/native/TensorShape.cpp) + *************************************************/ + if (stride.NumAxes() == 0) { return NullOpt; } + int64_t elem_count = shape.elem_cnt(); + int64_t ndim = shape.NumAxes(); + int64_t tgt_ndim = target_shape.NumAxes(); + DimVector shape_vec = shape.dim_vec(); + DimVector tgt_shape_vec = target_shape.dim_vec(); + DimVector stride_vec = stride.StrideVec(); + if (elem_count == 0) { return NullOpt; } + + int64_t view_d = tgt_ndim - 1; + int64_t chunk_base_stride = stride_vec.back(); + DimVector newstride(tgt_ndim); + // stride for each subspace in the chunk + // numel in current chunk + int64_t tensor_numel = 1; + int64_t view_numel = 1; + for (int64_t tensor_d = ndim - 1; tensor_d >= 0; tensor_d--) { + tensor_numel *= shape_vec[tensor_d]; + // if end of tensor size chunk, check view + if ((tensor_d == 0) + || (shape_vec[tensor_d - 1] != 1 + && stride_vec[tensor_d - 1] != tensor_numel * chunk_base_stride)) { + while (view_d >= 0 && (view_numel < tensor_numel || tgt_shape_vec[view_d] == 1)) { + newstride[view_d] = view_numel * chunk_base_stride; + view_numel *= tgt_shape_vec[view_d]; + view_d--; + } + if (view_numel != tensor_numel) { return NullOpt; } + if (tensor_d > 0) { + chunk_base_stride = stride_vec[tensor_d - 1]; + tensor_numel = 1; + view_numel = 1; + } + } + } + if (view_d != -1) { return NullOpt; } + Stride target_stride(newstride); + return target_stride; +} + +Maybe InferShape(const std::shared_ptr& x, const Shape& shape) { + int need_infer_axis = -1; + size_t count = 1; + for (int i = 0; i < shape.NumAxes(); ++i) { + if (shape.At(i) < -1) { + return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i); + } else if (shape.At(i) == -1) { + CHECK_EQ_OR_RETURN(need_infer_axis, -1) + << "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered."; + need_infer_axis = i; + } else { + count *= shape.At(i); + } + } + size_t x_count = x->shape()->Count(0); + Shape infered_shape = shape; + if (need_infer_axis == -1) { + CHECK_EQ_OR_RETURN(shape.Count(0), x_count) + << "\n Shape " << shape.ToString() << " is invalid for input shape " + << x->shape()->ToString(); + } else { + infered_shape.Set(need_infer_axis, x_count / count); + CHECK_EQ_OR_RETURN(infered_shape.Count(0), x_count) + << "\n Shape " << shape.ToString() << " is invalid for input shape " + << x->shape()->ToString(); + } + return infered_shape; +} + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/impl/common.h b/oneflow/core/functional/impl/common.h index e81d0427d24..860954b4568 100644 --- a/oneflow/core/functional/impl/common.h +++ b/oneflow/core/functional/impl/common.h @@ -17,6 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_FUNCTIONAL_IMPL_COMMON_H_ #include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/stride.h" namespace oneflow { namespace one { @@ -34,6 +35,8 @@ Maybe CheckInplaceValid(const std::shared_ptr& x); Maybe CheckInplaceCastValid(const std::shared_ptr& x, const std::shared_ptr& x_cast); Maybe CheckShapeCanExpandTo(const Shape& shape, const Shape& expand_shape); +Optional ComputeStride(const Shape& shape, const Stride& stride, const Shape& target_shape); +Maybe InferShape(const std::shared_ptr& x, const Shape& shape); } // namespace functional } // namespace one diff --git a/python/oneflow/framework/tensor.py b/python/oneflow/framework/tensor.py index 0cf99c03386..d5da13185e5 100644 --- a/python/oneflow/framework/tensor.py +++ b/python/oneflow/framework/tensor.py @@ -925,7 +925,13 @@ def _reshape(self, *shape): def _view(self, *shape): - return flow.view(self, *shape) + if len(shape) == 1: + new_shape = shape[0] + if isinstance(new_shape, int): + new_shape = (new_shape,) + else: + new_shape = shape + return flow._C.view(self, new_shape) def _sort(self, dim: int = -1, descending: bool = False): diff --git a/python/oneflow/nn/modules/reshape.py b/python/oneflow/nn/modules/reshape.py index 5f29a4ef52f..354039e4506 100644 --- a/python/oneflow/nn/modules/reshape.py +++ b/python/oneflow/nn/modules/reshape.py @@ -65,7 +65,7 @@ def view_op(input, *shape): new_shape = (new_shape,) else: new_shape = shape - return flow._C.reshape(input, new_shape) + return flow._C.view(input, new_shape) if __name__ == "__main__": diff --git a/python/oneflow/test/modules/test_view.py b/python/oneflow/test/modules/test_view.py index 117c7b96b51..f4a3bd59c6d 100644 --- a/python/oneflow/test/modules/test_view.py +++ b/python/oneflow/test/modules/test_view.py @@ -74,6 +74,7 @@ def _test_view_flow_size(test_case, device): @flow.unittest.skip_unless_1n1d() class TestView(flow.unittest.TestCase): + # TODO:(zhaoluyang) add test case that trigger tensor.view's check def test_view(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [