diff --git a/functorch/csrc/BatchRulesScatterOps.cpp b/functorch/csrc/BatchRulesScatterOps.cpp index ebdb8466dd..6a2a10255d 100644 --- a/functorch/csrc/BatchRulesScatterOps.cpp +++ b/functorch/csrc/BatchRulesScatterOps.cpp @@ -16,7 +16,7 @@ namespace at { namespace functorch { -std::vector> batchIndices( +std::tuple>, int64_t, int64_t> batchIndices( ArrayRef> indices, ArrayRef> indices_bdims, int64_t batch_size, @@ -41,11 +41,17 @@ std::vector> batchIndices( // batch would result in a tensor with different values. std::vector> indices_; int64_t minIndexDim = 0; + for (size_t i = 0; i < indices.size(); i++) { + if (indices[i].has_value()) { + minIndexDim = std::max(minIndexDim, rankWithoutBatchDim(indices[i].value(), indices_bdims[i])); + } + } + for (size_t i = 0; i < indices.size(); i++) { auto index = indices[i]; - if (index.has_value()) { - indices_.emplace_back(moveBatchDimToFront(index.value(), indices_bdims[i])); - minIndexDim = std::max(minIndexDim, index.value().dim()); + if (index.has_value() && index->numel() != 0) { + const auto idx_bdim = indices_bdims[i]; + indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, minIndexDim)); if (index.value().dtype() == kBool && indices_bdims[i].has_value()) { throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask."); } @@ -58,10 +64,17 @@ std::vector> batchIndices( for (auto idx : indices_bdims) { indices_batched = indices_batched || idx.has_value(); } - if (!indices_batched && values_bdim.has_value()) { + if (indices_batched || values_bdim.has_value()) { minIndexDim += 1; } + size_t batchLoc = 0; + if (indices_batched) { + while(!indices_[batchLoc].has_value() || indices_[batchLoc]->numel() == 0) { + batchLoc += 1; + } + } + if (!indices_batched && self_bdim.has_value()) { indices_.insert(indices_.begin(), nullopt); } else if (indices_batched && !self_bdim.has_value()) { @@ -73,7 +86,7 @@ std::vector> batchIndices( } indices_.insert(indices_.begin(), arange_index); } - return indices_; + return std::make_tuple(indices_, batchLoc, minIndexDim); } std::tuple> index_batch_rule( @@ -84,8 +97,18 @@ std::tuple> index_batch_rule( auto self_ = moveBatchDimToFront(self, self_bdim); TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); - std::vector> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim); - return std::make_tuple(at::index(self_, List>(indices_)), 0); + const auto ret = batchIndices(indices, indices_bdims, self_.size(0), self_bdim); + const std::vector> indices_ = std::get<0>(ret); + auto res = at::index(self_, List>(indices_)); + auto outDim = std::get<1>(ret); + const auto minIndexDim = std::get<2>(ret); + if (self_bdim.has_value() && outDim != 0) { + for (int i = 0; i < outDim; i ++) { + res = res.movedim(minIndexDim + i, i + 1); + } + outDim = 0; + } + return std::make_tuple(res, outDim); } // plumbing done since we don't support List> in codegen @@ -195,7 +218,7 @@ namespace { TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size()); // we've already made sure that self has bdim at 0. - std::vector> indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim); + std::vector> indices_ = std::get<0>(batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim)); auto indexed_shape = get_indexed_shape(self_, List>(indices_)); diff --git a/test/test_vmap.py b/test/test_vmap.py index 61b4913386..330232ac23 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3513,6 +3513,27 @@ def f(x, y): y = torch.randn(2, 3, device=device) self.assertTrue(isinstance(vmap(f)(x, y), Point)) + def test_advanced_indexing(self, device): + def test(f, args): + for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}): + self.assertEqual(loop_out, batched_out) + + def f(x, idx): + return x[:, idx] + + def f2(x, idx): + return x[idx, :] + + def f3(x, idx): + return x[:, :, idx] + + inps = (torch.randn(5, 5, 5), torch.randn(5, 5, 5, 5), torch.randn(5, 5, 5, 5, 5)) + idxes = (torch.tensor([0, 1, 2]), torch.tensor([0, 1, 2]).reshape(3, 1), torch.tensor([0, 1, 2]).reshape(3, 1, 1)) + for (inp, idx) in itertools.product(inps, idxes): + test(f, (inp, idx)) + test(f2, (inp, idx)) + test(f3, (inp, idx)) + class TestRandomness(TestCase): def _reset_random(self, generator, orig_state, use_generator, seed):