Skip to content

Commit

Permalink
Change Call with TIRCallAttrs to call_lowered op (apache#9312)
Browse files Browse the repository at this point in the history
* Introduce call_lowered op

Add op vm.call_tir

Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir

Change device_domains to use vm.call_tir op more explicitly

Fixed issue in type checker, now have seg fault :(

Fix typo -- most of VM tests pass now

Interpreter now deals with call_tir properly

Fix typo in te_compiler

Use InvokeTVMOp and CallTIR

Add some checks to graph_plan_memory.cc

Make GetToken skip function types

C++ TESTS PASS WOOHOO

Remove prints

formatting

vm.call_tir -> call_tir and more comment removals

call_tir -> call_lowered

fix lint

clang format

Remove compute from non computational vm ops

missed some semicolons in prev commit

Fix warning

Move call_lowered to relay/op/call/call.cc and rename util func

Add helper fn that returns lowered_call op

fix import order

clang format

Add constraint to call_lowered type rel

clean up empty token vector

comment

Move CallTIRAttrs to include/tvm/relay/attrs/call.h

Rename TIRCallAttrs as CallLoweredAttrs

lint

Add helper for extracting func and args from call_lowered

Change graph_executor_codegen to use helper function

Update interpreter to use helper

Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee

Clean up DeviceCopyProps and lint

lint

return CallLoweredAttrs with the extern func

comment

note in comment

Progress & notes. Realized that I am not handling externs correctly

not sure why this ever worked before?

Clean up CreateFuncCall signature, notes

comments

Fix extern function handling

extern_function -> extern_func

fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls

yay passes AOT tests!

formatting and comment removal

cleanup

Introduce call_lowered op

* lint

* Fix AOT to deal with externs

* add const auto&

* Fix aot crt test
  • Loading branch information
electriclilies authored Nov 10, 2021
1 parent 86781e9 commit 0812c07
Show file tree
Hide file tree
Showing 14 changed files with 603 additions and 265 deletions.
11 changes: 0 additions & 11 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,6 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/attrs/call.h
* \brief Attribute for call_lowered operator.
*/
#ifndef TVM_RELAY_ATTRS_CALL_H_
#define TVM_RELAY_ATTRS_CALL_H_

#include <tvm/ir/attrs.h>

#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_CALL_H_
77 changes: 55 additions & 22 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
Expand All @@ -40,6 +41,7 @@
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
Expand Down Expand Up @@ -72,14 +74,34 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
AssignReturnSid(GetRef<Expr>(op));
}

void DeviceAwareVisitExpr_(const CallNode* op) final {
// create token for the call node.
VisitExpr(op->op);
CreateStorage(op);
for (Expr arg : op->args) {
void DeviceAwareVisitExpr_(const CallNode* call_node) final {
// AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case
// where the op of the call is a generic function

Expr func;
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
func = call_lowered_props.lowered_func;
args = call_lowered_props.arguments;
} else { // Relay functions that have not been lowered and lowered extern functions
func = call_node->op;
args = call_node->args;
if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function
ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
} else { // Relay function which has not been lowered yet
ICHECK(call_node->op.as<FunctionNode>())
<< "Expected the call to be to a lowered primfunc, a lowered extern function or a "
"unlowered Relay function.";
}
}
VisitExpr(func);
CreateStorage(call_node);
for (const Expr& arg : args) {
GetStorage(arg);
}
AssignReturnSid(GetRef<Expr>(op));
AssignReturnSid(GetRef<Expr>(call_node));
}

void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef<Expr>(op)); }
Expand Down Expand Up @@ -287,13 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Call a function with a given name
* brief Create a function call
* \param call_lowered_props The lowered function and the arguments to call it with
* \param call The call we got func and args from
*/
void CreateFuncCall(Call call, std::string func_name) {
void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
std::string func_name = call_lowered_props.lowered_func->name_hint;

tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;

// Pack the inputs
for (Expr arg : call->args) {
for (const Expr& arg : call_lowered_props.arguments) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[arg])});
Expand Down Expand Up @@ -371,21 +398,25 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return ss.str();
}

void VisitExpr_(const CallNode* op) override {
void VisitExpr_(const CallNode* call_node) override {
// Descend the call tree
for (auto arg : op->args) {
VisitExpr(arg);
}

if (op->op.as<OpNode>()) {
LOG(FATAL) << "Operators should be transformed away; try applying"
<< "the fuse_ops transformation to the expression.";
} else if (op->op.as<GlobalVarNode>()) {
GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
CreateFuncCall(GetRef<Call>(op), node->name_hint);
CallLoweredProps call_lowered_props;
if (const auto* gvn = call_node->op.as<GlobalVarNode>()) { // Lowered extern function
ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
for (const auto& arg : call_node->args) {
VisitExpr(arg);
}
call_lowered_props = CallLoweredProps{GetRef<GlobalVar>(gvn), call_node->args, {}};
} else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try "
"applying the fuse_ops transformation to the "
"expression.";
call_lowered_props = GetCallLoweredProps(call_node);
for (const auto& arg : call_lowered_props.arguments) {
VisitExpr(arg);
}
}
CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
}

void VisitExpr_(const VarNode* op) override {
Expand Down Expand Up @@ -443,7 +474,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); }
void VisitExpr_(const OpNode* op) override {
LOG(FATAL) << "All OpNodes should have been expanded";
if (GetRef<Op>(op) != CallLoweredOp()) {
LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded";
}
}
void VisitExpr_(const IfNode* op) override {
LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called";
Expand Down
12 changes: 11 additions & 1 deletion src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/memory.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include "../../../op/call/call.h"

namespace tvm {
namespace relay {
namespace contrib {
Expand Down Expand Up @@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator {
GlobalVar new_global_var(func_name.value());
new_global_var->checked_type_ = func->checked_type();
ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef<Function>(func));
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);

// Since we are replacing the Relay function with a call to a TIR function, we must use the
// call_lowered op.
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
return CallLowered(std::move(new_global_var), call->args,
std::move(Attrs(call_lowered_attrs)), call->type_args, call->span);
}
}

Expand Down
89 changes: 51 additions & 38 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/json.h>
#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
Expand All @@ -37,6 +38,7 @@
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
Expand Down Expand Up @@ -403,64 +405,75 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return lhs_storage_id == rhs_storage_id;
}

std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& func_name,
GraphAttrs attrs) {
std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) {
Call call = GetRef<Call>(call_node);
std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
auto res = VisitExpr(arg);
for (auto nr : res) {
inputs.push_back(nr);
}
}
std::string func_name;

/// An adapted version of the storage optimization for the time being.
bool reshape_only = false;
if (op->attrs.defined()) {
if (auto tir_call_attrs = op->attrs.as<TIRCallAttrs>()) {
Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
if (metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
reshape_only = true;
}
if (call->op == CallLoweredOp()) {
// Extract function and arguments from the call_lowered op
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

auto relay_attrs = Downcast<DictAttrs>(tir_call_attrs->metadata["relay_attrs"]);
func_name = call_lowered_props.lowered_func->name_hint;

for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
for (const Expr& arg : call_lowered_props.arguments) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}
if (call_lowered_props.attrs.metadata.count("relay_attrs")) {
if (auto relay_attrs =
call_lowered_props.attrs.metadata["relay_attrs"].as<DictAttrsNode>()) {
for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
}
}
}
}
}

if (reshape_only && ShareSameStorage(GetRef<Expr>(op), op->args[0])) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, GetRef<Expr>(op));
bool reshape_only = false;
if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value ==
1) {
reshape_only = true;
}
if (reshape_only &&
ShareSameStorage(GetRef<Expr>(call_node), call_lowered_props.arguments[0])) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, call);
}
} else if (!call_node->attrs.defined()) { // Call is an extern function
std::cout << "call_node: \n" << PrettyPrint(call) << std::endl;
const auto* func = call_node->op.as<GlobalVarNode>();
ICHECK(func) << "Expected the operator to be a global var, but got "
<< call_node->op->GetTypeKey(); // getting a relay fn here, not sure why.
func_name = func->name_hint;

for (const Expr& arg : call_node->args) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}
} else {
LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to be call_lowered, "
<< "but found: " << std::endl
<< PrettyPrint(call);
}

// Compute the operator name, because we used the get unique name when generating the kernel.
auto op_name = _GetUniqueName(func_name);
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
return AddNode(node, GetRef<Expr>(op));
return AddNode(node, call);
}

std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
relay::Call call = GetRef<Call>(call_node);
auto props = GetOnDeviceProps(call_node);
if (props.body.defined()) {
// See through "on_device" calls.
return VisitExpr(props.body);
}

const auto* global_node = call->op.as<GlobalVarNode>();
ICHECK(global_node)
<< "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to have their callee "
"normalized to a GlobalVar, but found:"
<< std::endl
<< PrettyPrint(call);
auto prim_fn_name = global_node->name_hint;
return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs());
return GraphAddCallNode(call_node, GraphAttrs());
}

std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
Expand Down
Loading

0 comments on commit 0812c07

Please sign in to comment.