Skip to content

Commit

Permalink
feat: support Dynamo converter for torch.ops.aten.erf.default op
Browse files Browse the repository at this point in the history
Dynamo converter support for torch.ops.aten.erf.default op
  • Loading branch information
bowang007 authored Sep 22, 2023
2 parents e6e8099 + 3c4c2fe commit 670d2be
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 3 deletions.
23 changes: 20 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,23 @@ def aten_ops_squeeze(
return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])


@dynamo_tensorrt_converter(torch.ops.aten.erf.default) # type: ignore[misc]
def aten_ops_erf(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.erf(
network,
target,
SourceIR.ATEN,
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) # type: ignore[misc]
def aten_ops_unsqueeze(
network: TRTNetwork,
Expand Down Expand Up @@ -357,14 +374,14 @@ def aten_ops_softmax(

@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
)
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
)
) # type: ignore[misc]
@dynamo_tensorrt_converter(
torch.ops.aten.split_with_sizes.default,
capability_validator=dynamic_unsupported_with_args([1]),
)
) # type: ignore[misc]
def aten_ops_split(
network: TRTNetwork,
target: Target,
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,20 @@ def neg(
return convert_unary(
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
)


def erf(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(network, input_val, trt.float32, name)

return convert_unary(
network, target, source_ir, name, trt.UnaryOperation.ERF, input_val
)
52 changes: 52 additions & 0 deletions tests/py/dynamo/conversion/test_erf_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestErfConverter(DispatchTestCase):
@parameterized.expand(
[
("2d_dim_dtype_float", (2, 2), torch.float),
("3d_dim_dtype_float", (2, 2, 2), torch.float),
("2d_dim_dtype_half", (2, 2), torch.half),
("3d_dim_dtype_half", (2, 2, 2), torch.half),
]
)
def test_erf_float(self, _, x, type):
class erf(nn.Module):
def forward(self, input):
return torch.erf(input)

inputs = [torch.randn(x, dtype=type)]
self.run_test(
erf(),
inputs,
precision=type,
expected_ops={torch.ops.aten.erf.default},
)

@parameterized.expand(
[
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
]
)
def test_erf_int(self, _, x, type, min, max):
class erf(nn.Module):
def forward(self, input):
return torch.erf(input)

inputs = [torch.randint(min, max, x, dtype=type)]
self.run_test(
erf(),
inputs,
expected_ops={torch.ops.aten.erf.default},
)


if __name__ == "__main__":
run_tests()

0 comments on commit 670d2be

Please sign in to comment.