From 1075677fe1cb27176573834c8629bd2c0b9dc685 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 10 Nov 2021 14:20:01 -0800 Subject: [PATCH] Change Call with TIRCallAttrs to call_lowered op (#9312) * Introduce call_lowered op Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op * lint * Fix AOT to deal with externs * add const auto& * Fix aot crt test --- include/tvm/relay/attrs/annotation.h | 11 -- include/tvm/relay/attrs/call.h | 48 +++++ src/relay/backend/aot_executor_codegen.cc | 77 +++++--- .../example_target_hooks/relay_to_tir.cc | 12 +- src/relay/backend/graph_executor_codegen.cc | 89 +++++---- src/relay/backend/graph_plan_memory.cc | 52 +++--- src/relay/backend/interpreter.cc | 152 +++++++++------- src/relay/backend/te_compiler.cc | 169 +++++++++--------- src/relay/op/call/call.cc | 116 ++++++++++++ src/relay/op/call/call.h | 74 ++++++++ src/relay/op/memory/device_copy.cc | 17 ++ src/relay/op/vm/vm.h | 2 +- src/relay/transforms/device_domains.cc | 33 ++-- src/relay/transforms/memory_alloc.cc | 16 +- 14 files changed, 603 insertions(+), 265 deletions(-) create mode 100644 include/tvm/relay/attrs/call.h create mode 100644 src/relay/op/call/call.cc create mode 100644 src/relay/op/call/call.h diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 85ac3f36ff60..f88ca8ef6380 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -116,17 +116,6 @@ struct CompilerAttrs : public tvm::AttrsNode { } }; -/*! - * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. - */ -struct TIRCallAttrs : public tvm::AttrsNode { - /*! \brief The metadata attached to the call node. */ - Map metadata; - - TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") { - TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call."); - } -}; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h new file mode 100644 index 000000000000..2b02c6a5edac --- /dev/null +++ b/include/tvm/relay/attrs/call.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/call.h + * \brief Attribute for call_lowered operator. + */ +#ifndef TVM_RELAY_ATTRS_CALL_H_ +#define TVM_RELAY_ATTRS_CALL_H_ + +#include + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. + */ +struct CallLoweredAttrs : public tvm::AttrsNode { + /*! \brief The metadata attached to the call node. */ + Map metadata; + + TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { + TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call."); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_CALL_H_ diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 7e5702296542..58bcccf90879 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -40,6 +41,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler.h" #include "./utils.h" @@ -72,14 +74,34 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { AssignReturnSid(GetRef(op)); } - void DeviceAwareVisitExpr_(const CallNode* op) final { - // create token for the call node. - VisitExpr(op->op); - CreateStorage(op); - for (Expr arg : op->args) { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { + // AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case + // where the op of the call is a generic function + + Expr func; + Array args; + + if (call_node->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + func = call_lowered_props.lowered_func; + args = call_lowered_props.arguments; + } else { // Relay functions that have not been lowered and lowered extern functions + func = call_node->op; + args = call_node->args; + if (call_node->op.as()) { // Lowered extern function + ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; + } else { // Relay function which has not been lowered yet + ICHECK(call_node->op.as()) + << "Expected the call to be to a lowered primfunc, a lowered extern function or a " + "unlowered Relay function."; + } + } + VisitExpr(func); + CreateStorage(call_node); + for (const Expr& arg : args) { GetStorage(arg); } - AssignReturnSid(GetRef(op)); + AssignReturnSid(GetRef(call_node)); } void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } @@ -287,13 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Call a function with a given name + * brief Create a function call + * \param call_lowered_props The lowered function and the arguments to call it with + * \param call The call we got func and args from */ - void CreateFuncCall(Call call, std::string func_name) { + void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) { + std::string func_name = call_lowered_props.lowered_func->name_hint; + tvm::Array args{tvm::tir::StringImm(func_name)}; std::vector create_func_call_stmts; + // Pack the inputs - for (Expr arg : call->args) { + for (const Expr& arg : call_lowered_props.arguments) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[arg])}); @@ -371,21 +398,25 @@ class AOTExecutorCodegen : public MixedModeVisitor { return ss.str(); } - void VisitExpr_(const CallNode* op) override { + void VisitExpr_(const CallNode* call_node) override { // Descend the call tree - for (auto arg : op->args) { - VisitExpr(arg); - } - - if (op->op.as()) { - LOG(FATAL) << "Operators should be transformed away; try applying" - << "the fuse_ops transformation to the expression."; - } else if (op->op.as()) { - GlobalVar node = GetRef(op->op.as()); - CreateFuncCall(GetRef(op), node->name_hint); + CallLoweredProps call_lowered_props; + if (const auto* gvn = call_node->op.as()) { // Lowered extern function + ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; + for (const auto& arg : call_node->args) { + VisitExpr(arg); + } + call_lowered_props = CallLoweredProps{GetRef(gvn), call_node->args, {}}; } else { - LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); + ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try " + "applying the fuse_ops transformation to the " + "expression."; + call_lowered_props = GetCallLoweredProps(call_node); + for (const auto& arg : call_lowered_props.arguments) { + VisitExpr(arg); + } } + CreateFuncCall(call_lowered_props, GetRef(call_node)); } void VisitExpr_(const VarNode* op) override { @@ -443,7 +474,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { } void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } void VisitExpr_(const OpNode* op) override { - LOG(FATAL) << "All OpNodes should have been expanded"; + if (GetRef(op) != CallLoweredOp()) { + LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; + } } void VisitExpr_(const IfNode* op) override { LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index cae20210ec4f..c41399e314ef 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -17,14 +17,18 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include +#include #include #include #include #include #include +#include "../../../op/call/call.h" + namespace tvm { namespace relay { namespace contrib { @@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator { GlobalVar new_global_var(func_name.value()); new_global_var->checked_type_ = func->checked_type(); ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); + + // Since we are replacing the Relay function with a call to a TIR function, we must use the + // call_lowered op. + auto call_lowered_attrs = make_object(); + call_lowered_attrs->metadata.Set("relay_attrs", call->attrs); + return CallLowered(std::move(new_global_var), call->args, + std::move(Attrs(call_lowered_attrs)), call->type_args, call->span); } } diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index d32ded379688..ac3c835ed648 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler.h" #include "./utils.h" @@ -403,64 +405,75 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, const std::string& func_name, - GraphAttrs attrs) { + std::vector GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) { + Call call = GetRef(call_node); std::vector inputs; - for (auto arg : op->args) { - auto res = VisitExpr(arg); - for (auto nr : res) { - inputs.push_back(nr); - } - } + std::string func_name; - /// An adapted version of the storage optimization for the time being. - bool reshape_only = false; - if (op->attrs.defined()) { - if (auto tir_call_attrs = op->attrs.as()) { - Map metadata = tir_call_attrs->metadata; - if (metadata.count(attr::kReshapeOnly) && - Downcast(metadata[attr::kReshapeOnly])->value == 1) { - reshape_only = true; - } + if (call->op == CallLoweredOp()) { + // Extract function and arguments from the call_lowered op + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - auto relay_attrs = Downcast(tir_call_attrs->metadata["relay_attrs"]); + func_name = call_lowered_props.lowered_func->name_hint; - for (auto p : relay_attrs->dict) { - if (p.second.as()) { - attrs[p.first] = std::string(Downcast(p.second)); + for (const Expr& arg : call_lowered_props.arguments) { + for (auto n : VisitExpr(arg)) { + inputs.push_back(n); + } + } + if (call_lowered_props.attrs.metadata.count("relay_attrs")) { + if (auto relay_attrs = + call_lowered_props.attrs.metadata["relay_attrs"].as()) { + for (auto p : relay_attrs->dict) { + if (p.second.as()) { + attrs[p.first] = std::string(Downcast(p.second)); + } } } } - } - - if (reshape_only && ShareSameStorage(GetRef(op), op->args[0])) { - auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); - return AddNode(node, GetRef(op)); + bool reshape_only = false; + if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) && + Downcast(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value == + 1) { + reshape_only = true; + } + if (reshape_only && + ShareSameStorage(GetRef(call_node), call_lowered_props.arguments[0])) { + auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); + return AddNode(node, call); + } + } else if (!call_node->attrs.defined()) { // Call is an extern function + std::cout << "call_node: \n" << PrettyPrint(call) << std::endl; + const auto* func = call_node->op.as(); + ICHECK(func) << "Expected the operator to be a global var, but got " + << call_node->op->GetTypeKey(); // getting a relay fn here, not sure why. + func_name = func->name_hint; + + for (const Expr& arg : call_node->args) { + for (auto n : VisitExpr(arg)) { + inputs.push_back(n); + } + } + } else { + LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to be call_lowered, " + << "but found: " << std::endl + << PrettyPrint(call); } // Compute the operator name, because we used the get unique name when generating the kernel. auto op_name = _GetUniqueName(func_name); auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); - return AddNode(node, GetRef(op)); + return AddNode(node, call); } std::vector VisitExpr_(const CallNode* call_node) override { - relay::Call call = GetRef(call_node); auto props = GetOnDeviceProps(call_node); if (props.body.defined()) { // See through "on_device" calls. return VisitExpr(props.body); } - - const auto* global_node = call->op.as(); - ICHECK(global_node) - << "Non-primitive-call nodes should have been transformed away.\n" - << "The graph executor code generator expects all calls to have their callee " - "normalized to a GlobalVar, but found:" - << std::endl - << PrettyPrint(call); - auto prim_fn_name = global_node->name_hint; - return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); + return GraphAddCallNode(call_node, GraphAttrs()); } std::vector VisitExpr_(const LetNode* op) override { diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 961252a14fa7..4031dfdcd6e7 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include #include @@ -32,6 +33,7 @@ #include "../../support/arena.h" #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../op/memory/memory.h" #include "../transforms/device_aware_visitors.h" #include "./utils.h" @@ -139,6 +141,8 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { protected: /*! \brief internal token map */ std::unordered_map> token_map_; + /*! \brief empty token map */ + const std::vector no_tokens_; /*! * \brief Get the necessary token. @@ -146,6 +150,11 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ const std::vector& GetToken(const Expr& expr) { + this->VisitExpr(expr); + // Functions don't require data storage, represented by the empty token + if (expr->checked_type().as()) { + return no_tokens_; + } // See through on_device calls. Expr real_expr = IgnoreOnDevice(expr); this->VisitExpr(real_expr); @@ -159,8 +168,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding * the result of evaluating \p op. */ - void CreateToken(const ExprNode* op, bool can_realloc) { - return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef(op)), can_realloc); + void CreateToken(const ExprNode* expr_node, bool can_realloc) { + return CreateTokenOnDevice(expr_node, GetInScopeDeviceType(GetRef(expr_node)), + can_realloc); } /*! @@ -203,12 +213,12 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; - void DeviceAwareVisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { // create token for the call node. - CreateToken(op, true); + CreateToken(call_node, true); // for each input, visit argument token. - for (Expr arg : op->args) { + for (Expr arg : call_node->args) { for (StorageToken* tok : GetToken(arg)) { tok->ref_counter += 1; } @@ -273,7 +283,6 @@ class StorageAllocator : public StorageAllocaBaseVisitor { << "expressions are assigned with virtual device types. Either all " "or none of the expressions are expected to be annotated."; } - return backend::StaticMemoryPlan(smap); } @@ -320,10 +329,13 @@ class StorageAllocator : public StorageAllocaBaseVisitor { using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; // The call map - void DeviceAwareVisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { std::vector args; // for each input, visit argument token. - for (Expr arg : op->args) { + + for (const Expr& arg : call_node->args) { + // Note: GetToken skips GlobalVars and handles tuples properly, so we don't need to treat + // call_lowered specially. for (StorageToken* tok : GetToken(arg)) { args.push_back(tok); } @@ -337,20 +349,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // // TODO(tvm-team) Update checks of flat memory enablement when we support // opaque-nd memory planning to skip this path. - if (IsReshape(op)) { - // TODO(@electriclilies, jroesch): This check is failing because the size of args is 3 - // I can't figure out where the extra args are coming from, I assume it must be related - // to the relay_attrs field we added to the TIRCallArgs, but I don't know where / how - // that's happening... + + if (IsReshape(call_node)) { ICHECK_EQ(args.size(), 1U); - ReuseInputToken(op, args[0]); + ReuseInputToken(call_node, args[0]); } else { // create token for the call node. - CreateToken(op, true); + CreateToken(call_node, true); } // check if there is orphaned output that can be released immediately. - for (StorageToken* tok : token_map_.at(op)) { + for (StorageToken* tok : token_map_.at(call_node)) { CheckForRelease(tok); } for (StorageToken* tok : args) { @@ -376,12 +385,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { return fn->HasNonzeroAttr(attr::kReshapeOnly); } - if (call->attrs.defined()) { - if (auto tir_call_attrs = call->attrs.as()) { - Map metadata = tir_call_attrs->metadata; - return metadata.count(attr::kReshapeOnly) && - (Downcast(metadata[attr::kReshapeOnly])->value == 1); - } + if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call); + Map metadata = call_lowered_props.attrs.metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); } return false; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 13b855624461..4835d7618a2e 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/pass_utils.h" #include "te_compiler.h" @@ -682,82 +684,94 @@ class Interpreter : public ExprFunctor, } ObjectRef VisitExpr_(const CallNode* call_node) final { - std::vector args; - for (auto arg : call_node->args) { - args.push_back(Eval(arg)); - } + if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function. + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - if (call_node->op == OnDeviceOp()) { - // Special case: The call 'on_device(expr)' denotes that expr should be executed on - // a particular device. We can ignore this during interpretation. - ICHECK_EQ(call_node->args.size(), 1UL); - return args[0]; - } + // Evaluate only function args + std::vector args; + for (auto arg : call_lowered_props.arguments) { + args.push_back(Eval(arg)); + } - // We should not find calls to operators after running fusion and lowering. - if (const OpNode* op_node = call_node->op.as()) { - LOG(FATAL) << "found " << op_node->name - << "; operators should have been removed by previous passes; try " - "fusing and lowering"; - } + // TODO(mbs): Make calling convention first-class in Relay. + Array all_prim_fn_vars; + if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars")) { + all_prim_fn_vars = + Downcast>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars")); + } + GlobalVar prim_shape_fn_var; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var")) { + prim_shape_fn_var = + Downcast(call_lowered_props.attrs.metadata.at("prim_shape_fn_var")); + } + Array all_prim_shape_fn_vars; + if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars")) { + all_prim_shape_fn_vars = Downcast>( + call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars")); + } + Array prim_shape_fn_states; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states")) { + prim_shape_fn_states = + Downcast>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states")); + } - if (const ConstructorNode* con = call_node->op.as()) { - // Special case: ADT constructor - return ConstructorValue(con->tag, args, GetRef(con)); - } + size_t num_shape_inputs = 0; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs")) { + num_shape_inputs = static_cast( + Downcast(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs")) + ->value); + } + size_t num_shape_outputs = 0; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs")) { + num_shape_outputs = static_cast( + Downcast(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs")) + ->value); + } + ICHECK(config_->optional_homogeneous_target.defined()); + return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars, + config_->optional_homogeneous_target, prim_shape_fn_var, + all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, + num_shape_outputs, config_->host_se_scope->target, args); + } else { // All other calls + // Evaluate all arguments + std::vector args; + for (auto arg : call_node->args) { + args.push_back(Eval(arg)); + } - if (const GlobalVarNode* gvn = call_node->op.as()) { - if (const TIRCallAttrs* attrs = call_node->attrs.as()) { - // Special case: Call a lowered TIR function. - // TODO(mbs): Make calling convention first-class in Relay. - Array all_prim_fn_vars; - if (attrs->metadata.count("all_prim_fn_vars")) { - all_prim_fn_vars = Downcast>(attrs->metadata.at("all_prim_fn_vars")); - } - GlobalVar prim_shape_fn_var; - if (attrs->metadata.count("prim_shape_fn_var")) { - prim_shape_fn_var = Downcast(attrs->metadata.at("prim_shape_fn_var")); - } - Array all_prim_shape_fn_vars; - if (attrs->metadata.count("all_prim_shape_fn_vars")) { - all_prim_shape_fn_vars = - Downcast>(attrs->metadata.at("all_prim_shape_fn_vars")); - } - Array prim_shape_fn_states; - if (attrs->metadata.count("prim_shape_fn_states")) { - prim_shape_fn_states = - Downcast>(attrs->metadata.at("prim_shape_fn_states")); - } - size_t num_shape_inputs = 0; - if (attrs->metadata.count("prim_shape_fn_num_inputs")) { - num_shape_inputs = static_cast( - Downcast(attrs->metadata.at("prim_shape_fn_num_inputs"))->value); - } - size_t num_shape_outputs = 0; - if (attrs->metadata.count("prim_shape_fn_num_outputs")) { - num_shape_outputs = static_cast( - Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); - } + if (call_node->op == OnDeviceOp()) { + // Special case: The call 'on_device(expr)' denotes that expr should be executed on + // a particular device. We can ignore this during interpretation. + ICHECK_EQ(call_node->args.size(), 1UL); + return args[0]; + } + if (const ConstructorNode* con = call_node->op.as()) { + // Special case: ADT constructor - ICHECK(config_->optional_homogeneous_target.defined()); - return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, - config_->optional_homogeneous_target, prim_shape_fn_var, - all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, - num_shape_outputs, config_->host_se_scope->target, args); + return ConstructorValue(con->tag, args, GetRef(con)); } - } - // Now we just evaluate and expect to find a closure. - ObjectRef fn_val = Eval(call_node->op); - if (const InterpreterClosureObj* closure_node = fn_val.as()) { - auto closure = GetRef(closure_node); - return Invoke(closure, args); - } else if (const RecClosureObj* closure_node = fn_val.as()) { - return Invoke(closure_node->clos, args, closure_node->bind); - } else { - LOG(FATAL) << "internal error: type error, expected function value in the call " - << "position"; - return ObjectRef(); + if (const OpNode* op_node = call_node->op.as()) { + // Except for call_lowered and on_device, we should not find calls to operators after + // running fusion and lowering. + LOG(FATAL) << "found " << op_node->name + << "; operators should have been removed by previous passes; try " + "fusing and lowering"; + } + + // Now we just evaluate and expect to find a closure. + // TODO(@electriclilies): How should call_lowered behave with closures? + ObjectRef fn_val = Eval(call_node->op); + if (const InterpreterClosureObj* closure_node = fn_val.as()) { + auto closure = GetRef(closure_node); + return Invoke(closure, args); + } else if (const RecClosureObj* closure_node = fn_val.as()) { + return Invoke(closure_node->clos, args, closure_node->bind); + } else { + LOG(FATAL) << "internal error: type error, expected function value in the call " + << "position"; + return ObjectRef(); + } } } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 163bb9f71f9c..915fc22b2052 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler_cache.h" #include "./utils.h" @@ -460,7 +462,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { * to the TIR implementation, and attributes to attach to the call to identify it as * a TIR call. */ - std::pair LowerFunction(Function func, Target target) { + Expr MakeLoweredCall(Function func, Array visited_args, Array type_args, Span span, + Target target) { if (func->GetAttr(attr::kCompiler).defined()) { // BYOC flow. CCacheKey key = CCacheKey(func, target); @@ -468,6 +471,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { ICHECK(ext_func.defined()) << "Lowering returned undefined function for " << ext_func->prim_fn_var->name_hint; + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT Map prim_fns; relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); @@ -478,87 +482,91 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // act when we process a function. this->process_fn_(func_with_metadata); - // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an extern. // TODO(mbs): Dynamic shapes? - return {ext_func->prim_fn_var, Attrs()}; - } + // TODO(@mbs, electriclilies): Make extern functions explicit + return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, span); - // Non-External Relay Function - VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); - CCacheKey key = CCacheKey(func, target); - CachedFunc lowered_func = compiler_->Lower(key, module_name_); - VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; - - // Collect all the lowered functions produced for this primitive function. - Map prim_fns; - Array all_prim_fn_vars; - for (auto prim_fn : lowered_func->funcs->functions) { - CHECK(prim_fn.second.as()) << "must be a prim fn"; - prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); - all_prim_fn_vars.push_back(prim_fn.first); - VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; - } + } else { + // Non-External Relay Function + VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" + << PrettyPrint(func); + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compiler_->Lower(key, module_name_); + VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; - // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT - relay::Function func_with_metadata = func; - func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); - func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); + // Collect all the lowered functions produced for this primitive function. + Map prim_fns; + Array all_prim_fn_vars; + for (auto prim_fn : lowered_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + all_prim_fn_vars.push_back(prim_fn.first); + VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; + } - // Provide a callback hook which allows one-level up code generators to - // act when we process a function. - this->process_fn_(func_with_metadata); + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); - auto tir_call_attrs = make_object(); - if (func->HasNonzeroAttr(attr::kReshapeOnly)) { - tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); - } + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn_(func_with_metadata); - auto device_copy = IsDeviceCopy(func); - if (std::get<0>(device_copy)) { - // Record that device copy source and destination devices so the device planner can - // still follow along. - auto source_device = std::get<1>(device_copy); - auto dst_device = std::get<2>(device_copy); - tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); - tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); - } + auto call_lowered_attrs = make_object(); + if (func->HasNonzeroAttr(attr::kReshapeOnly)) { + call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + } - tir_call_attrs->metadata.Set("relay_attrs", func->attrs); - tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); - - if (IsDynamic(func->ret_type)) { - // Also lower the dynamic shape function. - // Shape function keys use the underlying primitive function as their 'function', - // but the generic 'cpu' target as the target since all shape functions run - // on the host cpu irrespective of where the primitive runs. - // TODO(mbs): Cleanup target handling. - Target shape_target("llvm"); - VLOG(1) << "lowering to target '" << shape_target->str() - << "' for dynamic shape function for primitive"; - CCacheKey shape_key(func, shape_target); - CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); - // Capture the shape function's global var and parameters 'states' in call - // annotations so calling convention can be recovered. - // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. - // The way the shape function calling convention is derived and passed to call sites - // via the 'parameter states' could be improved. - tir_call_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - tir_call_attrs->metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - tir_call_attrs->metadata.Set("prim_shape_fn_num_inputs", - Integer(static_cast(lowered_shape_func->inputs.size()))); - tir_call_attrs->metadata.Set("prim_shape_fn_num_outputs", - Integer(static_cast(lowered_shape_func->outputs.size()))); - Array all_prim_shape_fn_vars; - for (auto prim_shape_fn : lowered_shape_func->funcs->functions) { - CHECK(prim_shape_fn.second.as()) << "must be a prim fn"; - all_prim_shape_fn_vars.push_back(prim_shape_fn.first); + auto device_copy = IsDeviceCopy(func); + if (std::get<0>(device_copy)) { + // Record that device copy source and destination devices so the device planner can + // still follow along. + auto source_device = std::get<1>(device_copy); + auto dst_device = std::get<2>(device_copy); + call_lowered_attrs->metadata.Set("source_device", tvm::Integer(source_device)); + call_lowered_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); } - tir_call_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); - } - return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)}; + call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + + if (IsDynamic(func->ret_type)) { + // Also lower the dynamic shape function. + // Shape function keys use the underlying primitive function as their 'function', + // but the generic 'cpu' target as the target since all shape functions run + // on the host cpu irrespective of where the primitive runs. + // TODO(mbs): Cleanup target handling. + Target shape_target("llvm"); + VLOG(1) << "lowering to target '" << shape_target->str() + << "' for dynamic shape function for primitive"; + CCacheKey shape_key(func, shape_target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. + // The way the shape function calling convention is derived and passed to call sites + // via the 'parameter states' could be improved. + call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs->metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs->metadata.Set( + "prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs->metadata.Set( + "prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (auto prim_shape_fn : lowered_shape_func->funcs->functions) { + CHECK(prim_shape_fn.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(prim_shape_fn.first); + } + call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + } + return CallLowered(lowered_func->prim_fn_var, visited_args, Attrs(call_lowered_attrs), + type_args, span); + } } std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { @@ -593,6 +601,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { + // Passes before lowering might insert a call_lowered to call a function that has already + // been lowered. Therefore we might see call_lowered ops here, but we don't need to do anything + // because ResolveToPrimitive returns null for all calls where the call_node->op is an OpNode Call call = GetRef(call_node); // Look for (indirect) calls to primitives. @@ -628,15 +639,13 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // TODO(mbs): Replace device_type with target so this lookup is unnecessary. target = GetTargetFromInteger(device_type, targets_); } - + Array visited_args; + for (const auto& arg : call_node->args) { + visited_args.push_back(VisitExpr(arg)); + } // Lower the primitive function for that target. Function func = Downcast(prim_func); - std::pair pair = LowerFunction(func, target); - - // Replace with direct call to lowered primitive, and attach annotations to record calling - // convention. - // =====> in new call_lowered form - return Call(pair.first, args, pair.second); + return MakeLoweredCall(func, visited_args, call_node->type_args, call_node->span, target); } IRModule module_; diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc new file mode 100644 index 000000000000..9485b72d8374 --- /dev/null +++ b/src/relay/op/call/call.cc @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/call/call.cc + * \brief Operators for calling lowered functions. + */ + +#include "./call.h" + +#include +#include +#include +#include + +#include "../../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(CallLoweredAttrs); + +// call_lowered +bool CallLoweredRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Types = [func, call_args, ret_type] + if (types.size() != 3u) { + return false; + } + const auto* func_type = types[0].as(); + if (!func_type) { + return false; + } + + const auto* tuple_type_node = types[1].as(); + if (!tuple_type_node) { + return false; + } + + // Constraint to ensure function arguments are the same type as the inputs to the function (modulo + // the Tuple wrapper) + reporter->Assign(GetRef(tuple_type_node), TupleType(func_type->arg_types, {})); + // Constraint to ensure the output of call_lowered is the same as the function's return type + reporter->Assign(types[2], func_type->ret_type); + return true; +} + +const Op& CallLoweredOp() { return Op::Get("call_lowered"); } + +Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span) { + // Right now, call_lowered only supports func being a global var pointing to the lowered + // function. + ICHECK(func.as()) + << "Function to call should be GlobalVarNode, but got " << func->GetTypeKey(); + ICHECK(attrs.as()) + << "Expected attributes to be CallLoweredAttrs, but got " << attrs->GetTypeKey(); + return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, std::move(attrs), + std::move(type_args), std::move(span)); +} + +TVM_REGISTER_GLOBAL("relay.op.call_lowered") + .set_body_typed([](Expr func, Array inputs, Attrs attrs, Array type_args, + Span span) { + const TupleNode* tuple_node = inputs.as(); + return CallLowered(func, tuple_node->fields, attrs, type_args, span); + }); + +RELAY_REGISTER_OP("call_lowered") + .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("func", "Function", "The lowered function to call.") + .add_argument("call_args", "Tuple", "The input tensors.") + .add_type_rel("CallLoweredRel", CallLoweredRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { + ICHECK(call_node->op == CallLoweredOp()) + << "GetCallLoweredProps expects the op to be call_lowered. "; + ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments. "; + const auto* function = call_node->args[0].as(); + ICHECK(function) << "Expected first arg to call_lowered to be a GlobalVar. "; + + const auto* tuple_args = call_node->args[1].as(); + ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple. "; + + ICHECK(call_node->attrs.defined()) << "Attributes for call_lowered should be defined!"; + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " + << call_node->attrs->GetTypeKey(); + return CallLoweredProps{std::move(GetRef(function)), std::move(tuple_args->fields), + std::move(*attrs)}; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h new file mode 100644 index 000000000000..381be6724e0d --- /dev/null +++ b/src/relay/op/call/call.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/call/call.h + * \brief Operators for calling lowered functions. + */ +#ifndef TVM_RELAY_OP_CALL_CALL_H_ +#define TVM_RELAY_OP_CALL_CALL_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Helper to construct a Relay call with the call_lowered op. + * \param func Lowered function to call with call_lowered. + * \param inputs Arguments to be passed to the function. + * \param attrs Function attributes, should be TIRCallAttrs. + * \param type_args Type arguments for the call. + * \param span TVM span for propogating debugging info. + * \return + */ +Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span); + +/*! + * \brief Returns the Relay call_lowered op. Use this helper to avoid extraneous calls to + * Registry::Get. + */ +const Op& CallLoweredOp(); + +/*! + * \brief Lowered function and the arguments to call it with. + */ +struct CallLoweredProps { + /*! \brief Global variable pointing to the lowered function. */ + GlobalVar lowered_func; + /*! \brief Array of the arguments to call lowered_func with. */ + Array arguments; + /*! \brief Arguments from the call_lowered op. */ + CallLoweredAttrs attrs; +}; + +/*! + * \brief Helper to extract the lowered function and its arguments from Call("call_lowered", ...). + * Will fail if called on a Call whose op is not "call_lowered" \param call_node CallNode that we + * want to get the function and its arguments from. + */ +CallLoweredProps GetCallLoweredProps(const CallNode* call_node); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_CALL_CALL_H_ diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index dce89aa91b65..9106b95c9217 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -24,6 +24,7 @@ #include "./device_copy.h" +#include #include #include #include @@ -31,6 +32,8 @@ #include #include "../../transforms/infer_layout_utils.h" +#include "../annotation/annotation.h" +#include "../call/call.h" #include "../type_relations.h" namespace tvm { @@ -86,6 +89,7 @@ on different devices. return {topi::identity(inputs[0])}; }); +// Get device copy props for original device copy op DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { if (call_node->op == DeviceCopyOp()) { ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument"; @@ -103,6 +107,19 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { } else { return {call_node->args[0], src_dev_type, dst_dev_type}; } + } else if (call_node->op == CallLoweredOp()) { + /* Get device props for a TIR function */ + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + + if (call_lowered_props.attrs.metadata.count("source_device") == 1 && + call_lowered_props.attrs.metadata.count("dst_device") == 1) { + ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; + return {call_lowered_props.lowered_func, + static_cast( + Downcast(call_lowered_props.attrs.metadata["source_device"])->value), + static_cast( + Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; + } } return {}; } diff --git a/src/relay/op/vm/vm.h b/src/relay/op/vm/vm.h index 802c8100125a..68d25b097bce 100644 --- a/src/relay/op/vm/vm.h +++ b/src/relay/op/vm/vm.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_OP_VM_VM_H_ #define TVM_RELAY_OP_VM_VM_H_ -#include "tvm/relay/expr.h" +#include namespace tvm { namespace relay { diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 15784856edbf..b9fa0494d3b5 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -24,9 +24,11 @@ #include "./device_domains.h" +#include #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../op/memory/device_copy.h" namespace tvm { @@ -47,20 +49,19 @@ constexpr size_t mix(size_t h1, size_t h2) { * See te_compiler.cc for where this rewriting occurs. */ DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { - auto tir_call_attrs = call_node->attrs.as(); - if (tir_call_attrs == nullptr) { - return {}; - } - if (tir_call_attrs->metadata.count("source_device") != 1 || - tir_call_attrs->metadata.count("dst_device") != 1) { - return {}; + if (call_node->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + if (call_lowered_props.attrs.metadata.count("source_device") == 1 && + call_lowered_props.attrs.metadata.count("dst_device") == 1) { + ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; + return {call_lowered_props.arguments[0], + static_cast( + Downcast(call_lowered_props.attrs.metadata["source_device"])->value), + static_cast( + Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; + } } - ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; - return { - call_node->args[0], - static_cast( - Downcast(tir_call_attrs->metadata["source_device"])->value), - static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; + return {}; } } // namespace @@ -319,8 +320,12 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); + } else if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); + return DomainFor(call_lowered_props.lowered_func); } else { - // Defer to normal case where op can be an arbitrary expression. + // We still need to handle the case where the function / op is not lowered + // because the device planner runs before and after lowering. return DomainFor(call->op); } auto domain = MakeDomain(std::move(args_and_result)); diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 81d704e2be8e..a328eaa82aa2 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ #include "../backend/te_compiler.h" #include "../backend/te_compiler_cache.h" #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../op/memory/device_copy.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" @@ -74,12 +76,11 @@ bool IsReshapeOnly(const Expr& expr) { return func->HasNonzeroAttr(attr::kReshapeOnly); } if (const CallNode* call = expr.as()) { - if (call->attrs.defined()) { - if (auto tir_call_attrs = call->attrs.as()) { - Map metadata = tir_call_attrs->metadata; - return metadata.count(attr::kReshapeOnly) && - (Downcast(metadata[attr::kReshapeOnly])->value == 1); - } + if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call); + Map metadata = call_lowered_props.attrs.metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); } } return false; @@ -377,7 +378,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } Tuple tuple_outs(outs); - auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), dev.device_type, /*is_fixed=*/true); + auto call = InvokeTVMOp(func, ins, tuple_outs); + auto invoke = OnDevice(call, dev.device_type, /*is_fixed=*/true); scope->Push(invoke); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end()));