Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0c8c904
suddenly copy.default is unsupported
hugolatendresse Mar 16, 2025
3ae2b79
Merge branch 'main' into split
hugolatendresse Mar 16, 2025
f7f0637
wip
hugolatendresse Mar 16, 2025
7e4cf05
Able to split uneven tensors!
hugolatendresse Mar 16, 2025
dcbee0c
split size test passes!
hugolatendresse Mar 16, 2025
75890ce
test sizes and lists
hugolatendresse Mar 16, 2025
5a8eab1
just one func
hugolatendresse Mar 16, 2025
c771389
cleanup
hugolatendresse Mar 16, 2025
2fbe4c1
no assert
hugolatendresse Mar 16, 2025
e5095b8
linting
hugolatendresse Mar 16, 2025
490a454
chunk
hugolatendresse Mar 16, 2025
1ef561e
Merge branch 'main' of https://github.com/apache/tvm into split_uneven
hugolatendresse Mar 20, 2025
ec7311c
remove unsused modulo
hugolatendresse Mar 20, 2025
0744701
fixed first test
hugolatendresse Mar 20, 2025
40f1711
fixed second test and lint
hugolatendresse Mar 20, 2025
ca2bf9a
Merge branch 'main' into chunk2
hugolatendresse Mar 24, 2025
e20a43f
Merge branch 'split_uneven' into chunk2
hugolatendresse Mar 24, 2025
7bdc0bf
merge main
hugolatendresse Mar 24, 2025
4582c3a
linting
hugolatendresse Mar 24, 2025
afa793a
fix one test
hugolatendresse Mar 24, 2025
d71b518
chunk not passing anymore
hugolatendresse Mar 24, 2025
e95aef6
get_item error
hugolatendresse Mar 24, 2025
bc50446
chunk unit tests
hugolatendresse Mar 24, 2025
9951924
fix conflicts
hugolatendresse Mar 26, 2025
00f04b1
Merge branch 'chunk2' into index_tensor
hugolatendresse Mar 27, 2025
db5ec01
index select test passes
hugolatendresse Mar 28, 2025
859ca17
merge main
hugolatendresse Mar 28, 2025
c39e6e1
fix test
hugolatendresse Mar 28, 2025
f8d50f2
cleanup
hugolatendresse Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,12 @@ def _fill(self, node: fx.Node) -> relax.Var:
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))

def _index_select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.take(x, index, dim))

def _new_ones(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
self_var = args[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def create_convert_map(
"reshape.default": self._reshape,
# tensor creation
"_to_copy.default": self._to_copy,
"lift_fresh_copy.default": self._to_copy,
"detach.default": self._detach,
"detach_.default": self._detach,
"arange.start": self._arange,
Expand All @@ -398,6 +397,8 @@ def create_convert_map(
"empty.memory_format": self._empty,
"empty_like.default": self._empty_like,
"fill.Scalar": self._fill,
"index_select.default": self._index_select,
"lift_fresh_copy.default": self._to_copy,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
# other
Expand Down
6 changes: 0 additions & 6 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,6 @@ def _full(self, node: fx.Node) -> relax.Var:
)
)

def _index_select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.take(x, index, dim))

def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
Expand Down
12 changes: 12 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,5 +467,17 @@ def forward(self, x):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_index_select(target, dev):
class IndexSelectModel(nn.Module):
def forward(self, x):
indices = torch.tensor([0, 2])
return torch.index_select(x, 0, indices)

raw_data = np.random.rand(3, 4).astype("float32")
torch_module = IndexSelectModel().eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


if __name__ == "__main__":
tvm.testing.main()
Loading