Skip to content

Commit

Permalink
[functorch] Fix advanced indexing (pytorch/functorch#777)
Browse files Browse the repository at this point in the history
* fix advanced indexing

* add comments

* fix test device loc, add nested version

* fix cuda test
  • Loading branch information
Samantha Andow authored and zou3519 committed Jul 20, 2022
1 parent 5ef6034 commit 7435588
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 17 deletions.
70 changes: 53 additions & 17 deletions functorch/functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace at { namespace functorch {

std::vector<optional<Tensor>> batchIndices(
std::tuple<std::vector<optional<Tensor>>, int64_t, int64_t> batchIndices(
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
int64_t batch_size,
Expand All @@ -29,23 +29,38 @@ std::vector<optional<Tensor>> batchIndices(
//
// 2. self is not batched, some indices are batched.
// In this case, we don't need to do anything - indices will automatically
// broadcast to work with the unbatched self.
// broadcast to work with the unbatched self. If there are leading Nones or
// empty tensors in the indices, the batch dimension will appear after
// those leading blanks, so the 1st ret value captures that for the batching
// rules that use this
//
// 3. self is batched, some indices are batched.
// In this case, we simply need to add an arange that indexes along the first
// dimension (i.e. the batch dimension). We also need to make sure this
// broadcasts with the rest of the indices.
// broadcasts with the rest of the indices. If there are leading Nones or empty
// tensors, then this will mess us up because we will get [batch_dim, index_dim0, ...]
// instead of [batch_dim, self_dim0, ...] matching the number of leading Nones. By
// returning batchLoc (the number of leading Nones) and maxIndexDim (the maximum
// dim of any of the index), we can move dims of the result to make the right shape
//
// There is one more case worth mentioning - boolean tensor indices. If we
// have "batched" boolean tensor indices, that is unrepresentable, as each
// batch would result in a tensor with different values.
std::vector<optional<Tensor>> indices_;
int64_t minIndexDim = 0;
int64_t maxLogicalRank = 0;
bool indices_batched = false;
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i].has_value()) {
maxLogicalRank = std::max(maxLogicalRank, rankWithoutBatchDim(indices[i].value(), indices_bdims[i]));
}
indices_batched = indices_batched || indices_bdims[i].has_value();
}

for (size_t i = 0; i < indices.size(); i++) {
auto index = indices[i];
if (index.has_value()) {
indices_.emplace_back(moveBatchDimToFront(index.value(), indices_bdims[i]));
minIndexDim = std::max(minIndexDim, index.value().dim());
if (index.has_value() && index->numel() != 0) {
const auto idx_bdim = indices_bdims[i];
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
}
Expand All @@ -54,12 +69,17 @@ std::vector<optional<Tensor>> batchIndices(
}
}

bool indices_batched = false;
for (auto idx : indices_bdims) {
indices_batched = indices_batched || idx.has_value();
auto maxIndexDim = maxLogicalRank;
if (indices_batched || values_bdim.has_value()) {
maxIndexDim += 1;
}
if (!indices_batched && values_bdim.has_value()) {
minIndexDim += 1;

size_t batchLoc = 0;
// If there's leading Nones ([:, :,...]) and indices is batched, the batch dimension will show up after the skipped dimensinos
if (indices_batched) {
while(!indices_[batchLoc].has_value() || indices_[batchLoc]->numel() == 0) {
batchLoc += 1;
}
}

if (!indices_batched && self_bdim.has_value()) {
Expand All @@ -68,12 +88,12 @@ std::vector<optional<Tensor>> batchIndices(
// do nothing
} else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {
auto arange_index = at::arange(0, batch_size);
while (arange_index.dim() < minIndexDim) {
while (arange_index.dim() < maxIndexDim) {
arange_index = arange_index.unsqueeze(-1);
}
indices_.insert(indices_.begin(), arange_index);
}
return indices_;
return std::make_tuple(indices_, batchLoc, maxIndexDim);
}

std::tuple<Tensor,optional<int64_t>> index_batch_rule(
Expand All @@ -84,8 +104,23 @@ std::tuple<Tensor,optional<int64_t>> index_batch_rule(

auto self_ = moveBatchDimToFront(self, self_bdim);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
return std::make_tuple(at::index(self_, List<optional<Tensor>>(indices_)), 0);
const auto ret = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
const std::vector<optional<Tensor>> indices_ = std::get<0>(ret);
auto res = at::index(self_, List<optional<Tensor>>(indices_));
auto outDim = std::get<1>(ret);
const auto maxIndexDim = std::get<2>(ret);
if (self_bdim.has_value() && outDim != 0) {
// this will only happen if at least one index is batched and there is at least one leading None.
// In this case, we will have [batch_dim, index_dim0, ..., self_dim0, ...] instead of [batch_dim, self_dim0, ..., index_dim0]
// so we move dims around until we get to the right shape. We know where self_dim0 is because maxIndexDim is
// how long the broadcasted indices will be and we know how many to move because outDim is the number of leading
// Nones
for (int i = 0; i < outDim; i ++) {
res = res.movedim(maxIndexDim + i, i + 1);
}
outDim = 0;
}
return std::make_tuple(res, outDim);
}

// plumbing done since we don't support List<optional<Tensor>> in codegen
Expand Down Expand Up @@ -195,7 +230,8 @@ namespace {
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());

// we've already made sure that self has bdim at 0.
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
// TODO(samdow): fix. We made changes to batchIndices that fixed index and probably show issues in index_put
std::vector<optional<Tensor>> indices_ = std::get<0>(batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim));

auto indexed_shape = get_indexed_shape(self_, List<optional<Tensor>>(indices_));

Expand Down
51 changes: 51 additions & 0 deletions functorch/test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3515,6 +3515,57 @@ def f(x, y):
y = torch.randn(2, 3, device=device)
self.assertTrue(isinstance(vmap(f)(x, y), Point))

def test_advanced_indexing(self, device):
def test(f, args):
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
self.assertEqual(loop_out, batched_out)

def f(x, idx):
return x[:, idx]

def f2(x, idx):
return x[idx, :]

def f3(x, idx):
return x[:, :, idx]

inps = (torch.randn(5, 5, 5, device=device),
torch.randn(5, 5, 5, 5, device=device),
torch.randn(5, 5, 5, 5, 5, device=device))
idxes = (torch.tensor([0, 1, 2], device=device),
torch.tensor([0, 1, 2], device=device).reshape(3, 1),
torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1))
for (inp, idx) in itertools.product(inps, idxes):
test(f, (inp, idx))
test(f2, (inp, idx))
test(f3, (inp, idx))

def test_nested_advanced_indexing(self, device):
e = torch.rand(7, 4, device=device)
idx = torch.tensor([0, 1], device=device).view(2, 1)

# simple reference implementation for comparison
def _fake_vmap(f, in_dims=0, out_dims=0):
def w(input):
r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))]
return torch.stack(r, out_dims)

return w

def with_vmap(_vmap):
def g(idx_):
def f(e_):
return e_[idx_]

return _vmap(f, in_dims=1)(e)

r = _vmap(g)(idx)
return r

a = with_vmap(vmap)
b = with_vmap(_fake_vmap)
self.assertEqual(a, b)


class TestRandomness(TestCase):
def _reset_random(self, generator, orig_state, use_generator, seed):
Expand Down

0 comments on commit 7435588

Please sign in to comment.