Skip to content

Commit

Permalink
[XPU] bkcl_broadcast support int64_t (#53720)
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored May 12, 2023
1 parent b150b16 commit 13cdaab
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions paddle/fluid/distributed/collective/process_group_bkcl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,30 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
<< ", root: " << root << ", numel: " << input.numel()
<< ", dtype: " << input.type() << ", sync_op: " << sync_op
<< ", use_calc_stream: " << use_calc_stream;
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
if (framework::TransToProtoVarType(input.dtype()) ==
framework::proto::VarType::INT64) {
// special for int64_t, send as int32_t with DOUBLE NUMEL
int r = bkcl_broadcast(
comm,
input.data(),
output->data(),
input.numel() * 2,
platform::ToBKCLDataType(framework::proto::VarType::INT32),
root,
stream);
return r;
} else {
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
}
},
CommType::BROADCAST,
sync_op,
Expand Down

0 comments on commit 13cdaab

Please sign in to comment.