Skip to content

Commit

Permalink
fix advanced indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
samdow committed May 4, 2022
1 parent 4f06035 commit fdf0e5c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
41 changes: 32 additions & 9 deletions functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace at { namespace functorch {

std::vector<optional<Tensor>> batchIndices(
std::tuple<std::vector<optional<Tensor>>, int64_t, int64_t> batchIndices(
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
int64_t batch_size,
Expand All @@ -41,11 +41,17 @@ std::vector<optional<Tensor>> batchIndices(
// batch would result in a tensor with different values.
std::vector<optional<Tensor>> indices_;
int64_t minIndexDim = 0;
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i].has_value()) {
minIndexDim = std::max(minIndexDim, rankWithoutBatchDim(indices[i].value(), indices_bdims[i]));
}
}

for (size_t i = 0; i < indices.size(); i++) {
auto index = indices[i];
if (index.has_value()) {
indices_.emplace_back(moveBatchDimToFront(index.value(), indices_bdims[i]));
minIndexDim = std::max(minIndexDim, index.value().dim());
if (index.has_value() && index->numel() != 0) {
const auto idx_bdim = indices_bdims[i];
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, minIndexDim));
if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
}
Expand All @@ -58,10 +64,17 @@ std::vector<optional<Tensor>> batchIndices(
for (auto idx : indices_bdims) {
indices_batched = indices_batched || idx.has_value();
}
if (!indices_batched && values_bdim.has_value()) {
if (indices_batched || values_bdim.has_value()) {
minIndexDim += 1;
}

size_t batchLoc = 0;
if (indices_batched) {
while(!indices_[batchLoc].has_value() || indices_[batchLoc]->numel() == 0) {
batchLoc += 1;
}
}

if (!indices_batched && self_bdim.has_value()) {
indices_.insert(indices_.begin(), nullopt);
} else if (indices_batched && !self_bdim.has_value()) {
Expand All @@ -73,7 +86,7 @@ std::vector<optional<Tensor>> batchIndices(
}
indices_.insert(indices_.begin(), arange_index);
}
return indices_;
return std::make_tuple(indices_, batchLoc, minIndexDim);
}

std::tuple<Tensor,optional<int64_t>> index_batch_rule(
Expand All @@ -84,8 +97,18 @@ std::tuple<Tensor,optional<int64_t>> index_batch_rule(

auto self_ = moveBatchDimToFront(self, self_bdim);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
return std::make_tuple(at::index(self_, List<optional<Tensor>>(indices_)), 0);
const auto ret = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
const std::vector<optional<Tensor>> indices_ = std::get<0>(ret);
auto res = at::index(self_, List<optional<Tensor>>(indices_));
auto outDim = std::get<1>(ret);
const auto minIndexDim = std::get<2>(ret);
if (self_bdim.has_value() && outDim != 0) {
for (int i = 0; i < outDim; i ++) {
res = res.movedim(minIndexDim + i, i + 1);
}
outDim = 0;
}
return std::make_tuple(res, outDim);
}

// plumbing done since we don't support List<optional<Tensor>> in codegen
Expand Down Expand Up @@ -195,7 +218,7 @@ namespace {
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());

// we've already made sure that self has bdim at 0.
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
std::vector<optional<Tensor>> indices_ = std::get<0>(batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim));

auto indexed_shape = get_indexed_shape(self_, List<optional<Tensor>>(indices_));

Expand Down
21 changes: 21 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3513,6 +3513,27 @@ def f(x, y):
y = torch.randn(2, 3, device=device)
self.assertTrue(isinstance(vmap(f)(x, y), Point))

def test_advanced_indexing(self, device):
def test(f, args):
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
self.assertEqual(loop_out, batched_out)

def f(x, idx):
return x[:, idx]

def f2(x, idx):
return x[idx, :]

def f3(x, idx):
return x[:, :, idx]

inps = (torch.randn(5, 5, 5), torch.randn(5, 5, 5, 5), torch.randn(5, 5, 5, 5, 5))
idxes = (torch.tensor([0, 1, 2]), torch.tensor([0, 1, 2]).reshape(3, 1), torch.tensor([0, 1, 2]).reshape(3, 1, 1))
for (inp, idx) in itertools.product(inps, idxes):
test(f, (inp, idx))
test(f2, (inp, idx))
test(f3, (inp, idx))


class TestRandomness(TestCase):
def _reset_random(self, generator, orig_state, use_generator, seed):
Expand Down

0 comments on commit fdf0e5c

Please sign in to comment.