Skip to content

Commit

Permalink
Shape and type deduction (#7)
Browse files Browse the repository at this point in the history
* Shape and type deduction.

* Fix header.

* Add call attrs to the deduce signature.

* Address comments.

* Add DiagnosticContext to IRBuilder and inference signature.

* Fix nits.
  • Loading branch information
YuchenJin authored and junrushao committed Feb 5, 2023
1 parent bd3ea31 commit 32781ef
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 62 deletions.
12 changes: 7 additions & 5 deletions include/tvm/relax/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class IRBuilderNode : public Object {
/*!
* \brief Generate an output for the current dataflow block or function.
* \param output The output variable of the block/function.
* \return The variable being binded to \p ouput.
* \return The variable being binded to \p output.
*/
Var EmitOutput(const Expr& output);
/*!
Expand All @@ -107,13 +107,15 @@ class IRBuilderNode : public Object {

private:
/*! \brief The state of the function currently being built. */
RelaxFunction func;
RelaxFunction func_;
/*! \brief A flag tracking if currently inside a dataflow block or not. */
bool is_dataflow = false;
bool is_dataflow_ = false;
/*! \brief A global variable counter for naming global variables. */
int global_var_counter = 0;
int global_var_counter_ = 0;
/*! \brief A dataflow variable counter for naming dataflow variables. */
int dataflow_var_counter = 0;
int dataflow_var_counter_ = 0;
/*! \brief A diagnostic context for reporting errors. */
DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {}));
};

class IRBuilder : public ObjectRef {
Expand Down
59 changes: 59 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/op_attr_types.h
* \brief Data structures that can appear in operator attributes.
*/
#ifndef TVM_RELAX_OP_ATTR_TYPES_H_
#define TVM_RELAX_OP_ATTR_TYPES_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/te/schedule.h>
#include <tvm/te/tensor.h>

#include <string>

namespace tvm {
namespace relax {

using relay::Call;

/*!
* \brief Infer the output shape for operators. This function will
* be invoked to fill the \p shape_ field of expressions.
* \param call The call node.
* \param diag_ctx The diagnostic context for reporting errors.
* \return The inferred output shape expression.
*/
using FInferShape = runtime::TypedPackedFunc<Optional<RelayExpr>(const Call& call, DiagnosticContext diag_ctx)>;

/*!
* \brief Infer the output type for operators. This function will
* be invoked to fill the \p checked_type_ field of expressions.
* \param call The call node.
* \param diag_ctx The diagnostic context for reporting errors.
* \return The inferred output type.
*/
using FInferType = runtime::TypedPackedFunc<Type(const Call& call, DiagnosticContext diag_ctx)>;

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
2 changes: 1 addition & 1 deletion python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def checked_type(self):
checked_type : tvm.relay.Type
The checked type.
"""
ret = self._checked_type_
ret = self.checked_type_
if ret is None:
raise ValueError("The type checker has not populated" " the checked_type for this node")
return ret
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
type_annotation: Optional[Type] = None,
span: Span = None,
) -> None:
if shape_annotation is not None:
shape_annotation = make_shape(shape_annotation)
self.__init_handle_by_constructor__(
_ffi_api.Var, name_hint, shape_annotation, type_annotation, span
)
Expand All @@ -86,6 +88,8 @@ def __init__(
type_annotation: Optional[Type] = None,
span: Span = None,
) -> None:
if shape_annotation is not None:
shape_annotation = make_shape(shape_annotation)
self.__init_handle_by_constructor__(
_ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span
)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relax/op/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
"""Basic tensor operations."""
from . import _ffi_api
from ..expr import Expr

Expand Down
92 changes: 57 additions & 35 deletions src/relax/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/

#include <tvm/relax/ir_builder.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relay/op.h>

namespace tvm {
Expand All @@ -38,59 +39,84 @@ IRBuilder IRBuilderNode::Create() {

void IRBuilderNode::FillFuncNameParam(const Array<Var>& params, const std::string& func_name) {
if (!func_name.empty()) {
this->func.func_name = GlobalVar(func_name);
this->func_.func_name = GlobalVar(func_name);
}
this->func.params = params;

this->func_.params = params;
}

void IRBuilderNode::BuildFunction() {
SeqExpr seq = SeqExpr(this->func.binding_blocks, this->func.ret);
this->func.func = Function(this->func.func_name, this->func.params, seq, {});
this->global_var_counter = 0;
SeqExpr seq = SeqExpr(this->func_.binding_blocks, this->func_.ret);
this->func_.func = Function(this->func_.func_name, this->func_.params, seq, {});
this->global_var_counter_ = 0;
}

void IRBuilderNode::BuildBlock() {
if (!this->func.bindings.empty()) {
if (is_dataflow) {
this->func.binding_blocks.emplace_back(DataflowBlock(this->func.bindings));
if (!this->func_.bindings.empty()) {
if (is_dataflow_) {
this->func_.binding_blocks.emplace_back(DataflowBlock(this->func_.bindings));
} else {
this->func.binding_blocks.emplace_back(BindingBlock(this->func.bindings));
this->func_.binding_blocks.emplace_back(BindingBlock(this->func_.bindings));
}
this->func.bindings.clear();
this->func_.bindings.clear();
}
this->dataflow_var_counter = 0;
this->is_dataflow = !this->is_dataflow;
this->dataflow_var_counter_ = 0;
this->is_dataflow_ = !this->is_dataflow_;
}

Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
auto op_map = Op::GetAttrMap<FInferShape>("FInferShape");
Op op = Downcast<Op>(call->op);
return op_map[op](call, diag_ctx);
}

Type InferType(const Call& call, DiagnosticContext diag_ctx) {
auto op_map = Op::GetAttrMap<FInferType>("FInferType");
Op op = Downcast<Op>(call->op);
return op_map[op](call, diag_ctx);
}

Var IRBuilderNode::Emit(const Call& call) {
Var var;
if (is_dataflow) {
var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter++)), NullOpt, NullOpt);
if (is_dataflow_) {
var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt);
} else {
var = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt);
var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt);
}

// Shape inference
auto inferred_shape = InferShape(call, this->diag_ctx_);
if (inferred_shape.defined()) {
if (auto* shape_expr = inferred_shape.value().as<ShapeExprNode>()) {
call->shape_ = GetRef<Expr>(shape_expr);
var->shape_ = call->shape_;
}
}
// Type inference
auto inferred_type = InferType(call, this->diag_ctx_);
call->checked_type_ = inferred_type;
var->checked_type_ = inferred_type;

this->func.bindings.emplace_back(VarBinding(var, call));
this->func_.bindings.emplace_back(VarBinding(var, call));
return var;
}

Var IRBuilderNode::EmitOutput(const Expr& output) {
Var ret;
if (is_dataflow) {
ret = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt);
if (is_dataflow_) {
ret = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt);
ret->shape_ = output->shape_;
ret->checked_type_ = output->checked_type_;
this->func.bindings.emplace_back(VarBinding(ret, output));
this->func_.bindings.emplace_back(VarBinding(ret, output));
} else {
this->func.ret = output;
this->func_.ret = output;
}
return ret;
}

Function IRBuilderNode::Get() { return this->func.func; }
Function IRBuilderNode::Get() { return this->func_.func; }

std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func.binding_blocks; }
std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; }

class FunctionScope::Internal {
public:
Expand Down Expand Up @@ -121,20 +147,16 @@ DataflowScope::DataflowScope(IRBuilder ib) {
data_ = std::move(n);
}

void DataflowScope::EnterWithScope() {
this->get()->ir_builder->BuildBlock();
}
void DataflowScope::EnterWithScope() { this->get()->ir_builder->BuildBlock(); }

void DataflowScope::ExitWithScope() {
this->get()->ir_builder->BuildBlock();
}
void DataflowScope::ExitWithScope() { this->get()->ir_builder->BuildBlock(); }

TVM_REGISTER_GLOBAL("relax.IRBuilderCreate").set_body_typed(IRBuilderNode::Create);

TVM_REGISTER_GLOBAL("relax.IRBuilderFillFuncNameParam")
.set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) {
return builder->FillFuncNameParam(params, func_name);
});
.set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) {
return builder->FillFuncNameParam(params, func_name);
});

TVM_REGISTER_GLOBAL("relax.IRBuilderBuildFunction").set_body_typed([](IRBuilder builder) {
return builder->BuildFunction();
Expand All @@ -145,9 +167,9 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder,
});

TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput")
.set_body_typed([](IRBuilder builder, const Expr& output) {
return builder->EmitOutput(output);
});
.set_body_typed([](IRBuilder builder, const Expr& output) {
return builder->EmitOutput(output);
});

TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) {
return builder->Get();
Expand Down
10 changes: 6 additions & 4 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
#ifndef TVM_RELAX_OP_OP_COMMON_H_
#define TVM_RELAX_OP_OP_COMMON_H_

#include <tvm/relax/op_attr_types.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>

namespace tvm {
namespace relax {
Expand All @@ -42,15 +42,17 @@ namespace relax {
*
* \param OpName the name of registry.
*/
#define RELAX_REGISTER_BINARY_OP(OpName) \
#define RELAX_REGISTER_BINARY_BROADCAST_OP(OpName) \
TVM_REGISTER_GLOBAL("relax.op." OpName).set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
static const Op& op = Op::Get("relax." OpName); \
return Call(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP("relax." OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.set_attr<FInferShape>("FInferShape", InferShapeBinaryBroadcast) \
.set_attr<FInferType>("FInferType", InferTypeBinaryBroadcast)

} // namespace relax
} // namespace tvm
Expand Down
20 changes: 3 additions & 17 deletions src/relax/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,18 @@
* \brief binary broadcast operators.
*/

#include <tvm/arith/analyzer.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>
#include <tvm/topi/broadcast.h>
#include "binary.h"

#include "../op_common.h"

namespace tvm {
namespace relax {

using Expr = tvm::RelayExpr;
using relay::Call;

#define RELAX_BINARY_COMPUTE(FTOPI) \
[](const Attrs& attrs, const Array<te::Tensor>& inputs, \
const Type& out_type) -> Array<te::Tensor> { \
ICHECK_EQ(inputs.size(), 2U); \
return {FTOPI(inputs[0], inputs[1])}; \
}

RELAX_REGISTER_BINARY_OP("add")
RELAX_REGISTER_BINARY_BROADCAST_OP("add")
.describe("Elementwise add with broadcasting")
.set_support_level(1);

RELAX_REGISTER_BINARY_OP("multiply")
RELAX_REGISTER_BINARY_BROADCAST_OP("multiply")
.describe("Elementwise multiply with broadcasting")
.set_support_level(1);

Expand Down
Loading

0 comments on commit 32781ef

Please sign in to comment.