From 68c1409320a361f85e13ffcb96a41d04a8ba6744 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 28 Sep 2025 11:55:25 -0400 Subject: [PATCH] [Fix] Update ShapeView use in nccl.cc This PR fixes the use of ShapeView in nccl.cc, which was using `Shape()->Product()`. This has been changed to `Shape().Product()` with the introduction of ShapeView. --- src/runtime/disco/nccl/nccl.cc | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index 2eb0c3348bd5..fd4ad06c3fa8 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -150,13 +150,13 @@ void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv const void* send_data = [&]() -> const void* { if (is_sender) { CHECK(send.defined()); - CHECK(send.value().Shape()->Product() == recv.Shape()->Product()); + CHECK(send.value().Shape().Product() == recv.Shape().Product()); return send.value()->data; } else { return nullptr; } }(); - int64_t numel = recv.Shape()->Product(); + int64_t numel = recv.Shape().Product(); deviceStream_t stream = ctx->GetDefaultStream(); NCCL_CALL(ncclBroadcast(send_data, recv->data, numel, @@ -176,7 +176,7 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) if (is_sender) { CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0."; Tensor buffer = send.value(); - int64_t numel = buffer.Shape()->Product(); + int64_t numel = buffer.Shape().Product(); CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly requires that the number " "of elements in the buffer to be " "divisible by the number of workers, but got numel = " @@ -184,11 +184,11 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) DataType dtype(buffer->dtype); int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); - CHECK_EQ(numel_per_shard, recv.Shape()->Product()) + CHECK_EQ(numel_per_shard, recv.Shape().Product()) << "ValueError: The number of elements in buffer `recv` must be the same as each shard " "of " "buffer `send`. `send.size` is " - << numel << ", but `recv.size` is " << recv.Shape()->Product() << "."; + << numel << ", but `recv.size` is " << recv.Shape().Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); for (int i = 0; i < num_receiver; ++i) { @@ -204,7 +204,7 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) } NCCL_CALL(ncclGroupStart()); } - int64_t numel = recv.Shape()->Product(); + int64_t numel = recv.Shape().Product(); DataType dtype(recv->dtype); NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); @@ -223,7 +223,7 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { if (is_sender) { CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0."; Tensor buffer = recv.value(); - int64_t numel = buffer.Shape()->Product(); + int64_t numel = buffer.Shape().Product(); CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly requires that the number " "of elements in the buffer to be " "divisible by the number of workers, but got numel = " @@ -231,11 +231,11 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { DataType dtype(buffer->dtype); int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); - CHECK_EQ(numel_per_shard, send.Shape()->Product()) + CHECK_EQ(numel_per_shard, send.Shape().Product()) << "ValueError: The number of elements in buffer `send` must be the same as each shard " "of " "buffer `recv`. `recv.size` is " - << numel << ", but `send.size` is " << send.Shape()->Product() << "."; + << numel << ", but `send.size` is " << send.Shape().Product() << "."; NCCL_CALL(ncclGroupStart()); uint8_t* data = static_cast(buffer->data); for (int i = 0; i < num_receiver; ++i) { @@ -251,7 +251,7 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { } NCCL_CALL(ncclGroupStart()); } - int64_t numel = send.Shape()->Product(); + int64_t numel = send.Shape().Product(); DataType dtype(send->dtype); NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0, in_group ? ctx->group_comm : ctx->global_comm, stream)); @@ -264,7 +264,7 @@ void RecvFromWorker0(Tensor buffer) { CHECK_NE(ctx->worker->worker_id, 0) << "ValueError: Worker 0 is not allowed to call RecvFromWorker0."; NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0, + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), 0, ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -278,7 +278,7 @@ void SendToNextGroup(Tensor buffer) { CHECK_LT(receiver_id, ctx->worker->num_workers) << "The current group is already the last group and there is no such a next group."; NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), receiver_id, ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -292,7 +292,7 @@ void RecvFromPrevGroup(Tensor buffer) { CHECK_GE(sender_id, 0) << "The current group is already the first group and there is no such a previous group."; NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), sender_id, ctx->global_comm, stream)); NCCL_CALL(ncclGroupEnd()); } @@ -305,7 +305,7 @@ void SendToWorker(Tensor buffer, int receiver_id) { << "Invalid receiver id " << receiver_id << ". The world size is " << ctx->worker->num_workers; CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself."; - NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), receiver_id, ctx->global_comm, stream)); } @@ -316,7 +316,7 @@ void RecvFromWorker(Tensor buffer, int sender_id) { CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers) << "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers; CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself."; - NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), + NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), sender_id, ctx->global_comm, stream)); }