Skip to content

Commit

Permalink
Fix dimensionality check (#1759)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Jun 13, 2022
1 parent 2d6343f commit b263562
Showing 1 changed file with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -595,15 +595,16 @@ void initNvFuserPythonBindings(PyObject* module) {
[](TensorView* input,
std::vector<int>& output_shape,
std::vector<int>& broadcast_dims) -> TensorView* {
const auto input_ndims = input->domain()->noReductions().size();
TORCH_CHECK(
output_shape.size() >= input->nDims(),
output_shape.size() >= input_ndims,
"The new shape is expected to be greater-then-or-equal to the input",
output_shape.size(),
input->nDims());
input_ndims);
TORCH_CHECK(
input->nDims() == broadcast_dims.size(),
input_ndims == broadcast_dims.size(),
"The broadcast dimensions should match the input dimensions.",
input->nDims(),
input_ndims,
broadcast_dims.size());

std::vector<bool> is_broadcast_dim(output_shape.size(), true);
Expand Down

0 comments on commit b263562

Please sign in to comment.