Skip to content

Commit

Permalink
Introduce call_lowered op
Browse files Browse the repository at this point in the history
Add op vm.call_tir

Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir

Change device_domains to use vm.call_tir op more explicitly

Fixed issue in type checker, now have seg fault :(

Fix typo -- most of VM tests pass now

Interpreter now deals with call_tir properly

Fix typo in te_compiler

Use InvokeTVMOp and CallTIR

Add some checks to graph_plan_memory.cc

Make GetToken skip function types

C++ TESTS PASS WOOHOO

Remove prints

formatting

vm.call_tir -> call_tir and more comment removals

call_tir -> call_lowered

fix lint

clang format

Remove compute from non computational vm ops

missed some semicolons in prev commit

Fix warning

Move call_lowered to relay/op/call/call.cc and rename util func

Add helper fn that returns lowered_call op

fix import order

clang format

Add constraint to call_lowered type rel

clean up empty token vector

comment

Move CallTIRAttrs to include/tvm/relay/attrs/call.h

Rename TIRCallAttrs as CallLoweredAttrs

lint

Add helper for extracting func and args from call_lowered

Change graph_executor_codegen to use helper function

Update interpreter to use helper

Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee

Clean up DeviceCopyProps and lint

lint

return CallLoweredAttrs with the extern func

comment

note in comment

Progress & notes. Realized that I am not handling externs correctly

not sure why this ever worked before?

Clean up CreateFuncCall signature, notes

comments

Fix extern function handling

extern_function -> extern_func

fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls

yay passes AOT tests!

formatting and comment removal

cleanup
  • Loading branch information
electriclilies committed Oct 26, 2021
1 parent 649ee20 commit a379c27
Show file tree
Hide file tree
Showing 15 changed files with 473 additions and 206 deletions.
11 changes: 0 additions & 11 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,6 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
}
};

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct TIRCallAttrs : public tvm::AttrsNode<TIRCallAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
Expand Down
48 changes: 48 additions & 0 deletions include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
*/
#ifndef TVM_RELAY_ATTRS_CALL_H_
#define TVM_RELAY_ATTRS_CALL_H_

#include <tvm/ir/attrs.h>

#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR.
*/
struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
/*! \brief The metadata attached to the call node. */
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_CALL_H_
70 changes: 50 additions & 20 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
Expand All @@ -40,6 +41,7 @@
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
Expand Down Expand Up @@ -72,14 +74,31 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
AssignReturnSid(GetRef<Expr>(op));
}

void DeviceAwareVisitExpr_(const CallNode* op) final {
// create token for the call node.
VisitExpr(op->op);
CreateStorage(op);
for (Expr arg : op->args) {
void DeviceAwareVisitExpr_(const CallNode* call_node) final {
// AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case
// where the op of the call is a generic function

Expr func;
Array<Expr> args;

if (call_node->op == CallLoweredOp()) {
// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
func = func_and_args.first;
args = func_and_args.second;

} else {
ICHECK(call_node->op.as<FunctionNode>())
<< "Expect call to be call_lowered op or function node. ";
func = call_node->op;
args = call_node->args;
}
VisitExpr(func);
CreateStorage(call_node);
for (Expr arg : args) {
GetStorage(arg);
}
AssignReturnSid(GetRef<Expr>(op));
AssignReturnSid(GetRef<Expr>(call_node));
}

void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef<Expr>(op)); }
Expand Down Expand Up @@ -289,11 +308,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
/*!
* brief Call a function with a given name
*/
void CreateFuncCall(Call call, std::string func_name) {
void CreateFuncCall(const CallNode* call_node) {
Call call = GetRef<Call>(call_node);

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;

std::string func_name = func->name_hint;

tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;
// Pack the inputs
for (Expr arg : call->args) {

for (Expr arg : call_args) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[arg])});
Expand Down Expand Up @@ -371,21 +400,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
return ss.str();
}

void VisitExpr_(const CallNode* op) override {
void VisitExpr_(const CallNode* call_node) override {
// Descend the call tree
for (auto arg : op->args) {
ICHECK(call_node->op == CallLoweredOp()) << "Only expect call_lowered op at this point";

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;

for (auto arg : call_args) {
VisitExpr(arg);
}

if (op->op.as<OpNode>()) {
LOG(FATAL) << "Operators should be transformed away; try applying"
<< "the fuse_ops transformation to the expression.";
} else if (op->op.as<GlobalVarNode>()) {
GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
CreateFuncCall(GetRef<Call>(op), node->name_hint);
} else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
}
CreateFuncCall(call_node);
}

void VisitExpr_(const VarNode* op) override {
Expand Down Expand Up @@ -443,7 +471,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); }
void VisitExpr_(const OpNode* op) override {
LOG(FATAL) << "All OpNodes should have been expanded";
if (GetRef<Op>(op) != CallLoweredOp()) {
LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded";
}
}
void VisitExpr_(const IfNode* op) override {
LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called";
Expand Down
79 changes: 46 additions & 33 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <dmlc/json.h>
#include <tvm/ir/module.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/attrs/call.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
Expand All @@ -37,6 +38,7 @@
#include <vector>

#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
#include "./te_compiler.h"
#include "./utils.h"
Expand Down Expand Up @@ -403,64 +405,75 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
return lhs_storage_id == rhs_storage_id;
}

std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& func_name,
GraphAttrs attrs) {
std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) {
Call call = GetRef<Call>(call_node);
ICHECK(call->op == CallLoweredOp())
<< "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to be call_lowered, "
<< "but found: " << std::endl
<< PrettyPrint(call);

// Extract function and arguments from the call_lowered op
std::pair<GlobalVar, Array<Expr>> func_and_args = ExtractFunctionAndArgs(call_node);
GlobalVar func = func_and_args.first;
Array<Expr> call_args = func_and_args.second;

std::string func_name = func->name_hint;

std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
auto res = VisitExpr(arg);
for (auto nr : res) {
inputs.push_back(nr);
// Visit all the arguments to call_lowered
for (Expr arg : call_args) {
for (auto n : VisitExpr(arg)) {
inputs.push_back(n);
}
}

/// An adapted version of the storage optimization for the time being.
bool reshape_only = false;
if (op->attrs.defined()) {
if (auto tir_call_attrs = op->attrs.as<TIRCallAttrs>()) {
Map<String, ObjectRef> metadata = tir_call_attrs->metadata;
if (metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
reshape_only = true;
}

auto relay_attrs = Downcast<DictAttrs>(tir_call_attrs->metadata["relay_attrs"]);
ICHECK(call_node->attrs.defined()) << "Attrs should be defined!";
auto call_lowered_attrs = call_node->attrs.as<CallLoweredAttrs>();
ICHECK(call_lowered_attrs) << "Expected call_lowered to have CallLoweredAttrs";

// Need to check if this is an extern or not
Map<String, ObjectRef> metadata = call_lowered_attrs->metadata;
if (metadata.count(attr::kReshapeOnly) &&
Downcast<tvm::Integer>(metadata[attr::kReshapeOnly])->value == 1) {
reshape_only = true;
}

for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
}
if (!call_lowered_attrs->metadata.count(
"extern_func")) { // Extern funcs won't have relay attrs
// In main, I don't understand why this was running properly unless this was not actually
// called on the function?? Looks like maybe something is messed up with how call nodes are
// getting passed around. IDK tho
ICHECK(call_lowered_attrs->metadata.count("relay_attrs"))
<< "Expected there to be relay attrs stored in the metadata. ";
auto relay_attrs = Downcast<DictAttrs>(call_lowered_attrs->metadata["relay_attrs"]);
for (auto p : relay_attrs->dict) {
if (p.second.as<StringObj>()) {
attrs[p.first] = std::string(Downcast<String>(p.second));
}
}
}

if (reshape_only && ShareSameStorage(GetRef<Expr>(op), op->args[0])) {
if (reshape_only && ShareSameStorage(GetRef<Expr>(call_node), func)) {
auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
return AddNode(node, GetRef<Expr>(op));
return AddNode(node, call);
}

// Compute the operator name, because we used the get unique name when generating the kernel.
auto op_name = _GetUniqueName(func_name);
auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
return AddNode(node, GetRef<Expr>(op));
return AddNode(node, call);
}

std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
relay::Call call = GetRef<Call>(call_node);
auto props = GetOnDeviceProps(call_node);
if (props.body.defined()) {
// See through "on_device" calls.
return VisitExpr(props.body);
}

const auto* global_node = call->op.as<GlobalVarNode>();
ICHECK(global_node)
<< "Non-primitive-call nodes should have been transformed away.\n"
<< "The graph executor code generator expects all calls to have their callee "
"normalized to a GlobalVar, but found:"
<< std::endl
<< PrettyPrint(call);
auto prim_fn_name = global_node->name_hint;
return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs());
return GraphAddCallNode(call_node, GraphAttrs());
}

std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
Expand Down
Loading

0 comments on commit a379c27

Please sign in to comment.