-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change Call with TIRCallAttrs to call_lowered op #9312
Conversation
9435c6d
to
5ebbe78
Compare
e8c644a
to
831879c
Compare
@mbs-octoml PTAL, I'm pretty sure I've resolved all the issues, and this is ready for review. |
In the overview comment prob should show that args must be a tuple (even for the zero and one arity case) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Lily. Mostly nits and a few of my personal style preferences (feel free to ignore those), so don't panic! Otherwise:
- think we can tune up the Extract helper.
- a bit worries about hiding the lowered vs original forms of the operators (eg device_copy), see comment inline.
- feeling skittish about folding extern calls into the call_lowered convention just yet and suggest we back out of that.
Happy to look again on a PTAL.
include/tvm/relay/attrs/call.h
Outdated
*/ | ||
|
||
/*! | ||
* \file tvm/relay/attrs/annotation.h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update comment
src/relay/op/call/call.h
Outdated
namespace tvm { | ||
namespace relay { | ||
|
||
Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, Span span); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comments thanks!
src/relay/op/call/call.h
Outdated
|
||
Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, Span span); | ||
const Op& CallLoweredOp(); | ||
std::pair<GlobalVar, Array<Expr>> ExtractFunctionAndArgs(const CallNode* call_node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it's a lot of boilerplate i know, but a struct for the return result is probably best long term.
super duper nit: for on_device and device_copy i called these Get.... Do I get to claim precedence?
src/relay/op/call/call.cc
Outdated
const TypeReporter& reporter) { | ||
// Types = [func, args, ret_type] | ||
ICHECK_EQ(types.size(), 3u); | ||
auto func_type = types[0].as<FuncTypeNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return false for these checks since we know nothing about how the expression was constructed.
even for the arity check i think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and then can just the Assign for the TupleType do the check for types[1]
src/relay/op/call/call.cc
Outdated
|
||
const Op& CallLoweredOp() { return Op::Get("call_lowered"); } | ||
|
||
Expr CallLowered(Expr func, Expr inputs, Attrs attrs, Array<Type> type_args, Span span) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for symmetry with the Extract, inputs can be Array args.
You can std::move into the Call ctor (kinda a mico optimization, but good habit to get into to for when the perf really does matter).
src/relay/backend/te_compiler.cc
Outdated
return {ext_func->prim_fn_var, Attrs()}; | ||
auto call_lowered_attrs = make_object<CallLoweredAttrs>(); | ||
// Mark the function as a extern function so that AOT knows what to do with | ||
call_lowered_attrs->metadata.Set("extern_func", Integer(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the check in the graph_executor_codegen.cc but not aot_executor_codegen.cc. I wonder if we should leave the extern calls as regular old boring calls with no attributes and tackle them separately. Perhaps this method can return the call instead of pair so that the different call representations can be handled.
@@ -103,6 +107,20 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { | |||
} else { | |||
return {call_node->args[0], src_dev_type, dst_dev_type}; | |||
} | |||
} else if (call_node->op == CallLoweredOp()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit nervous about this one since there's a lot of usage patterns such as
props = GetProps
if (prop.body.defined()) { .... MakeDeviceCopy(...) ... }
which would silently demote call_lowered copies to regular copies.
I think safer would be to have a GetLoweredDeviceCopyProps so it's always clear which form is being matched against.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked with mark and we determined this is fine, he was just being a bit skittish
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this chunk of code is not necessary since I also deal with CallLoweredOp in GetPrimitiveDeviceCopyProps, I'm not sure why I put it in. Removing it
src/relay/op/vm/vm.cc
Outdated
@@ -22,8 +22,6 @@ | |||
* \brief Dialect operators for Relay VM. | |||
*/ | |||
|
|||
#include "vm.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: might as well keep it -- standard idiom even though not strictly needed.
if (tir_call_attrs->metadata.count("source_device") != 1 || | ||
tir_call_attrs->metadata.count("dst_device") != 1) { | ||
return {}; | ||
if (call_node->op == CallLoweredOp()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would become the GetLoweredDeviceCopyProps mentioned above, so no need for the helper here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we decided we don't need that actually so this is OK as is.
@@ -275,6 +275,7 @@ class DeviceDomains { | |||
const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); | |||
const Op& shape_of_op = Op::Get("vm.shape_of"); | |||
const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); | |||
const Op& call_lowered = Op::Get("call_lowered"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: your CallLoweredOp() helper is good enough.
some day all these will use the same idiom.
5861dc6
to
d5a33f9
Compare
@mbs-octoml I think this is ready for another review! |
Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op
fddc231
to
207c487
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two minor suggestions which are easily done in follow up. Thanks!
@@ -628,15 +639,13 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { | |||
// TODO(mbs): Replace device_type with target so this lookup is unnecessary. | |||
target = GetTargetFromInteger(device_type, targets_); | |||
} | |||
|
|||
Array<Expr> visited_args; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a correctness issue, but we are revisiting the args (bad rebase).
/* Get device props for a TIR function */ | ||
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); | ||
|
||
if (call_lowered_props.attrs.metadata.count("source_device") == 1 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confirming looks like we can avoid this here
(eventually GetPrimitiveDeviceCopyProps from device_domains.cc will need to get moved alongside this one (and renamed) but no need here)
Thanks @electriclilies @mbs-octoml |
@mbrookhart @mbs-octoml thanks! |
Followed up in #9491 |
* main: (119 commits) [Topi][Op][PyTorch][Vitas] Fix inconsistent kernel layout conventions for conv2d_transpose (apache#9336) Fix repository URL in ubuntu_install_rocm.sh (apache#9425) Add LLVM-13 installation to Docker setup (apache#9498) [Relay] Use target_host determined at Relay level instead of recalculating it (apache#9499) Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support (apache#9442) [COMMUNITY] Junru's and Wuwei's PGP key for ASF release (apache#9488) Add default for split op (apache#9489) [HOTFIX][TARGET] Change LOG in compilation config to DLOG (apache#9486) Fixed some warnings about lambda's closures that are bigger than necessary (apache#9481) [Support] Add libinfo into the runtime build (apache#9310) Change Call with TIRCallAttrs to call_lowered op (apache#9312) [ETHOSN] Streamline Ethos(TM)-N cross-compile rpc usage (apache#9477) [CMSIS-NN] Assert correct amount of CMSIS-NN artifacts in MLF (apache#9480) [MicroTVM][PyTest] Explicitly skip MicroTVM unittests. (apache#9335) [microNPU] Replace ICHECK with diagnostic context in type inference (apache#9470) Better host handling in CompilationConfig & debug printing (apache#9460) [AOT][Tests] Use pre-built libraries in Reference System tests (apache#9271) [TIR] Add type hint for TIR (apache#9432) [TVMC] Add test for quantized pytorch model (apache#9467) [CMSIS-NN] Convert CMSIS-NN to use Target Hooks (apache#9397) ...
* followups from #9312 * remove unneeded moves
* Introduce call_lowered op Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op * lint * Fix AOT to deal with externs * add const auto& * Fix aot crt test
* followups from apache#9312 * remove unneeded moves
* Introduce call_lowered op Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op * lint * Fix AOT to deal with externs * add const auto& * Fix aot crt test
* followups from apache#9312 * remove unneeded moves
* Introduce call_lowered op Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op * lint * Fix AOT to deal with externs * add const auto& * Fix aot crt test
* followups from apache#9312 * remove unneeded moves
* Introduce call_lowered op Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op * lint * Fix AOT to deal with externs * add const auto& * Fix aot crt test
* followups from apache#9312 * remove unneeded moves
* Introduce call_lowered op Add op vm.call_tir Change from checking if CallNode has CallTIRAttrs to checking if the Op is vm.call_tir Change device_domains to use vm.call_tir op more explicitly Fixed issue in type checker, now have seg fault :( Fix typo -- most of VM tests pass now Interpreter now deals with call_tir properly Fix typo in te_compiler Use InvokeTVMOp and CallTIR Add some checks to graph_plan_memory.cc Make GetToken skip function types C++ TESTS PASS WOOHOO Remove prints formatting vm.call_tir -> call_tir and more comment removals call_tir -> call_lowered fix lint clang format Remove compute from non computational vm ops missed some semicolons in prev commit Fix warning Move call_lowered to relay/op/call/call.cc and rename util func Add helper fn that returns lowered_call op fix import order clang format Add constraint to call_lowered type rel clean up empty token vector comment Move CallTIRAttrs to include/tvm/relay/attrs/call.h Rename TIRCallAttrs as CallLoweredAttrs lint Add helper for extracting func and args from call_lowered Change graph_executor_codegen to use helper function Update interpreter to use helper Fix device_domains.cc -- could still use cleanup, also I am not sure why there are still direct calls to primfns in DomainforCallee Clean up DeviceCopyProps and lint lint return CallLoweredAttrs with the extern func comment note in comment Progress & notes. Realized that I am not handling externs correctly not sure why this ever worked before? Clean up CreateFuncCall signature, notes comments Fix extern function handling extern_function -> extern_func fix DeviceAwareVisitExpr_ -- now it handles both lowered and normal calls yay passes AOT tests! formatting and comment removal cleanup Introduce call_lowered op * lint * Fix AOT to deal with externs * add const auto& * Fix aot crt test
* followups from apache#9312 * remove unneeded moves
In this PR, I change
Call
withTIRCallAttrs
to a relay op called call_lowered. (I also renameCallTIRAttrs
toCallLoweredAttrs
). This means that after lowering, calls are of the formCall("call_lowered", [fn: GlobalVar, args: Tuple] Attrs(CallLoweredAttrs))
.One benefit to this approach is it is easy to identify calls to lowered functions-- instead of checking the type of the call's attributes, we can just check what the op is. This is especially helpful for passes that run both before an after lowering, like device planning and memory allocation.
I left the calling convention for extern functions as-is. A codegened extern call looks like this:
Call(fn: GlobalVar, [arg1: Expr, ... argn: Expr], Attrs())
with null attributes. In the future we will probably need to change the extern calling convention to have attributes that store information like the target, but that is outside the scope of this work.