From 9efc217a41af1a9f5662e3a270d9e3a6e2d4f3b8 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 10 Nov 2021 15:36:47 -0800 Subject: [PATCH] followups from #9312 --- src/relay/backend/te_compiler.cc | 11 ++++------- src/relay/op/memory/device_copy.cc | 13 ------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 915fc22b20528..418581bffd01b 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -617,15 +617,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } // Similarly transform arguments. - Array args; + Array visited_args; for (const auto& arg : call_node->args) { - args.push_back(VisitExpr(arg)); + visited_args.push_back(VisitExpr(arg)); } // Already lowered by other means so we don't need to mutate // the call but we do need to mutate the arguments if (prim_func->IsInstance()) { - return Call(call_node->op, args, call_node->attrs); + return Call(call_node->op, visited_args, call_node->attrs); } // Find the desired target device. @@ -639,10 +639,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // TODO(mbs): Replace device_type with target so this lookup is unnecessary. target = GetTargetFromInteger(device_type, targets_); } - Array visited_args; - for (const auto& arg : call_node->args) { - visited_args.push_back(VisitExpr(arg)); - } + // Lower the primitive function for that target. Function func = Downcast(prim_func); return MakeLoweredCall(func, visited_args, call_node->type_args, call_node->span, target); diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index 9106b95c92171..9a3df7a4f91f9 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -107,19 +107,6 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { } else { return {call_node->args[0], src_dev_type, dst_dev_type}; } - } else if (call_node->op == CallLoweredOp()) { - /* Get device props for a TIR function */ - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - - if (call_lowered_props.attrs.metadata.count("source_device") == 1 && - call_lowered_props.attrs.metadata.count("dst_device") == 1) { - ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; - return {call_lowered_props.lowered_func, - static_cast( - Downcast(call_lowered_props.attrs.metadata["source_device"])->value), - static_cast( - Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; - } } return {}; }