Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix advanced indexing #777

Merged
merged 4 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace at { namespace functorch {

std::vector<optional<Tensor>> batchIndices(
std::tuple<std::vector<optional<Tensor>>, int64_t, int64_t> batchIndices(
ArrayRef<optional<Tensor>> indices,
ArrayRef<optional<int64_t>> indices_bdims,
int64_t batch_size,
Expand All @@ -29,23 +29,38 @@ std::vector<optional<Tensor>> batchIndices(
//
// 2. self is not batched, some indices are batched.
// In this case, we don't need to do anything - indices will automatically
// broadcast to work with the unbatched self.
// broadcast to work with the unbatched self. If there are leading Nones or
// empty tensors in the indices, the batch dimension will appear after
// those leading blanks, so the 1st ret value captures that for the batching
// rules that use this
//
// 3. self is batched, some indices are batched.
// In this case, we simply need to add an arange that indexes along the first
// dimension (i.e. the batch dimension). We also need to make sure this
// broadcasts with the rest of the indices.
// broadcasts with the rest of the indices. If there are leading Nones or empty
// tensors, then this will mess us up because we will get [batch_dim, index_dim0, ...]
// instead of [batch_dim, self_dim0, ...] matching the number of leading Nones. By
// returning batchLoc (the number of leading Nones) and maxIndexDim (the maximum
// dim of any of the index), we can move dims of the result to make the right shape
//
// There is one more case worth mentioning - boolean tensor indices. If we
// have "batched" boolean tensor indices, that is unrepresentable, as each
// batch would result in a tensor with different values.
std::vector<optional<Tensor>> indices_;
int64_t minIndexDim = 0;
int64_t maxLogicalRank = 0;
bool indices_batched = false;
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i].has_value()) {
maxLogicalRank = std::max(maxLogicalRank, rankWithoutBatchDim(indices[i].value(), indices_bdims[i]));
}
indices_batched = indices_batched || indices_bdims[i].has_value();
}

for (size_t i = 0; i < indices.size(); i++) {
auto index = indices[i];
if (index.has_value()) {
indices_.emplace_back(moveBatchDimToFront(index.value(), indices_bdims[i]));
minIndexDim = std::max(minIndexDim, index.value().dim());
if (index.has_value() && index->numel() != 0) {
const auto idx_bdim = indices_bdims[i];
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
}
Expand All @@ -54,12 +69,17 @@ std::vector<optional<Tensor>> batchIndices(
}
}

bool indices_batched = false;
for (auto idx : indices_bdims) {
indices_batched = indices_batched || idx.has_value();
auto maxIndexDim = maxLogicalRank;
if (indices_batched || values_bdim.has_value()) {
maxIndexDim += 1;
}
if (!indices_batched && values_bdim.has_value()) {
minIndexDim += 1;

size_t batchLoc = 0;
// If there's leading Nones ([:, :,...]) and indices is batched, the batch dimension will show up after the skipped dimensinos
if (indices_batched) {
while(!indices_[batchLoc].has_value() || indices_[batchLoc]->numel() == 0) {
batchLoc += 1;
}
}

if (!indices_batched && self_bdim.has_value()) {
Expand All @@ -68,12 +88,12 @@ std::vector<optional<Tensor>> batchIndices(
// do nothing
} else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {
auto arange_index = at::arange(0, batch_size);
while (arange_index.dim() < minIndexDim) {
while (arange_index.dim() < maxIndexDim) {
arange_index = arange_index.unsqueeze(-1);
}
indices_.insert(indices_.begin(), arange_index);
}
return indices_;
return std::make_tuple(indices_, batchLoc, maxIndexDim);
}

std::tuple<Tensor,optional<int64_t>> index_batch_rule(
Expand All @@ -84,8 +104,23 @@ std::tuple<Tensor,optional<int64_t>> index_batch_rule(

auto self_ = moveBatchDimToFront(self, self_bdim);
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
return std::make_tuple(at::index(self_, List<optional<Tensor>>(indices_)), 0);
const auto ret = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
const std::vector<optional<Tensor>> indices_ = std::get<0>(ret);
auto res = at::index(self_, List<optional<Tensor>>(indices_));
auto outDim = std::get<1>(ret);
const auto maxIndexDim = std::get<2>(ret);
if (self_bdim.has_value() && outDim != 0) {
// this will only happen if at least one index is batched and there is at least one leading None.
// In this case, we will have [batch_dim, index_dim0, ..., self_dim0, ...] instead of [batch_dim, self_dim0, ..., index_dim0]
// so we move dims around until we get to the right shape. We know where self_dim0 is because maxIndexDim is
// how long the broadcasted indices will be and we know how many to move because outDim is the number of leading
// Nones
for (int i = 0; i < outDim; i ++) {
res = res.movedim(maxIndexDim + i, i + 1);
}
outDim = 0;
}
return std::make_tuple(res, outDim);
}

// plumbing done since we don't support List<optional<Tensor>> in codegen
Expand Down Expand Up @@ -195,7 +230,8 @@ namespace {
TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());

// we've already made sure that self has bdim at 0.
std::vector<optional<Tensor>> indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
// TODO(samdow): fix. We made changes to batchIndices that fixed index and probably show issues in index_put
std::vector<optional<Tensor>> indices_ = std::get<0>(batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim));

auto indexed_shape = get_indexed_shape(self_, List<optional<Tensor>>(indices_));

Expand Down
51 changes: 51 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3515,6 +3515,57 @@ def f(x, y):
y = torch.randn(2, 3, device=device)
self.assertTrue(isinstance(vmap(f)(x, y), Point))

def test_advanced_indexing(self, device):
def test(f, args):
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
self.assertEqual(loop_out, batched_out)

def f(x, idx):
return x[:, idx]

def f2(x, idx):
return x[idx, :]

def f3(x, idx):
return x[:, :, idx]

inps = (torch.randn(5, 5, 5, device=device),
torch.randn(5, 5, 5, 5, device=device),
torch.randn(5, 5, 5, 5, 5, device=device))
idxes = (torch.tensor([0, 1, 2], device=device),
torch.tensor([0, 1, 2], device=device).reshape(3, 1),
torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1))
for (inp, idx) in itertools.product(inps, idxes):
test(f, (inp, idx))
test(f2, (inp, idx))
test(f3, (inp, idx))

def test_nested_advanced_indexing(self, device):
e = torch.rand(7, 4, device=device)
idx = torch.tensor([0, 1], device=device).view(2, 1)

# simple reference implementation for comparison
def _fake_vmap(f, in_dims=0, out_dims=0):
def w(input):
r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))]
return torch.stack(r, out_dims)

return w

def with_vmap(_vmap):
def g(idx_):
def f(e_):
return e_[idx_]

return _vmap(f, in_dims=1)(e)

r = _vmap(g)(idx)
return r

a = with_vmap(vmap)
b = with_vmap(_fake_vmap)
self.assertEqual(a, b)


class TestRandomness(TestCase):
def _reset_random(self, generator, orig_state, use_generator, seed):
Expand Down