diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2b88ae324fa..ba7696baf74 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -26,7 +26,6 @@ "combinations", "complex", "diag_embed", - "diagflat", "diagonal_copy", "diagonal_scatter", "diff", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 5e094d39927..1388e04e87d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1928,6 +1928,17 @@ def _aten_diagonal(input, offset=0, dim1=0, dim2=1): return jnp.diagonal(input, offset, dim1, dim2) +# aten.diagflat +@op(torch.ops.aten.diagflat) +def _aten_diagflat(input, offset=0): + return jnp.diagflat(jnp.array(input), offset) + + +@op(torch.ops.aten.movedim) +def _aten_movedim(input, source, destination): + return jnp.moveaxis(input, source, destination) + + # aten.eq @op(torch.ops.aten.eq) def _aten_eq(input1, input2):