Skip to content

Commit 0ee3b84

Browse files
nkaretnikovpytorchmergebot
authored andcommitted
[pt2] add meta for cholesky_inverse (pytorch#106120)
Pull Request resolved: pytorch#106120 Approved by: https://github.com/ezyang
1 parent 8075588 commit 0ee3b84

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

test/functorch/test_aotdispatch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,7 +2750,6 @@ def forward(self, x):
27502750
skip('max_pool2d_with_indices_backward'),
27512751

27522752
# Worked with real but not with fake
2753-
xfail('cholesky_inverse'),
27542753
xfail('_segment_reduce', 'lengths'),
27552754
skip('nn.functional.nll_loss', ''), # UBSAN failure!
27562755

@@ -2794,7 +2793,6 @@ def forward(self, x):
27942793
symbolic_aot_autograd_failures = {
27952794
xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
27962795
xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
2797-
xfail('cholesky_inverse', ''), # could not find kernel
27982796
xfail('combinations', ''), # aten.masked_select.default
27992797
xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
28002798
xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition

test/test_meta.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,6 @@ def run_meta_crossref(
633633
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
634634
torch.nn.functional.one_hot : {i64},
635635
torch._segment_reduce : {f64, f16, bf16, f32},
636-
torch.cholesky_inverse : {f64, f32, c128, c64},
637636
torch.linalg.eig : {f64, f32, c128, c64},
638637
torch.linalg.eigvals : {f64, f32, c128, c64},
639638
torch.linalg.lstsq : {f64, f32, c128, c64},
@@ -803,8 +802,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
803802
# these always fail
804803
meta_dispatch_expected_failures = {
805804
aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense'
806-
aten.cholesky_inverse.default : {c64, c128, f64, f32},
807-
aten.cholesky_inverse.out : {c64, c128, f64, f32},
808805
aten.geqrf.default : {c64, c128, f64, f32},
809806
aten.linalg_eig.default : {c64, c128, f64, f32},
810807
aten.linalg_lstsq.default : {c64, c128, f64, f32},

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,6 @@ def test_refs_are_in_decomp_table(self, op):
19351935

19361936
fake_skips = (
19371937
"aminmax", # failing input
1938-
"cholesky_inverse", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend
19391938
"cov", # aweights cannot be negtaive
19401939
"istft", # window overlap add min: 0
19411940
"linalg.eigvals", # The tensor has a non-zero number of elements, but its data is not allocated yet

test/test_proxy_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1521,7 +1521,6 @@ def f(t):
15211521
fake_tensor_failures = {
15221522
# FakeTensor fallback doesn't work
15231523
xfail('_segment_reduce', 'lengths'),
1524-
xfail('cholesky_inverse'),
15251524
# cannot do these as they rely on tensor data
15261525
xfail('repeat_interleave'),
15271526
# ASAN failures due to divide by 0

torch/_meta_registrations.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,13 @@ def cholesky(self: Tensor, upper: bool = False) -> Tensor:
584584
return cloneBatchedColumnMajor(self)
585585

586586

587+
@register_meta(aten.cholesky_inverse)
588+
@out_wrapper()
589+
def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
590+
squareCheckInputs(self, "cholesky_inverse")
591+
return cloneBatchedColumnMajor(self)
592+
593+
587594
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
588595
@register_meta(aten.linalg_cholesky_ex.default)
589596
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):

0 commit comments

Comments
 (0)