Skip to content

Commit

Permalink
optimize reshape/slice/transpose functor
Browse files Browse the repository at this point in the history
  • Loading branch information
wyushun committed Dec 7, 2021
1 parent db415d6 commit cbc5680
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 22 deletions.
65 changes: 43 additions & 22 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,32 +891,41 @@ class ReshapeFunctor {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const {
// if input tensor is eager local, than return tensor's view
if (x->is_eager() && x->is_local()) { return view::Reshape(x, shape); }
int need_infer_axis = -1;
size_t count = 1;
for (int i = 0; i < shape.NumAxes(); ++i) {
if (shape.At(i) == -1) {
CHECK_EQ_OR_RETURN(need_infer_axis, -1)
<< "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered.";
need_infer_axis = i;

// if shape is x's shape, return directly
if (*(x->shape()) == shape) return x;

// normal handling routine, consider 0 and other negative numbers besides -1
auto new_shape = shape;
auto last_neg_pos = -1;
auto negative_axis_num = 0;
auto positive_cnt = 0;
for (auto i = 0; i < new_shape.NumAxes(); i++) {
if (new_shape.At(i) <= 0) {
negative_axis_num++;
last_neg_pos = i;
} else {
count *= shape.At(i);
positive_cnt *= new_shape.At(i);
}
}
size_t x_count = x->shape()->Count(0);
MutableAttrMap attrs;
if (need_infer_axis == -1) {
CHECK_EQ_OR_RETURN(shape.Count(0), x_count)
<< "\n Shape " << shape.ToString() << " is invalid for input shape "
<< x->shape()->ToString();
JUST(attrs.SetAttr<Shape>("shape", shape));
} else {
Shape infered_shape = shape;
infered_shape.Set(need_infer_axis, x_count / count);
CHECK_EQ_OR_RETURN(infered_shape.Count(0), x_count)
<< "\n Shape " << shape.ToString() << " is invalid for input shape "
<< x->shape()->ToString();
JUST(attrs.SetAttr<Shape>("shape", infered_shape));
if (negative_axis_num == 1) {
if (new_shape.At(last_neg_pos) != -1) {
Error::RuntimeError() << "Shape " << new_shape.ToString()
<< " has negative axis besides -1.";
} else {
if (x->shape()->Count(0) % positive_cnt) {
Error::RuntimeError() << "Shape " << new_shape.ToString()
<< " is invalid, -1 cannot be deduced.";
} else {
new_shape.Set(last_neg_pos, x->shape()->Count(0) / positive_cnt);
}
}
} else if (negative_axis_num > 1) {
Error::RuntimeError() << "Shape " << new_shape.ToString()
<< " has more than 2 negative axis.";
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", new_shape));
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}

Expand All @@ -931,6 +940,18 @@ class SliceBaseFunctor {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int64_t>& start,
const std::vector<int64_t>& stop,
const std::vector<int64_t>& step) const {
// judge whether is full slice in front end
if (op_->op_type_name() == "slice") {
auto is_full_slice = [&]() {
for (auto i = 0; i < x->ndim(); ++i) {
if (start[i] != 0 || stop[i] != x->dim(i) || step[i] != 1) { return false; }
}
return true;
}();
if (is_full_slice) return x;
}

// normal handling routine
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int64_t>>("start", start));
JUST(attrs.SetAttr<std::vector<int64_t>>("stop", stop));
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,13 @@ class TransposeFunctor {
const std::vector<int32_t>& permute) const {
MutableAttrMap attrs;
CHECK_EQ_OR_RETURN(input->ndim(), permute.size()) << "number of dims don't match in permute";

// if permute vector is 0,1,...,n, return input directly
std::vector<int32_t> not_permute(permute.size());
std::iota(not_permute.begin(), not_permute.end(), 1);
if (not_permute == permute) return input;

// normal handling routine
JUST(attrs.SetAttr<std::vector<int32_t>>("perm", permute));
int32_t ndims = input->shape()->NumAxes();
for (int i = 0; i < permute.size(); i++) {
Expand Down

0 comments on commit cbc5680

Please sign in to comment.