Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Update EmitMatchShape.
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin committed Sep 27, 2021
1 parent d0b648d commit 60cc239
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 12 deletions.
6 changes: 4 additions & 2 deletions include/tvm/relax/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ class IRBuilderNode : public Object {
virtual Var Emit(const Var& var, const Call& call);
/*!
* \brief Emit a MatchShape.
* \param match_shape The MatchShape to be emitted.
* \param value The value of the MatchShape to be emitted.
* \param pattern The pattern of the MatchShape to be emitted.
* \return The variable being binded to the MatchShape.
*/
void Emit(const MatchShape& match_shape);
Var EmitMatchShape(const Expr& value, const Array<PrimExpr>& pattern);
/*!
* \brief Generate an output for the current dataflow block or function.
* \param output The output variable of the block/function.
Expand Down
19 changes: 14 additions & 5 deletions python/tvm/relax/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,25 @@ def emit(self,
"""
return _ffi_api.IRBuilderEmit(self, call)

def emit_matchshape(self,
match_shape: MatchShape):
def match_shape(self,
value: Expr,
pattern: List[PrimExpr]):
"""Emit a MatchShape.
Parameters
----------
match_shape : tvm.relax.MatchShape
The MatchShape to be emitted.
value : tvm.relay.Expr
The value of the MatchShape to be emitted.
pattern : List[PrimExpr]
The pattern of the MatchShape to be emitted.
Returns
-------
ret : tvm.relax.Var
A newly created variable that gets binded to the call code.
"""
return _ffi_api.IRBuilderEmitMatchShape(self, match_shape)
return _ffi_api.IRBuilderEmitMatchShape(self, value, pattern)

def emit_output(self,
output: Union[Expr, Tuple, List[Expr]]) -> None:
Expand Down
27 changes: 24 additions & 3 deletions src/relax/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/relax/op_attr_types.h>
#include <tvm/relay/op.h>
#include <tvm/arith/analyzer.h>
#include <tvm/relax/type.h>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -116,8 +117,28 @@ Var IRBuilderNode::Emit(const Call& call) {
return var;
}

void IRBuilderNode::Emit(const MatchShape& match_shape) {
Var IRBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& pattern) {
Var var;
if (is_dataflow_) {
var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt);
} else {
var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt);
}
if (value->checked_type().as<ShapeTypeNode>()) {
var->checked_type_ = ShapeType(Span());
} else if (value->checked_type().as<DynTensorTypeNode>()){
ShapeExpr shape = ShapeExpr(pattern);
var->shape_ = shape;
DataType dtype = (Downcast<DynTensorType>(value->checked_type()))->dtype;
var->checked_type_ = DynTensorType(pattern.size(), dtype);
} else {
this->diag_ctx_.EmitFatal(Diagnostic::Error(value->span)
<< "The value passed to EmitMatchShape must be of DynTensorType or ShapeType.");
}

MatchShape match_shape = MatchShape(value, pattern, var);
this->func_.bindings.emplace_back(match_shape);
return var;
}

Var IRBuilderNode::Emit(const VarBinding& binding) {
Expand Down Expand Up @@ -393,8 +414,8 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder,
return builder->Emit(call);
});

TVM_REGISTER_GLOBAL("relax.IRBuilderEmitMatchShape").set_body_typed([](IRBuilder builder, const MatchShape& match_shape) {
builder->Emit(match_shape);
TVM_REGISTER_GLOBAL("relax.IRBuilderEmitMatchShape").set_body_typed([](IRBuilder builder, const Expr& value, const Array<PrimExpr>& pattern) {
return builder->EmitMatchShape(value, pattern);
});

TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput")
Expand Down
49 changes: 47 additions & 2 deletions tests/python/relax/test_irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def test_dataflow_block():
lv1 = ib.emit(rx.op.multiply(lv0, y))
assert lv1.name_hint == "lv1"

b0 = rx.MatchShape([m, n], x.shape)
ib.emit_matchshape(b0)
b0 = ib.match_shape(x, [m, n])

gv0 = ib.emit_output(lv1)
assert gv0.name_hint == "gv0"
Expand Down Expand Up @@ -178,6 +177,50 @@ def test_binary_shape_type_deduction():
assert gv0.checked_type.rank == 1
assert gv0.checked_type.dtype == "float16"


def test_emit_match_shape():
m = tir.Var("m", dtype="int32")
n = tir.Var("n", dtype="int32")
type_anno0 = rx.DynTensorType(-1, "float32")
x = rx.Var("tensor_value", type_annotation=type_anno0)
shape_anno = [16, 8]
y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno)
ib = rx.IRBuilder()

with ib.function([x, y]):
with ib.dataflow() as df:
# lv0: Tensor[(m, n), "float32"] =
# match_shape(x: Tensor[_, "float32"], [m, n])
lv0 = ib.match_shape(x, [m, n])
assert isinstance(lv0, rx.DataflowVar)
assert lv0.shape[0] == m
assert lv0.shape[1] == n
assert lv0.checked_type.rank == 2
assert lv0.checked_type.dtype == "float32"

# lv1: Shape = match_shape(shape, [m, n])
lv1 = ib.match_shape(y, [m, n])
assert lv1.checked_type == rx.ShapeType()
gv0 = ib.emit_output(lv1)

ib.emit_output(gv0)

block = ib.get_blocks()[-1]
b0, b1 = block.bindings[:2]
assert isinstance(b0, rx.MatchShape)
assert isinstance(b1, rx.MatchShape)

assert b0.value == x
assert b0.pattern[0] == m
assert b0.pattern[1] == n
assert b0.var == lv0

assert b1.value == y
assert b1.pattern[0] == m
assert b1.pattern[1] == n
assert b1.var == lv1


def test_normalize():
m = tir.Var("m", "int32")
n = tir.Var("n", "int32")
Expand All @@ -195,9 +238,11 @@ def test_normalize():
assert add_call.shape[0] == m
assert add_call.shape[1] == n


if __name__ == "__main__":
test_dataflow_block()
test_function_single_block()
test_function_multi_blocks()
test_binary_shape_type_deduction()
test_emit_match_shape()
test_normalize()

0 comments on commit 60cc239

Please sign in to comment.