diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index 538264ce9688..48d12368fa28 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -105,19 +105,6 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { } else { return {call_node->args[0], device_copy_attrs->src_se_scope, device_copy_attrs->dst_se_scope}; } - } else if (call_node->op == CallLoweredOp()) { - /* Get device props for a TIR function */ - CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - - if (call_lowered_props.attrs.metadata.count("source_device") == 1 && - call_lowered_props.attrs.metadata.count("dst_device") == 1) { - ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; - return {call_lowered_props.lowered_func, - static_cast( - Downcast(call_lowered_props.attrs.metadata["source_device"])->value), - static_cast( - Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; - } } return {}; } diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 667379d7a9a0..44c0ecf41de8 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -51,10 +51,8 @@ DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { call_lowered_props.attrs.metadata.count("dst_device") == 1) { ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; return {call_lowered_props.arguments[0], - static_cast( - Downcast(call_lowered_props.attrs.metadata["source_device"])->value), - static_cast( - Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; + Downcast(call_lowered_props.attrs.metadata["src_se_scope"]), + Downcast(call_lowered_props.attrs.metadata["dst_se_scope"])}; } } return {};