Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update aten operators #46

Merged
merged 14 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions py/torch_migraphx/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def acc_ops_elu(mgx_module, node, args, kwargs):


@migraphx_converter(acc_ops.selu)
def acc_ops_elu(mgx_module, node, args, kwargs):
def acc_ops_selu(mgx_module, node, args, kwargs):

inp = kwargs['input']
dtype = get_arg_dtype(inp)
Expand Down Expand Up @@ -442,7 +442,7 @@ def acc_ops_elu(mgx_module, node, args, kwargs):


@migraphx_converter(acc_ops.softsign)
def acc_ops_elu(mgx_module, node, args, kwargs):
def acc_ops_softsign(mgx_module, node, args, kwargs):

inp = kwargs['input']
dtype = get_arg_dtype(inp)
Expand Down Expand Up @@ -787,6 +787,26 @@ def acc_ops_argmax(mgx_module, node, args, kwargs):
return out


@migraphx_converter(acc_ops.argmin)
def acc_ops_argmin(mgx_module, node, args, kwargs):
inp = kwargs['input']
dim = kwargs["dim"]
keepdim = kwargs["keepdim"]

if dim is None:
assert not keepdim, "keepdim cannot be true when dim is None"
inp = acc_ops_flatten(mgx_module, node, (), {"input": inp})
dim = 0

out = mgx_module.add_instruction(migraphx.op('argmin', axis=dim), [inp])

if not keepdim:
out = mgx_module.add_instruction(migraphx.op('squeeze', axes=[dim]),
[out])

return out


@migraphx_converter(acc_ops.embedding)
def acc_ops_embedding(mgx_module, node, args, kwargs):
inp = kwargs['input']
Expand Down Expand Up @@ -980,6 +1000,54 @@ def acc_ops_maximum(mgx_module, node, args, kwargs):
return mgx_module.add_instruction(migraphx.op('max'), [inp, other])


@migraphx_converter(acc_ops.max)
def acc_ops_max(mgx_module, node, args, kwargs):
inp = kwargs['input']
in_shape = inp.shape().lens()

if 'dim' not in kwargs:
dims = list(range(len(in_shape)))
max_ = mgx_module.add_instruction(
migraphx.op('reduce_max', axes=dims), [inp])
return mgx_module.add_instruction(migraphx.op('squeeze', axes=dims), [max_])
else:
dims = kwargs['dim']
indicies = acc_ops_argmax(mgx_module, node, args, kwargs)
max_ = mgx_module.add_instruction(
migraphx.op('reduce_max', axes=[dims]), [inp])

if 'keepdim' in kwargs and kwargs['keepdim']:
return [max_, indicies]

max_ = mgx_module.add_instruction(
migraphx.op('reduce_max', axes=[dims]), [inp])
return [mgx_module.add_instruction(migraphx.op('squeeze', axes=[dims]), [max_]), indicies]


@migraphx_converter(acc_ops.min)
def acc_ops_min(mgx_module, node, args, kwargs):
inp = kwargs['input']
in_shape = inp.shape().lens()

if 'dim' not in kwargs:
dims = list(range(len(in_shape)))
min_ = mgx_module.add_instruction(
migraphx.op('reduce_min', axes=dims), [inp])
return mgx_module.add_instruction(migraphx.op('squeeze', axes=dims), [min_])
else:
dims = kwargs['dim']
indicies = acc_ops_argmin(mgx_module, node, args, kwargs)
min_ = mgx_module.add_instruction(
migraphx.op('reduce_min', axes=[dims]), [inp])

if 'keepdim' in kwargs and kwargs['keepdim']:
return [min_, indicies]

min_ = mgx_module.add_instruction(
migraphx.op('reduce_min', axes=[dims]), [inp])
return [mgx_module.add_instruction(migraphx.op('squeeze', axes=[dims]), [min_]), indicies]


@migraphx_converter(acc_ops.mean)
def acc_ops_mean(mgx_module, node, args, kwargs):

Expand Down
109 changes: 105 additions & 4 deletions py/torch_migraphx/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import migraphx
import torch
from typing import cast, Iterable, List, Sequence
from ..converter_registry import migraphx_converter
from torch_migraphx.fx.converters import acc_ops_converters
from ..utils import torch_dtype_to_mgx_enum
Expand Down Expand Up @@ -248,15 +249,19 @@ def aten_ops_split(mgx_module, node, args, kwargs):
return slice_nodes


@migraphx_converter(torch.ops.aten.hardtanh.default)
@migraphx_converter(torch.ops.aten.clamp.default)
def aten_ops_clamp(mgx_module, node, args, kwargs):
assert len(args) >= 1
min_, max_ = None, None
if node.target == torch.ops.aten.hardtanh.default:
min_, max_ = -1, 1

acc_kwargs = {
"input": args[0],
"min": args[1] if len(args) >= 2 else None,
"max": args[2] if len(args) == 3 else None
"min": args[1] if len(args) >= 2 else min_,
"max": args[2] if len(args) == 3 else max_
}

return acc_ops_converters.acc_ops_clamp(mgx_module, node, (), acc_kwargs)


Expand All @@ -276,6 +281,16 @@ def aten_ops_tanh(mgx_module, node, args, kwargs):
return acc_ops_converters.acc_ops_tanh(mgx_module, node, (), acc_kwargs)


@migraphx_converter(torch.ops.aten.elu.default)
def aten_ops_elu(mgx_module, node, args, kwargs):
assert len(args) >= 1
inp = args[0]
alpha = 1.0 if len(args) < 2 else args[1]

acc_kwargs = {'input': inp, 'alpha': alpha}
return acc_ops_converters.acc_ops_elu(mgx_module, node, (), acc_kwargs)


@migraphx_converter(torch.ops.aten.leaky_relu.default)
def aten_ops_leaky_relu(mgx_module, node, args, kwargs):
assert len(args) >= 1
Expand All @@ -287,6 +302,17 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs):
acc_kwargs)


@migraphx_converter(torch.ops.aten.hardswish.default)
TedThemistokleous marked this conversation as resolved.
Show resolved Hide resolved
def aten_ops_hardswish(mgx_module, node, args, kwargs):
assert len(args) == 1
acc_kwargs = {"input": args[0]}

hard_sig = acc_ops_converters.acc_ops_hard_sigmoid(mgx_module, node, (), acc_kwargs)

mul_kwargs = {"input": args[0], "other": hard_sig}
return acc_ops_converters.acc_ops_mul(mgx_module, node, (), mul_kwargs)


@migraphx_converter(torch.ops.aten.hardsigmoid.default)
def aten_ops_hardsigmoid(mgx_module, node, args, kwargs):
assert len(args) == 1
Expand Down Expand Up @@ -718,7 +744,6 @@ def aten_ops_embedding(mgx_module, node, args, kwargs):
return acc_ops_converters.acc_ops_embedding(mgx_module, node, (),
acc_kwargs)


@migraphx_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(mgx_module, node, args, kwargs):
assert len(args) >= 1
Expand All @@ -731,7 +756,83 @@ def aten_ops_argmax(mgx_module, node, args, kwargs):

return acc_ops_converters.acc_ops_argmax(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.argmin.default)
def aten_ops_argmin(mgx_module, node, args, kwargs):
assert len(args) >= 1

acc_kwargs = {
"input": args[0],
"dim": args[1] if len(args) >= 2 else None,
"keepdim": args[2] if len(args) >= 3 else False
}

return acc_ops_converters.acc_ops_argmin(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.max.default)
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
@migraphx_converter(torch.ops.aten.max.dim)
def aten_ops_max(mgx_module, node, args, kwargs):
assert len(args) >= 1

acc_kwargs = {
"input": args[0],
"keepdim": False,
}

if len(args) >= 2:
acc_kwargs["dim"] = args[1]

if len(args) >= 3:
acc_kwargs["keepdim"] = args[2]

return acc_ops_converters.acc_ops_max(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.min.default)
@migraphx_converter(torch.ops.aten.min.dim)
def aten_ops_min(mgx_module, node, args, kwargs):
assert len(args) >= 1

acc_kwargs = {
"input": args[0],
"keepdim": False
}

if len(args) >= 2:
acc_kwargs["dim"] = args[1]

if len(args) >= 3:
acc_kwargs["keepdim"] = args[2]

return acc_ops_converters.acc_ops_min(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.stack.default)
def aten_ops_stack(mgx_module, node, args, kwargs):
assert len(args) >= 1

"""
Map aten.stack to unsqueeze + cat acc ops.
"""
inputs = args[0]
assert isinstance(inputs, Sequence)

dims = args[1] if len(args) > 1 else 0

unsqueeze_kwargs={
"dim": dims
}
cat_kwargs={
"dim": dims
}

unsqueeze_nodes = []
for i, t in enumerate(inputs):
unsqueeze_kwargs["input"] = t
unsq_res = acc_ops_converters.acc_ops_unsqueeze(mgx_module, node, (), unsqueeze_kwargs)
unsqueeze_nodes.append(unsq_res)

cat_kwargs["tensors"] = unsqueeze_nodes
return acc_ops_converters.acc_ops_cat(mgx_module, node, (), cat_kwargs)


@migraphx_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(mgx_module, node, args, kwargs):
assert len(args) >= 3
Expand Down
56 changes: 56 additions & 0 deletions py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,54 @@ def maximum(*, input, other):
return torch.maximum(input=input, other=other)


@register_acc_op_mapping(
op_and_target=("call_method", "max"),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op_mapping(
op_and_target=("call_function", torch.max),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op
def max(*, input, dim=None, keepdim=False):
if dim is not None:
return torch.max(input, dim=dim, keepdim=keepdim)
else:
return torch.max(input)


@register_acc_op_mapping(
op_and_target=("call_method", "min"),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op_mapping(
op_and_target=("call_function", torch.min),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op
def min(*, input, dim=None, keepdim=False):
if dim is not None:
return torch.min(input, dim=dim, keepdim=keepdim)
else:
return torch.min(input)


@register_acc_op_mapping(op_and_target=("call_function", operator.getitem))
@register_acc_op
def getitem(*, input, idx):
Expand Down Expand Up @@ -914,6 +962,14 @@ def argmax(*, input, dim, keepdim):
return torch.argmax(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.argmin))
@register_acc_op_mapping(op_and_target=("call_method", "argmin"))
@register_acc_op
def argmin(*, input, dim, keepdim):
return torch.argmin(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_mapping(op_and_target=("call_function",
nn.functional.embedding))
@register_acc_op
Expand Down
12 changes: 9 additions & 3 deletions tests/dynamo/converters/test_activations_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
pytest.skip(allow_module_level=True)


@pytest.mark.parametrize('op_alias', [torch.ops.aten.clamp.default])
@pytest.mark.parametrize('op_alias', [torch.ops.aten.clamp.default,
torch.ops.aten.hardtanh.default])
@pytest.mark.parametrize('inp_size', [(4, 2, 7), (128, 2048),
(1, 3, 6, 128, 128)])
def test_clamp(op_alias, inp_size):
Expand All @@ -20,8 +21,10 @@ def test_clamp(op_alias, inp_size):

@pytest.mark.parametrize('op_alias', [
torch.ops.aten.relu.default,
torch.ops.aten.elu.default,
torch.ops.aten.tanh.default,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.sigmoid.default,
torch.ops.aten.gelu.default,
torch.ops.aten.silu.default,
Expand All @@ -33,13 +36,16 @@ def test_noparam_activation_funcs(op_alias):
verify_outputs(mod, mgx_mod, inp)


@pytest.mark.parametrize('op_alias', [torch.ops.aten.leaky_relu.default])
shivadbhavsar marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize('op_alias', [
torch.ops.aten.elu.default,
torch.ops.aten.leaky_relu.default,
])
@pytest.mark.parametrize('inp_size, alpha', [
((11, 3, 9), 0.1),
((6, 12, 32, 6), 0.05),
((2, ), 0),
])
def test_leaky_relu(op_alias, inp_size, alpha):
def test_single_param_activation_funcs(op_alias, inp_size, alpha):
inp = torch.randn(inp_size).cuda()
mod = FuncModule(op_alias, alpha).cuda()
mgx_mod = convert_to_mgx(mod, [inp])
Expand Down
9 changes: 6 additions & 3 deletions tests/dynamo/converters/test_maxmin_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
pytest.skip(allow_module_level=True)


@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmax.default])
@pytest.mark.parametrize('op_alias', [
torch.ops.aten.argmax.default,
torch.ops.aten.argmax.default,
])
@pytest.mark.parametrize('dim, keepdim', [
(2, True),
(-1, False),
(0, False),
])
def test_argmax(op_alias, dim, keepdim):
def test_argmax_argmin(op_alias, dim, keepdim):
inp = torch.randn(10, 2, 12, 8, 14).cuda()
mod = FuncModule(torch.argmax, dim, keepdim)
mod = FuncModule(op_alias, dim, keepdim)
mgx_mod = convert_to_mgx(mod, [inp])
verify_outputs(mod, mgx_mod, inp)
Loading