diff --git a/src/torch_xccl.cpp b/src/torch_xccl.cpp index 5e6d781..0fda62a 100644 --- a/src/torch_xccl.cpp +++ b/src/torch_xccl.cpp @@ -10,6 +10,15 @@ namespace c10d { +#define XCCL_CHECK_GOTO(_cmd, _label) \ + do { \ + xccl_status_t st = _cmd; \ + if (XCCL_OK != st) { \ + fprintf(stderr, "TorchUCC error: %s:%d %d", __FILE__, __LINE__, st); \ + goto _label; \ + } \ + } while (0) + struct xccl_oob_allgather_req_t { xccl_ep_range_t range; void* sbuf; @@ -333,12 +342,16 @@ torch_ucc_status_t torch_xccl_allgather( coll_args.buffer_info.len = buf_len; coll_args.alg.set_by_user = 0; - xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team); - + XCCL_CHECK_GOTO( + xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team), error); coll_req->request = xccl_req; *request = (torch_ucc_coll_request_t*)coll_req; return TORCH_UCC_OK; +error: + fprintf(stderr, "TorchUCC: allgather init failed\n"); + delete coll_req; + return TORCH_UCC_ERROR; } torch_ucc_status_t torch_xccl_alltoall( @@ -372,12 +385,16 @@ torch_ucc_status_t torch_xccl_alltoall( coll_args.buffer_info.len = buf_len; coll_args.alg.set_by_user = 0; - xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team); - + XCCL_CHECK_GOTO( + xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team), error); coll_req->request = xccl_req; *request = (torch_ucc_coll_request_t*)coll_req; return TORCH_UCC_OK; +error: + fprintf(stderr, "TorchUCC: alltoall init failed\n"); + delete coll_req; + return TORCH_UCC_ERROR; } torch_ucc_status_t torch_xccl_alltoallv( @@ -419,12 +436,16 @@ torch_ucc_status_t torch_xccl_alltoallv( xccl_type_map.at(output_tensor.scalar_type()); coll_args.alg.set_by_user = 0; - xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team); - + XCCL_CHECK_GOTO( + xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team), error); coll_req->request = xccl_req; *request = (torch_ucc_coll_request_t*)coll_req; return TORCH_UCC_OK; +error: + fprintf(stderr, "TorchUCC: alltoallv init failed\n"); + delete coll_req; + return TORCH_UCC_ERROR; } torch_ucc_status_t torch_xccl_allreduce( @@ -454,12 +475,16 @@ torch_ucc_status_t torch_xccl_allreduce( coll_args.reduce_info.count = tensor.numel(); coll_args.alg.set_by_user = 0; - xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team); - + XCCL_CHECK_GOTO( + xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team), error); coll_req->request = xccl_req; *request = (torch_ucc_coll_request_t*)coll_req; return TORCH_UCC_OK; +error: + fprintf(stderr, "TorchUCC: allreduce init failed\n"); + delete coll_req; + return TORCH_UCC_ERROR; } torch_ucc_status_t torch_xccl_barrier( @@ -478,12 +503,16 @@ torch_ucc_status_t torch_xccl_barrier( coll_req->status = TORCH_UCC_OPERATION_INITIALIZED; coll_args.coll_type = XCCL_BARRIER; - xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team); - + XCCL_CHECK_GOTO( + xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team), error); coll_req->request = xccl_req; *request = (torch_ucc_coll_request_t*)coll_req; return TORCH_UCC_OK; +error: + fprintf(stderr, "TorchUCC: barrier init failed\n"); + delete coll_req; + return TORCH_UCC_ERROR; } torch_ucc_status_t torch_xccl_broadcast( @@ -511,12 +540,16 @@ torch_ucc_status_t torch_xccl_broadcast( coll_args.root = root; coll_args.alg.set_by_user = 0; - xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team); - + XCCL_CHECK_GOTO( + xccl_collective_init(&coll_args, &xccl_req, xccl_comm->xccl_team), error); coll_req->request = xccl_req; *request = (torch_ucc_coll_request_t*)coll_req; return TORCH_UCC_OK; +error: + fprintf(stderr, "TorchUCC: broadcast init failed\n"); + delete coll_req; + return TORCH_UCC_ERROR; } torch_ucc_status_t torch_xccl_progress(torch_ucc_coll_request_t* request) {