From eba28f172c306492ca7010678fc62e1c49f70783 Mon Sep 17 00:00:00 2001 From: David Huang Date: Fri, 4 Oct 2024 16:16:33 -0700 Subject: [PATCH] [torch_xla2] Fix cholesky_inverse (#8214) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 8 ++++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index eb98549e933..8174c10273b 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -21,7 +21,6 @@ "cdist", "ceil", "cholesky", - "cholesky_inverse", "cholesky_solve", "complex", "diagonal_copy", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a30e85bae30..1cc30bfe0e4 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2041,6 +2041,14 @@ def _aten__pdist_forward(x, p=2): return condensed_dists +@op(torch.ops.aten.cholesky_inverse) +def _aten_cholesky_inverse(input, upper=False): + t = jnp.matrix_transpose(input) + if "complex" in str(input.dtype): + t = t.conjugate() + return jnp.linalg.inv(input @ t) + + # aten.cos @op(torch.ops.aten.cos) @op_base.promote_int_input