Skip to content

Commit 2eb6aef

Browse files
abhudevpytorchmergebot
authored andcommitted
Handle compatible inputs to where (#124)
* Handle compatible input dims; move where to allowlist * Add bool back to block list * Refactor max operation
1 parent 0a7d8b4 commit 2eb6aef

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

aten/src/ATen/native/mps/operations/TensorCompare.mm

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,20 +386,22 @@ Tensor where_mps(const Tensor& condition,
386386

387387
auto max_dim = std::max(condition.dim(), std::max(self.dim(), other.dim()));
388388

389-
auto sum_dims = condition.dim() + self.dim() + other.dim();
390-
391-
TORCH_CHECK(max_dim == 0 || !(sum_dims % max_dim), "All inputs of where should have same/compatible number of dims")
392-
393389
std::vector<int64_t> out_arr(max_dim);
394390

395391
// Broadcasted output shape
396392
for(int i = 0; i < max_dim; i++) {
397393

398-
int64_t cond_num = cond_zero_shape ? 0 : condition.size(i);
399-
int64_t self_num = self_zero_shape ? 0 : self.size(i);
400-
int64_t other_num = other_zero_shape ? 0 : other.size(i);
394+
int64_t cond_idx = cond_zero_shape ? 1 : (i < condition.dim() ? condition.size(i) : 1);
395+
int64_t self_idx = self_zero_shape ? 1 : (i < self.dim() ? self.size(i) : 1);
396+
int64_t other_idx = other_zero_shape ? 1 : (i < other.dim() ? other.size(i) : 1);
397+
398+
auto max_idx = std::max({cond_idx, self_idx, other_idx});
399+
400+
TORCH_CHECK(cond_idx == max_idx || cond_idx == 1, i, "'th index ", cond_idx, " of condition tensor does not match the other tensors")
401+
TORCH_CHECK(self_idx == max_idx || self_idx == 1, i, "'th index ", self_idx, " of x tensor does not match the other tensors")
402+
TORCH_CHECK(other_idx == max_idx || other_idx == 1, i, "'th index ", other_idx, " of x tensor does not match the other tensors")
401403

402-
out_arr[i] = std::max(cond_num, std::max(self_num, other_num));
404+
out_arr[i] = max_idx;
403405
}
404406

405407
Tensor ret = empty_mps(IntArrayRef(out_arr),

test/test_mps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7264,7 +7264,8 @@ class TestConsistency(TestCase):
72647264
'clamp_min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
72657265
'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
72667266
'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7267-
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8']}
7267+
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7268+
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']}
72687269

72697270

72707271
ALLOWLIST_OP_GRAD = {

0 commit comments

Comments
 (0)