Skip to content

Commit

Permalink
[Relay][Op] Remove reverse attribute from reshape and reverse_reshape…
Browse files Browse the repository at this point in the history
… operators. (apache#7086)
  • Loading branch information
jwfromm authored and tkonolige committed Jan 11, 2021
1 parent 1ec54fb commit 369c069
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 25 deletions.
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,9 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
"The new shape. Should be compatible with the original shape.");
TVM_ATTR_FIELD(reverse)
.describe("Infer the special values from right to left if true")
.set_default(false);
}
}; // struct ReshapeAttrs

Expand Down
1 change: 0 additions & 1 deletion src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& in

Expr MakeReshape(Expr data, Expr newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->reverse = false;
static const Op& op = Op::Get("dyn.reshape");
return Call(op, {data, newshape}, Attrs(attrs), {});
}
Expand Down
76 changes: 58 additions & 18 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,14 @@ RELAY_REGISTER_OP("transpose")
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);

Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs) {
Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs,
bool reverse) {
const auto* param = attrs.as<ReshapeAttrs>();
Array<IndexExpr> oshape;
Array<IndexExpr> ishape;
Array<Integer> newshape;

if (param->reverse) {
if (reverse) {
ishape.Assign(data_shape.rbegin(), data_shape.rend());
newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
} else {
Expand Down Expand Up @@ -582,7 +583,6 @@ Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs&

bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* param = attrs.as<ReshapeAttrs>();
// types: [data, result]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
Expand All @@ -592,16 +592,12 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return false;
}

const auto& oshape = infer_newshape(data->shape, attrs);
const auto& oshape = InferNewShape(data->shape, attrs, false);

// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
Array<IndexExpr> data_shape;
if (param->reverse) {
data_shape.Assign(data->shape.rbegin(), data->shape.rend());
} else {
data_shape = data->shape;
}
data_shape = data->shape;

bool found_dynamic = false;
int64_t oshape_sum = 1;
Expand Down Expand Up @@ -631,12 +627,58 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Input tensor shape and reshaped shape are not compatible";
}

if (param->reverse) {
reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
reporter->Assign(types[1], TensorType(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}

bool ReverseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "reshape: expect input type to be TensorType but get " << types[0];
return false;
}

const auto& oshape = InferNewShape(data->shape, attrs, true);

// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
Array<IndexExpr> data_shape;
data_shape.Assign(data->shape.rbegin(), data->shape.rend());

bool found_dynamic = false;
int64_t oshape_sum = 1;
for (auto& x : oshape) {
// Check if we have a dynamic shape. If we do, we can't verify if the
// reshape is valid. Dynamic shapes are marker by using Any, but can also
// occur from SizeVar's. In the case of SizeVar, the shape expression can
// be an AST. We can't easily check if we have an AST because of a ShapeVar
// or some other reason, so our check for dynamic shape is just if we can
// convert the shape to in integer or not.
if (!x->IsInstance<tvm::Integer::ContainerType>()) {
found_dynamic = true;
break;
}
oshape_sum *= Downcast<tvm::Integer>(x)->value;
}
int64_t data_shape_sum = 1;
for (auto& x : data_shape) {
if (!x->IsInstance<tvm::Integer::ContainerType>()) {
found_dynamic = true;
break;
}
data_shape_sum *= Downcast<tvm::Integer>(x)->value;
}
if (!found_dynamic) {
ICHECK_EQ(oshape_sum, data_shape_sum)
<< "Input tensor shape and reshaped shape are not compatible";
}

reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
return true;
}

Expand Down Expand Up @@ -699,15 +741,14 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& in
}

if (newshape_has_any) {
newshape = infer_newshape(inputs[0]->shape, attrs);
newshape = InferNewShape(inputs[0]->shape, attrs, false);
}
return {topi::reshape(inputs[0], newshape)};
}

Expr MakeReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = false;
static const Op& op = Op::Get("reshape");
return Call(op, {data}, Attrs(attrs), {});
}
Expand Down Expand Up @@ -2869,7 +2910,6 @@ RELAY_REGISTER_OP("auto_scheduler_layout_transform")
Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
static const Op& op = Op::Get("contrib_reverse_reshape");
return Call(op, {data}, Attrs(attrs), {});
}
Expand All @@ -2894,7 +2934,7 @@ example below::
.set_attrs_type<ReshapeAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(10)
.add_type_rel("Reshape", ReshapeRel)
.add_type_rel("ReverseReshape", ReverseReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ static inline Array<Array<Layout>> ConcatenateLayout(const Attrs& attrs,
* \param attrs The attributes.
* \return Output shape.
*/
Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs);
Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs);

} // namespace relay
} // namespace tvm
Expand Down
1 change: 0 additions & 1 deletion tests/python/contrib/test_arm_compute_lib/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _get_expected_codegen(input_shape, output_shape, dtype):
"newshape": [[str(s) for s in output_shape]],
"shape": [[list(output_shape)]],
"dtype": [[dtype]],
"reverse": [["0"]],
},
}

Expand Down

0 comments on commit 369c069

Please sign in to comment.