diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index 092727cb5115..c32cdc3aacb3 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -133,12 +133,12 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { auto input_shape = input_sinfo->GetShape(); CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape."; - if (analyzer->CanProve(floormod(input_shape.value()[0], PrimExpr(num_workers))) != 0) { + if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], PrimExpr(num_workers)) != 0)) { ctx->ReportFatal(Diagnostic::Error(call) - << "scatter_from_worker0 expects the size of axis 0 of input tensor to be " - "divisible by the " - "num_workers. However, the axis 0 of input tensor is " - << input_shape.value() << " while num_workers is " << num_workers); + << "scatter_from_worker0 expects the size of axis " << attrs->axis + << " of input tensor to be divisible by the num_workers. However, axis " + << attrs->axis << " of input tensor is " << input_shape.value() + << " while num_workers is " << num_workers); } Array output_shape = input_shape.value();