Skip to content

Commit

Permalink
[Unity] Relax op: search (apache#13992)
Browse files Browse the repository at this point in the history
This PR is about the high-level tensor computation operators in Relax.

This PR includes the search operators.
  • Loading branch information
MasterJH5574 authored and yongwww committed Feb 27, 2023
1 parent 9ebc992 commit b061f74
Show file tree
Hide file tree
Showing 7 changed files with 532 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .manipulate import *
from .op_attrs import *
from .statistical import *
from .search import *
from .set import *
from .ternary import *
from .unary import *
Expand Down
50 changes: 50 additions & 0 deletions python/tvm/relax/op/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Search operators."""
from . import _ffi_api
from ..expr import Expr


def where(condition: Expr, x1: Expr, x2: Expr) -> Expr:
"""Selecting elements from either the input tensors depending on the value of the
condition.
For a given position, return the corresponding value in `x1` if `condition` is True,
and return the corresponding value in `x2` otherwise.
Parameters
----------
condition : relax.Expr
When True, yield `x1`; otherwise, yield `x2`.
Must be broadcasting compatible with `x1` and `x2`.
Must have boolean dtype.
x1 : relax.Expr
The first input tensor.
Must be broadcasting compatible with `condition` and `x2`.
x2 : relax.Expr
The second input tensor.
Must be broadcasting compatible with `condition` and `x1`.
Returns
-------
result : relax.Expr
The result tensor.
"""
return _ffi_api.where(condition, x1, x2) # type: ignore
4 changes: 3 additions & 1 deletion python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
tril,
triu,
unique,
where,
zeros,
zeros_like,
nn,
Expand Down Expand Up @@ -547,8 +548,9 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"tril",
"triu",
"tuple",
"variance",
"unique",
"variance",
"where",
"zeros",
"zeros_like",
"nn",
Expand Down
99 changes: 99 additions & 0 deletions src/relax/op/tensor/search.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file search.cc
* \brief Searching operators.
*/

#include "search.h"

#include <algorithm>
#include <utility>

namespace tvm {
namespace relax {

/* relax.where */
Expr where(Expr condition, Expr x1, Expr x2) {
static const Op& op = Op::Get("relax.where");
return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where);

StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo cond_sinfo = input_sinfo[0];
TensorStructInfo x1_sinfo = input_sinfo[1];
TensorStructInfo x2_sinfo = input_sinfo[2];

if (!cond_sinfo->dtype.is_bool()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Where requires the input condition tensor to have boolean dtype. However, "
"the given condition dtype is "
<< cond_sinfo->dtype);
}
DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo);

int output_ndim;
if (cond_sinfo->IsUnknownNdim() || x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
output_ndim = kUnknownNDim;
} else {
output_ndim = std::max(cond_sinfo->ndim, std::max(x1_sinfo->ndim, x2_sinfo->ndim));
}

const auto* cond_shape = cond_sinfo->shape.as<ShapeExprNode>();
const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
if (cond_shape && x1_shape && x2_shape) {
// Step 1. Compute the broadcasted shape of x1's and x2's
Optional<Array<PrimExpr>> broadcasted_shape =
InferBinaryBroadcastShape(call, ctx, x1_shape->values, x2_shape->values);
if (!broadcasted_shape.defined()) {
return TensorStructInfo(output_dtype, output_ndim);
}
// Step 2. Compute the broadcasted shape of cond's and the previous broadcasted shape.
broadcasted_shape =
InferBinaryBroadcastShape(call, ctx, cond_shape->values, broadcasted_shape.value());
if (!broadcasted_shape.defined()) {
return TensorStructInfo(output_dtype, output_ndim);
}
ICHECK_EQ(static_cast<int>(broadcasted_shape.value().size()), output_ndim);
return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype);
} else if (cond_sinfo->shape.defined() && //
x1_sinfo->shape.defined() && //
x2_sinfo->shape.defined() && //
cond_sinfo->shape.same_as(x1_sinfo->shape) && //
cond_sinfo->shape.same_as(x2_sinfo->shape)) {
return TensorStructInfo(cond_sinfo->shape.value(), output_dtype);
} else {
return TensorStructInfo(output_dtype, output_ndim);
}
}

TVM_REGISTER_OP("relax.where")
.set_num_inputs(3)
.add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, yield `x2`.")
.add_argument("x1", "Tensor", "The first input tensor.")
.add_argument("x2", "Tensor", "The second input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoWhere);

} // namespace relax
} // namespace tvm
41 changes: 41 additions & 0 deletions src/relax/op/tensor/search.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file search.h
* \brief The functions to make Relax searching operator calls.
*/
#ifndef TVM_RELAX_OP_TENSOR_SEARCH_H_
#define TVM_RELAX_OP_TENSOR_SEARCH_H_

#include "../op_common.h"

namespace tvm {
namespace relax {

/*!
* \brief Selecting elements from either the input tensors depending on the value of the
* condition.
*/
Expr where(Expr condition, Expr x1, Expr x2);

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_OP_TENSOR_SEARCH_H_
Loading

0 comments on commit b061f74

Please sign in to comment.