@@ -386,20 +386,22 @@ Tensor where_mps(const Tensor& condition,
386386
387387 auto max_dim = std::max (condition.dim (), std::max (self.dim (), other.dim ()));
388388
389- auto sum_dims = condition.dim () + self.dim () + other.dim ();
390-
391- TORCH_CHECK (max_dim == 0 || !(sum_dims % max_dim), " All inputs of where should have same/compatible number of dims" )
392-
393389 std::vector<int64_t > out_arr (max_dim);
394390
395391 // Broadcasted output shape
396392 for (int i = 0 ; i < max_dim; i++) {
397393
398- int64_t cond_num = cond_zero_shape ? 0 : condition.size (i);
399- int64_t self_num = self_zero_shape ? 0 : self.size (i);
400- int64_t other_num = other_zero_shape ? 0 : other.size (i);
394+ int64_t cond_idx = cond_zero_shape ? 1 : (i < condition.dim () ? condition.size (i) : 1 );
395+ int64_t self_idx = self_zero_shape ? 1 : (i < self.dim () ? self.size (i) : 1 );
396+ int64_t other_idx = other_zero_shape ? 1 : (i < other.dim () ? other.size (i) : 1 );
397+
398+ auto max_idx = std::max ({cond_idx, self_idx, other_idx});
399+
400+ TORCH_CHECK (cond_idx == max_idx || cond_idx == 1 , i, " 'th index " , cond_idx, " of condition tensor does not match the other tensors" )
401+ TORCH_CHECK (self_idx == max_idx || self_idx == 1 , i, " 'th index " , self_idx, " of x tensor does not match the other tensors" )
402+ TORCH_CHECK (other_idx == max_idx || other_idx == 1 , i, " 'th index " , other_idx, " of x tensor does not match the other tensors" )
401403
402- out_arr[i] = std::max (cond_num, std::max (self_num, other_num)) ;
404+ out_arr[i] = max_idx ;
403405 }
404406
405407 Tensor ret = empty_mps (IntArrayRef (out_arr),
0 commit comments