Skip to content

Commit

Permalink
[Op][Manip] collapse_sum_like, collapse_sum_to (#87)
Browse files Browse the repository at this point in the history
* collapse_sum_like & collapse_sum_to

* address comments

* test coverage

* fix some tests

* collapse_sum_to tgt shape var test

* format

* format

* reformat

* reformat

* type: ignore

* handle shape var cases and add regression tests
  • Loading branch information
SiriusNEO authored and MasterJH5574 committed Jan 9, 2023
1 parent 1bd2329 commit d4a1be8
Show file tree
Hide file tree
Showing 6 changed files with 564 additions and 0 deletions.
53 changes: 53 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,56 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr:
if isinstance(axis, int):
axis = [axis]
return _ffi_api.squeeze(x, axis) # type: ignore


def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
"""Return a summation of data to the shape of collapse_target.
For details, please see relax.op.collapse_sum_to.
Parameters
----------
data : relax.Expr
The input tensor.
collapse_target : relax.Expr
The tensor whose shape is the shape to collapse to.
Returns
-------
result : relax.Expr
The result tensor after summation.
"""
return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore


def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr:
"""Return a summation of data to the given shape.
collapse_sum_to is intended as the backward operator of tvm.relax.op.broadcast_to and
other broadcast operators in the automatic differentiation process.
We expect that data is the result of broadcasting some tensor of the given shape in some
broadcast operation. Thus the given `shape` and `data.shape` must follow broadcast rules.
During computation, all axes of `data.shape` and `shape` are checked from right to left.
For an axis, if it follows these rules, `data` will be summed over this axis:
- the axis exists in `data.shape` but not in `shape`, or
- the axis exists in `data.shape` and equals to 1 in `shape`.
Parameters
----------
data : relax.Expr
The input tensor.
shape : Union[Tuple[PrimExprLike], relax.Expr]
The shape to collapse to.
Returns
-------
result : relax.Expr
The result tensor of the given shape after summation.
"""
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.collapse_sum_to(data, shape) # type: ignore
4 changes: 4 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
null_value,
call_tir,
call_builtin,
collapse_sum_like,
collapse_sum_to,
concat,
cos,
divide,
Expand Down Expand Up @@ -400,6 +402,8 @@ def tuple(*fields: List[Expr]) -> Expr:
"call_tir",
"call_builtin",
"cos",
"collapse_sum_like",
"collapse_sum_to",
"null_value",
"concat",
"const",
Expand Down
130 changes: 130 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,5 +799,135 @@ TVM_REGISTER_OP("relax.squeeze")
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSqueeze);

void CheckCollapseShape(const Call& call, const BlockBuilder& ctx,
const Array<PrimExpr>& data_shape, const Array<PrimExpr>& target_shape) {
arith::Analyzer* analyzer = ctx->GetAnalyzer();

int data_ndim = data_shape.size();
int target_ndim = target_shape.size();

int data_ax = data_ndim - 1;
int target_ax = target_ndim - 1;
for (; data_ax >= 0; --data_ax) {
if (target_ax < 0) {
continue;
}
const PrimExpr& dim0 = data_shape[data_ax];
const PrimExpr& dim1 = target_shape[target_ax];
const auto* int_dim0 = dim0.as<IntImmNode>();
const auto* int_dim1 = dim1.as<IntImmNode>();

if (analyzer->CanProveEqual(dim0, dim1) || (int_dim1 != nullptr && int_dim1->value == 1)) {
--target_ax;
} else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "In " << call->op << ", the data shape at dim " << data_ax << " is "
<< dim0 << " and the target shape at dim " << target_ax << " is " << dim1
<< ", which do not match the rule of collapse sum.");
} else {
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
ctx->ReportFatal(Diagnostic::Error(call)
<< call->op
<< " fails to match the axes because of unknown dim or symbolic"
" shape. In this position the dim of data shape is "
<< dim0 << " while the dim of target shape is " << dim1
<< ". If it is symbolic, consider use MatchCast first.");
}
}
}

/* relax.collapse_sum_like */
Expr collapse_sum_like(Expr data, Expr collapse_target) {
static const Op& op = Op::Get("relax.collapse_sum_like");
return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like);

StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo data_sinfo = input_sinfo[0];
TensorStructInfo collapse_target_sinfo = input_sinfo[1];

DataType output_dtype = data_sinfo->dtype;

Optional<Array<PrimExpr>> data_shape_value;
if (data_sinfo->shape.defined()) {
data_shape_value = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())->values;
}
Optional<Array<PrimExpr>> collapse_target_shape_value;
if (collapse_target_sinfo->shape.defined()) {
collapse_target_shape_value =
GetStructInfoAs<ShapeStructInfoNode>(collapse_target_sinfo->shape.value())->values;
}

if (data_shape_value.defined() && collapse_target_shape_value.defined()) {
CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value());
}

if (collapse_target_sinfo->shape.defined()) {
return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype);
} else {
return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim);
}
}

TVM_REGISTER_OP("relax.collapse_sum_like")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("collapse_target", "Tensor",
"The tensor whose shape is the shape to collapse to.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCollapseSumLike);

/* relax.collapse_sum_to */
Expr collapse_sum_to(Expr data, Expr shape) {
static const Op& op = Op::Get("relax.collapse_sum_to");
return Call(op, {std::move(data), std::move(shape)}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to);

StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 2) {
ctx->ReportFatal(Diagnostic::Error(call) << "CollapseSumTo should have 2 arguments");
}

const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "CollapseSumTo requires the input data to be a Tensor. However, the given one is "
<< call->args[0]->struct_info_->GetTypeKey());
}
if (shape_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "CollapseSumTo requires the input shape to be a Shape. However, the given one is "
<< call->args[1]->struct_info_->GetTypeKey());
}

DataType output_dtype = data_sinfo->dtype;

Optional<Array<PrimExpr>> data_shape_value;
if (data_sinfo->shape.defined()) {
data_shape_value = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())->values;
}

if (data_shape_value.defined() && shape_sinfo->values.defined()) {
CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value());
}

return TensorStructInfo(/*shape=*/call->args[1], output_dtype);
}

TVM_REGISTER_OP("relax.collapse_sum_to")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape", "Shape", "The shape to collapse to.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCollapseSumTo);

} // namespace relax
} // namespace tvm
21 changes: 21 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ Expr split(Expr x, ObjectRef indices_or_sections, int axis);
*/
Expr squeeze(Expr x, Optional<Array<Integer>> axis);

/*!
* \brief Return a summation of data to the shape of collapse_target.
* For details, please see the operator `relax.collapse_sum_to`.
* \param data The input tensor.
* \param collapse_target The tensor whose shape is the shape to collapse to.
* \return The result tensor after summation.
*/
Expr collapse_sum_like(Expr data, Expr collapse_target);

/*!
* \brief Return a summation of data to the given shape.
* collapse_sum_to is intended as the backward operator of broadcast_to and
* other broadcast operators in the automatic differentiation process.
* We expect that data is the result of broadcasting some tensor of the given shape in some
* broadcast operation. Thus the given shape and data.shape must follow broadcast rules.
* \param data The input tensor.
* \param shape The shape to collapse to.
* \return The result tensor of the given shape after summation.
*/
Expr collapse_sum_to(Expr data, Expr shape);

} // namespace relax
} // namespace tvm

Expand Down
Loading

0 comments on commit d4a1be8

Please sign in to comment.