diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index ef60342b..ba1b6a20 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -589,6 +589,32 @@ def acc_ops_hard_sigmoid(mgx_module, node, args, kwargs): return mgx_module.add_instruction(migraphx.op('clip'), [add, zeros, ones]) +@migraphx_converter(acc_ops.hardswish) +def acc_ops_hard_swish(mgx_module, node, args, kwargs): + + inp = kwargs['input'] + dtype = get_arg_dtype(inp) + shape = inp.shape().lens() + + alpha = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=shape), + [mgx_module.add_literal(torch.tensor([1 / 6], dtype=dtype).numpy())]) + + beta = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=shape), + [mgx_module.add_literal(torch.tensor([1 / 2], dtype=dtype).numpy())]) + + zeros = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=shape), + [mgx_module.add_literal(torch.tensor([0], dtype=dtype).numpy())]) + + mul = mgx_module.add_instruction(migraphx.op('mul'), [alpha, inp]) + add = mgx_module.add_instruction(migraphx.op('add'), [beta, mul]) + + mul2 = mgx_module.add_instruction(migraphx.op('mul'), [add, inp]) + return mgx_module.add_instruction(migraphx.op('clip'), [zeros, inp, mul2]) + + @migraphx_converter(acc_ops.softmax) def acc_ops_softmax(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index fd3034b8..a61e5fd9 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -288,7 +288,15 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs): acc_kwargs = {'input': inp, 'negative_slope': neg_slope} return acc_ops_converters.acc_ops_leaky_relu(mgx_module, node, (), - acc_kwargs) + + +@migraphx_converter(torch.ops.aten.hardswish.default) +def aten_ops_hardsigmoid(mgx_module, node, args, kwargs): + assert len(args) == 1 + acc_kwargs = {"input": args[0]} + + return acc_ops_converters.acc_ops_hard_swish(mgx_module, node, (), + acc_kwargs) @migraphx_converter(torch.ops.aten.hardsigmoid.default) diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index f4af53b0..5b53276a 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -446,6 +446,11 @@ def hardsigmoid(*, input): return nn.functional.hardsigmoid(input) +@register_acc_op +def hardswish(*, input): + return nn.functional.hardswish(input) + + @register_acc_op_mapping( op_and_target=("call_method", "softmax"), arg_replacement_tuples=[ diff --git a/tests/dynamo/converters/test_activations_dynamo.py b/tests/dynamo/converters/test_activations_dynamo.py index b842cd7b..a7321419 100644 --- a/tests/dynamo/converters/test_activations_dynamo.py +++ b/tests/dynamo/converters/test_activations_dynamo.py @@ -23,6 +23,7 @@ def test_clamp(op_alias, inp_size): torch.ops.aten.relu.default, torch.ops.aten.tanh.default, torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardswish.default, torch.ops.aten.sigmoid.default, torch.ops.aten.gelu.default, torch.ops.aten.silu.default,