diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 94c6f5c00f4..afb4e2feef8 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -320,7 +320,10 @@ def _extract_input_args(exported_model, options): def _run_decompositions(exported_model): - decomp_table = core_aten_decompositions() + if hasattr(torch.export, 'core_op_decompositions'): + decomp_table = torch.export.core_op_decompositions() + else: + decomp_table = core_aten_decompositions() decomp_table.update(_extra_decompositions) decomp_table[torch.ops.aten._safe_softmax.default] = torch.softmax exported_model = exported_model.run_decompositions(decomp_table)