Skip to content

Commit

Permalink
followups from #9312
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Nov 10, 2021
1 parent 0812c07 commit 9efc217
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 20 deletions.
11 changes: 4 additions & 7 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,15 +617,15 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
}

// Similarly transform arguments.
Array<Expr> args;
Array<Expr> visited_args;
for (const auto& arg : call_node->args) {
args.push_back(VisitExpr(arg));
visited_args.push_back(VisitExpr(arg));
}

// Already lowered by other means so we don't need to mutate
// the call but we do need to mutate the arguments
if (prim_func->IsInstance<tir::PrimFuncNode>()) {
return Call(call_node->op, args, call_node->attrs);
return Call(call_node->op, visited_args, call_node->attrs);
}

// Find the desired target device.
Expand All @@ -639,10 +639,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
// TODO(mbs): Replace device_type with target so this lookup is unnecessary.
target = GetTargetFromInteger(device_type, targets_);
}
Array<Expr> visited_args;
for (const auto& arg : call_node->args) {
visited_args.push_back(VisitExpr(arg));
}

// Lower the primitive function for that target.
Function func = Downcast<Function>(prim_func);
return MakeLoweredCall(func, visited_args, call_node->type_args, call_node->span, target);
Expand Down
13 changes: 0 additions & 13 deletions src/relay/op/memory/device_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,6 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) {
} else {
return {call_node->args[0], src_dev_type, dst_dev_type};
}
} else if (call_node->op == CallLoweredOp()) {
/* Get device props for a TIR function */
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

if (call_lowered_props.attrs.metadata.count("source_device") == 1 &&
call_lowered_props.attrs.metadata.count("dst_device") == 1) {
ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1";
return {call_lowered_props.lowered_func,
static_cast<DLDeviceType>(
Downcast<Integer>(call_lowered_props.attrs.metadata["source_device"])->value),
static_cast<DLDeviceType>(
Downcast<Integer>(call_lowered_props.attrs.metadata["dst_device"])->value)};
}
}
return {};
}
Expand Down

0 comments on commit 9efc217

Please sign in to comment.