Skip to content

Commit

Permalink
Add converter for harswish activation
Browse files Browse the repository at this point in the history
Had to add my owwn migraphx converter as the order of value return to clip would fail the test
  • Loading branch information
TedThemistokleous committed Oct 25, 2023
1 parent e237c6e commit 4cabff1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
26 changes: 26 additions & 0 deletions py/torch_migraphx/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,32 @@ def acc_ops_hard_sigmoid(mgx_module, node, args, kwargs):
return mgx_module.add_instruction(migraphx.op('clip'), [add, zeros, ones])


@migraphx_converter(acc_ops.hardswish)
def acc_ops_hard_swish(mgx_module, node, args, kwargs):

inp = kwargs['input']
dtype = get_arg_dtype(inp)
shape = inp.shape().lens()

alpha = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=shape),
[mgx_module.add_literal(torch.tensor([1 / 6], dtype=dtype).numpy())])

beta = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=shape),
[mgx_module.add_literal(torch.tensor([1 / 2], dtype=dtype).numpy())])

zeros = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=shape),
[mgx_module.add_literal(torch.tensor([0], dtype=dtype).numpy())])

mul = mgx_module.add_instruction(migraphx.op('mul'), [alpha, inp])
add = mgx_module.add_instruction(migraphx.op('add'), [beta, mul])

mul2 = mgx_module.add_instruction(migraphx.op('mul'), [add, inp])
return mgx_module.add_instruction(migraphx.op('clip'), [zeros, inp, mul2])


@migraphx_converter(acc_ops.softmax)
def acc_ops_softmax(mgx_module, node, args, kwargs):

Expand Down
10 changes: 9 additions & 1 deletion py/torch_migraphx/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,15 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs):

acc_kwargs = {'input': inp, 'negative_slope': neg_slope}
return acc_ops_converters.acc_ops_leaky_relu(mgx_module, node, (),
acc_kwargs)


@migraphx_converter(torch.ops.aten.hardswish.default)
def aten_ops_hardsigmoid(mgx_module, node, args, kwargs):
assert len(args) == 1
acc_kwargs = {"input": args[0]}

return acc_ops_converters.acc_ops_hard_swish(mgx_module, node, (),
acc_kwargs)


@migraphx_converter(torch.ops.aten.hardsigmoid.default)
Expand Down
5 changes: 5 additions & 0 deletions py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@ def hardsigmoid(*, input):
return nn.functional.hardsigmoid(input)


@register_acc_op
def hardswish(*, input):
return nn.functional.hardswish(input)


@register_acc_op_mapping(
op_and_target=("call_method", "softmax"),
arg_replacement_tuples=[
Expand Down
1 change: 1 addition & 0 deletions tests/dynamo/converters/test_activations_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_clamp(op_alias, inp_size):
torch.ops.aten.relu.default,
torch.ops.aten.tanh.default,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.sigmoid.default,
torch.ops.aten.gelu.default,
torch.ops.aten.silu.default,
Expand Down

0 comments on commit 4cabff1

Please sign in to comment.