From c2defe8b9551af40fe02be0b9ba4a91e720d426b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 13 Jun 2022 12:33:38 -0700 Subject: [PATCH] Fix dimensionality check --- .../jit/codegen/cuda/python_frontend/python_bindings.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index 08fa1d4e0411e..d6009cbae74ce 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -595,15 +595,16 @@ void initNvFuserPythonBindings(PyObject* module) { [](TensorView* input, std::vector& output_shape, std::vector& broadcast_dims) -> TensorView* { + const auto input_ndims = input->domain()->noReductions().size(); TORCH_CHECK( - output_shape.size() >= input->nDims(), + output_shape.size() >= input_ndims, "The new shape is expected to be greater-then-or-equal to the input", output_shape.size(), - input->nDims()); + input_ndims); TORCH_CHECK( - input->nDims() == broadcast_dims.size(), + input_ndims == broadcast_dims.size(), "The broadcast dimensions should match the input dimensions.", - input->nDims(), + input_ndims, broadcast_dims.size()); std::vector is_broadcast_dim(output_shape.size(), true);