Skip to content

Commit

Permalink
Fix unbind op.
Browse files Browse the repository at this point in the history
torch.unbind does not lower to aten.unbind so we don't need to map it.

The test passed with the previous implementation but I created a
slightly more efficient one anyway.
  • Loading branch information
cornmander committed Sep 17, 2024
1 parent 0da723f commit d05c292
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@
"take_along_dim",
"to_sparse", # We are not supporting sparse tensors yet.
"triu",
"unbind",
"unfold_copy",
"unfold",
"unique_consecutive",
Expand Down
6 changes: 1 addition & 5 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2396,13 +2396,9 @@ def _aten_trunc(a):
return jnp.trunc(a)


@op(torch.ops.aten.unbind)
@op(torch.ops.aten.unbind_copy)
def _aten_unbind(a, dim=0):
return tuple(
_aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim)
for i in range(a.shape[dim])
)
return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])]


# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d
Expand Down

0 comments on commit d05c292

Please sign in to comment.