Skip to content

Commit

Permalink
[ONNX] Support complex in FX exporter
Browse files Browse the repository at this point in the history
ghstack-source-id: cea8f8c8b39132d368439fd5b9ad208bcdeb0606
Pull Request resolved: #100554
  • Loading branch information
titaiwangms committed Jun 19, 2023
1 parent 15eed5b commit 901ef00
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 1 deletion.
6 changes: 5 additions & 1 deletion test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
# "nn.functional.scaled_dot_product_attention" non-deterministic
"scatter_add",
"scatter_reduce",
"stft",
"unflatten",
"vstack", # aten::cat is invoked instead
]
Expand Down Expand Up @@ -481,6 +482,10 @@
variant_name="mean",
reason="ONNX doesn't support reduce='mean' option",
),
xfail(
"stft",
reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"),
),
xfail(
"unflatten", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Unflatten")
Expand Down Expand Up @@ -680,7 +685,6 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):

if dtype == torch.float32:
# Relax atol and rtol for float32 based on empirical results
# The current most relaxed values are for aten::stft
rtol = 1e-5
atol = 2e-5
else:
Expand Down
3 changes: 3 additions & 0 deletions torch/onnx/_internal/fx/op_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def generate_random_tensors(shape: torch.Size, dtype: torch.dtype):
return torch.where(
random_numbers > 0.5, torch.tensor(True), torch.tensor(False)
)
if dtype in _type_utils.COMPLEX_TO_FLOAT:
# ONNX does not support complex value, but support real representation
return torch.randn(shape + (2,), dtype=_type_utils.COMPLEX_TO_FLOAT[dtype])
return torch.randn(shape, dtype=dtype)


Expand Down
11 changes: 11 additions & 0 deletions torch/onnx/_internal/fx/passes/fx_to_onnxscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,17 @@ def _fill_tensor_shape_type(
onnxscript_value.dtype = torch.int64
elif isinstance(expected_value, torch.SymFloat):
onnxscript_value.dtype = torch.float32
elif expected_value.dtype in _type_utils.COMPLEX_TO_FLOAT:
# Like torch.view_as_real, we flatten complex tensors to real tensors with
# additional last dimension of 2
onnxscript_value.shape = tuple(
[dim if isinstance(dim, int) else None for dim in expected_value.size()]
+ [2]
)
# complex64 -> float32, complex128 -> float64, etc.
onnxscript_value.dtype = _type_utils.COMPLEX_TO_FLOAT[expected_value.dtype]
# Dispatcher needs to know the value is complex
onnxscript_value.is_complex = True
else:
# We set node output sizes to be dynamic to continue the model conversion,
# and inputs are also set to be dynamic in add_input().
Expand Down
5 changes: 5 additions & 0 deletions torch/onnx/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ def valid_torch_name(torch_name: Union[TorchName, str]) -> bool:
float: {"tensor(float16)", "tensor(float)", "tensor(double)"},
bool: {"tensor(int32)", "tensor(int64)", "tensor(bool)"},
}
COMPLEX_TO_FLOAT = {
torch.complex32: torch.float16,
torch.complex64: torch.float32,
torch.complex128: torch.float64, # NOTE: ORT doesn't support torch.float64
}

# NOTE: Belows are from torch/fx/node.py
BaseArgumentTypes = Union[
Expand Down

0 comments on commit 901ef00

Please sign in to comment.