Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Operator] Added advanced tensor indexing (#251)
Browse files Browse the repository at this point in the history
Added advanced tensor indexing. I had to create a new task called
"AdvancedIndexingTask". Otherwise it will not work with dynamic shapes.

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
2 people authored and vadiklyutiy committed Jul 22, 2024
1 parent a94c0c6 commit 4706425
Showing 2 changed files with 202 additions and 17 deletions.
77 changes: 61 additions & 16 deletions python/hidet/graph/tensor.py
Original file line number Diff line number Diff line change
@@ -13,9 +13,9 @@

from typing import List, Optional, Tuple, Sequence, Union
import warnings

import numpy as np

from hidet.utils.py import same_list
import hidet.runtime.storage
import hidet.cuda
from hidet.ir import dtypes
@@ -345,33 +345,21 @@ def __str__(self):
return '{}\nfrom {}'.format(head, self.trace)

def __getitem__(self, item):
from hidet.graph.ops import strided_slice
from .ops import take, reshape, flatten
from .ops import strided_slice, take
from .ops import reshape, transpose

if isinstance(item, Tensor):
if not item.dtype.is_integer():
raise TypeError("Tensor indexing via Tensor requires integer index tensor")

if len(item.shape) == 1:
return take(self, item, axis=0)

if len(item.shape) >= 2:
item_1d = reshape(item, flatten(item).shape)
return take(self, item_1d, axis=0).reshape(tuple(item.shape) + tuple(self.shape[1:]))
return take(self, item, axis=0)

if isinstance(item, list):
item = tuple(item)

if not isinstance(item, tuple):
item = tuple([item])

# now, the item could have
# 1. integer index
# 2. slice
# 3. Ellipsis
# 4. None
# e.g., [1, 3:5, ..., None]

# process Ellipsis
# e.g., x[1, ..., 2] -> x[1, :, :, 2]
if Ellipsis in item:
@@ -382,6 +370,63 @@ def __getitem__(self, item):
ellipsis_ndim = max(ellipsis_ndim, 0)
item = item[:ellipsis_index] + (slice(None),) * ellipsis_ndim + item[ellipsis_index + 1 :]

# if some elements in item are tensors
# advanced indexing will be used
if any(isinstance(it, Tensor) for it in item):
tensor_indices = []
slice_indices = []
for i, it in enumerate(item):
if isinstance(it, Tensor):
tensor_indices.append(i)
else:
slice_indices.append(i)

if len(self.shape) == 2 and same_list(tensor_indices, list(range(2)), use_equal=True):
x = self
else:
x = transpose(self, tensor_indices + slice_indices + list(range(len(item), len(self.shape))))
n = len(tensor_indices)

item_sum = item[tensor_indices[0]] * prod(x.shape[1:n])
for i in range(1, n):
item_sum += item[tensor_indices[i]] * prod(x.shape[i + 1 : n])
x = take(reshape(x, (prod(x.shape[:n]),) + x.shape[n:]), item_sum)

# check if there is slice index between tensor indices
# if no, slice indices that came before tensor indices need to go to their initial positions
# otherwise, they stay as they are at the current line
if tensor_indices[-1] - tensor_indices[0] + 1 == n:
transpose_back = []
new_item = []
idx = 0
for _ in range(len(slice_indices)):
if slice_indices[idx] == idx:
transpose_back.append(len(item[tensor_indices[0]].shape) + idx)
new_item.append(item[idx])
idx += 1
else:
break

for i in range(len(item[tensor_indices[0]].shape)):
transpose_back.append(i)
new_item.append(slice(None))

for idx in range(idx, len(slice_indices)):
transpose_back.append(idx + len(item[tensor_indices[0]].shape))
new_item.append(item[idx + n])

if len(x.shape) != 2 or not same_list(tensor_indices, list(range(2)), use_equal=True):
x = transpose(x, transpose_back)
else:
new_item = [slice(None) for _ in range(n)] + [item[idx] for idx in slice_indices]
return x.__getitem__(new_item)

# now, the item could have
# 1. integer index
# 2. slice
# 3. None
# e.g., [1, 3:5, ..., None]

# process None
# e.g., x[2, None, 3] -> x[2, 1, 3]
if None in item:
142 changes: 141 additions & 1 deletion tests/operators/test_transform.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import hidet
import hidet as hi
from hidet import ops
from hidet import symbol, trace_from
from hidet.ir.utils import broadcast_shape
from hidet.utils import prod
from hidet.graph.tensor import asarray
@@ -213,7 +214,7 @@ def numpy_getitem(data, item):
[[129], [2, 3]],
],
)
def test_getitem(a_shape, b_shape):
def test_getitem_nd(a_shape, b_shape):
for device in ['cuda', 'cpu']:
a = np.array(np.random.randn(*a_shape)).astype('float32')
b = np.array(np.random.randint(low=0, high=a_shape[0], size=b_shape)).astype('int32')
@@ -260,6 +261,145 @@ def test_meshgrid(shapes, indexing):
)


@pytest.mark.parametrize(
"a_shape, b_shape",
[
[[('a', 1), ('b', 1000)], [('c', 2), ('d', 3)]],
[[('a', 16), ('b', 1000)], [('c', 1), ('d', 2)]],
[[('a', 1), ('b', 1000), ('c', 1), ('d', 1)], [('e', 10), ('f', 2), ('g', 3)]],
[[('a', 16), ('b', 1000), ('c', 1), ('d', 1)], [('e', 3), ('f', 2), ('g', 4), ('h', 2)]],
[[('a', 1), ('b', 128), ('c', 128), ('d', 128)], [('e', 2), ('f', 2)]],
[[('a', 1), ('b', 128), ('c', 128), ('d', 128)], [('e', 2)]],
[[('a', 129)], [('b', 2)]],
[[('a', 129)], [('b', 2), ('c', 3)]],
],
)
def test_getitem_nd_dynamic(a_shape, b_shape):
for dev in ['cuda', 'cpu']:
a_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in a_shape]
a_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in a_shape]

b_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in b_shape]
b_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in b_shape]
a = np.array(np.random.randn(*a_concrete_shape)).astype('float32')
b = np.array(np.random.randint(low=0, high=a_concrete_shape[0], size=b_concrete_shape)).astype('int32')
numpy_result = a[b]
a_hidet = asarray(a).to(device=dev)
b_hidet = asarray(b).to(device=dev)
sym_a = symbol(a_symbolic_shape, dtype=a_hidet.dtype, device=a_hidet.device)
sym_b = symbol(b_symbolic_shape, dtype=b_hidet.dtype, device=b_hidet.device)
sym_result = sym_a[sym_b]

func = trace_from(sym_result, [sym_a, sym_b])
hidet_result = func(a_hidet, b_hidet).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=0, rtol=0)


@pytest.mark.parametrize(
"a_shape, b_shape, c_shape",
[
[[10, 150], [2, 3], [1, 3]],
[[16, 130], [1, 2], [1, 1]],
[[10, 140], [10, 3], [10, 3]],
[[16, 136], [1, 2], [4, 2]],
[[128, 128], [2, 2], [2, 2]],
[[10, 128], [2, 1], [1, 1]],
[[129, 138], [1, 2], [1, 1]],
[[129, 138], [2, 3], [2, 3]],
],
)
def test_getitem_advanced(a_shape, b_shape, c_shape):
for device in ['cuda', 'cpu']:
a = np.array(np.random.randn(*a_shape)).astype('float32')
b = np.array(np.random.randint(low=0, high=10, size=b_shape)).astype('int32')
c = np.array(np.random.randint(low=0, high=128, size=c_shape)).astype('int32')
atol = 0
rtol = 0

numpy_result = a[b, c]
a = asarray(a).to(device=device)
b = asarray(b).to(device=device)
c = asarray(c).to(device=device)

hidet_result = a[b, c].cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"a_shape, b_shape, c_shape",
[
[[('a', 10), ('b', 150), 10], [2, 3], [1, 3]],
[[('a', 16), ('b', 130)], [1, 2], [1, 1]],
[[('a', 10), ('b', 140), 10], [10, 3], [10, 3]],
[[('a', 16), ('b', 136), 11, 1], [1, 2], [4, 2]],
[[('a', 128), ('b', 128)], [2, 2], [2, 2]],
[[('a', 10), ('b', 128)], [2, 1], [1, 1]],
[[('a', 129), ('b', 138)], [1, 2], [1, 1]],
[[('a', 129), ('b', 138)], [2, 3], [2, 3]],
],
)
def test_getitem_advanced_dynamic(a_shape, b_shape, c_shape):
for dev in ['cuda', 'cpu']:
a_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in a_shape]
a_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in a_shape]

b_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in b_shape]
b_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in b_shape]

c_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in c_shape]
c_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in c_shape]

a = np.array(np.random.randn(*a_concrete_shape)).astype('float32')
b = np.array(np.random.randint(low=0, high=10, size=b_concrete_shape)).astype('int32')
c = np.array(np.random.randint(low=0, high=128, size=c_concrete_shape)).astype('int32')

numpy_result = a[b, c]
a_hidet = asarray(a).to(device=dev)
b_hidet = asarray(b).to(device=dev)
c_hidet = asarray(c).to(device=dev)

sym_a = symbol(a_symbolic_shape, dtype=a_hidet.dtype, device=a_hidet.device)
sym_b = symbol(b_symbolic_shape, dtype=b_hidet.dtype, device=b_hidet.device)
sym_c = symbol(c_symbolic_shape, dtype=c_hidet.dtype, device=c_hidet.device)
sym_result = sym_a[sym_b, sym_c]

func = trace_from(sym_result, [sym_a, sym_b, sym_c])
hidet_result = func(a_hidet, b_hidet, c_hidet).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=0, rtol=0)


def test_adv_indexing_with_slices():
def tests(x, ind1, ind2):
y1 = x[ind1, :]
y2 = x[ind1, ...]
y3 = x[..., ind1, :]
y4 = x[:, :, ind1, :]

y5 = x[ind1, :, ind2]
y6 = x[ind1, ..., ind2]
y7 = x[..., ind1, :, ind2]
y8 = x[:, :, ind1, ind2]

y9 = x[ind1, ind2, :]
y10 = x[ind1, ind2, ...]
y11 = x[ind1, ind2, :]
y12 = x[:, ind2, ind1, :]

y13 = x[:, :, ind1, :10]
return [y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13]

x = hidet.randn(shape=(10, 11, 12, 13), dtype='float32', device='cuda')
ind1 = hidet.randint(low=0, high=10, shape=(1, 1)).cuda()
ind2 = hidet.randint(low=0, high=10, shape=(1, 2)).cuda()

outs1 = tests(x, ind1, ind2)

x, ind1, ind2 = [t.cpu().numpy() for t in [x, ind1, ind2]]

outs2 = tests(x, ind1, ind2)
[np.testing.assert_allclose(actual=ho.cpu().numpy(), desired=no) for ho, no in zip(outs1, outs2)]


@pytest.mark.parametrize(
"input_shape, repeats, dim", [([2, 3, 4], 2, 0), ([1, 2, 9], 3, 1), ([1, 3, 4], 4, 2), ([1, 2, 3], 3, None)]
)

0 comments on commit 4706425

Please sign in to comment.