Skip to content

Commit

Permalink
support export hardsigmoid in torch<=1.8 (open-mmlab#169)
Browse files Browse the repository at this point in the history
* support export hardsigmoid in torch<=1.8

* fix lint
  • Loading branch information
q.yao authored Feb 24, 2022
1 parent 486d45e commit e9ee21f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmdeploy/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
adaptive_avg_pool2d__default,
adaptive_avg_pool3d__default)
from .grid_sampler import grid_sampler__default
from .hardsigmoid import hardsigmoid__default
from .instance_norm import instance_norm__tensorrt
from .lstm import generic_rnn__ncnn
from .squeeze import squeeze__default

__all__ = [
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'instance_norm__tensorrt', 'generic_rnn__ncnn', 'squeeze__default'
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default'
]
12 changes: 12 additions & 0 deletions mmdeploy/pytorch/ops/hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from mmdeploy.core import SYMBOLIC_REWRITER


@SYMBOLIC_REWRITER.register_symbolic(
'hardsigmoid', is_pytorch=True, arg_descriptors=['v'])
def hardsigmoid__default(ctx, g, self):
"""Support export hardsigmoid This rewrite enable export hardsigmoid in
torch<=1.8.2."""
return g.op('HardSigmoid', self, alpha_f=1 / 6)
7 changes: 7 additions & 0 deletions tests/test_pytorch/test_pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def test_squeeze(self):
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].attribute[0].ints == [0]
assert nodes[0].op_type == 'Squeeze'


def test_hardsigmoid():
x = torch.rand(1, 2, 3, 4)
model = torch.nn.Hardsigmoid().eval()
nodes = get_model_onnx_nodes(model, x)
assert nodes[0].op_type == 'HardSigmoid'

0 comments on commit e9ee21f

Please sign in to comment.