From e237c6e91ccafaf62003ef99c014db421c6e59f8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 24 Oct 2023 14:38:34 -0700 Subject: [PATCH 01/14] Update aten.hardtanh support Add converter for hardtanh similar to that of clamp --- py/torch_migraphx/fx/converters/aten_ops_converters.py | 10 +++++++--- tests/dynamo/converters/test_activations_dynamo.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 17adfa36..fd3034b8 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -248,15 +248,19 @@ def aten_ops_split(mgx_module, node, args, kwargs): return slice_nodes +@migraphx_converter(torch.ops.aten.hardtanh.default) @migraphx_converter(torch.ops.aten.clamp.default) def aten_ops_clamp(mgx_module, node, args, kwargs): assert len(args) >= 1 + min_, max_ = None, None + if node.target == torch.ops.aten.hardtanh.default: + min_, max_ = -1, 1 + acc_kwargs = { "input": args[0], - "min": args[1] if len(args) >= 2 else None, - "max": args[2] if len(args) == 3 else None + "min": args[1] if len(args) >= 2 else min_, + "max": args[2] if len(args) == 3 else max_ } - return acc_ops_converters.acc_ops_clamp(mgx_module, node, (), acc_kwargs) diff --git a/tests/dynamo/converters/test_activations_dynamo.py b/tests/dynamo/converters/test_activations_dynamo.py index 6842655f..b842cd7b 100644 --- a/tests/dynamo/converters/test_activations_dynamo.py +++ b/tests/dynamo/converters/test_activations_dynamo.py @@ -7,7 +7,8 @@ pytest.skip(allow_module_level=True) -@pytest.mark.parametrize('op_alias', [torch.ops.aten.clamp.default]) +@pytest.mark.parametrize('op_alias', [torch.ops.aten.clamp.default, + torch.ops.aten.hardtanh.default]) @pytest.mark.parametrize('inp_size', [(4, 2, 7), (128, 2048), (1, 3, 6, 128, 128)]) def test_clamp(op_alias, inp_size): From e9fe73169ba14632ab07fbfbc74569a46cd839e5 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Oct 2023 12:04:15 -0700 Subject: [PATCH 02/14] Add converter for harswish activation Had to add my owwn migraphx converter as the order of value return to clip would fail the test --- .../fx/converters/acc_ops_converters.py | 26 +++++++++++++++++++ .../fx/converters/aten_ops_converters.py | 8 ++++++ .../fx/tracer/acc_tracer/acc_ops.py | 5 ++++ .../converters/test_activations_dynamo.py | 1 + 4 files changed, 40 insertions(+) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index ef60342b..ba1b6a20 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -589,6 +589,32 @@ def acc_ops_hard_sigmoid(mgx_module, node, args, kwargs): return mgx_module.add_instruction(migraphx.op('clip'), [add, zeros, ones]) +@migraphx_converter(acc_ops.hardswish) +def acc_ops_hard_swish(mgx_module, node, args, kwargs): + + inp = kwargs['input'] + dtype = get_arg_dtype(inp) + shape = inp.shape().lens() + + alpha = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=shape), + [mgx_module.add_literal(torch.tensor([1 / 6], dtype=dtype).numpy())]) + + beta = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=shape), + [mgx_module.add_literal(torch.tensor([1 / 2], dtype=dtype).numpy())]) + + zeros = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=shape), + [mgx_module.add_literal(torch.tensor([0], dtype=dtype).numpy())]) + + mul = mgx_module.add_instruction(migraphx.op('mul'), [alpha, inp]) + add = mgx_module.add_instruction(migraphx.op('add'), [beta, mul]) + + mul2 = mgx_module.add_instruction(migraphx.op('mul'), [add, inp]) + return mgx_module.add_instruction(migraphx.op('clip'), [zeros, inp, mul2]) + + @migraphx_converter(acc_ops.softmax) def acc_ops_softmax(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index fd3034b8..4c6da8c1 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -290,6 +290,14 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_leaky_relu(mgx_module, node, (), acc_kwargs) +@migraphx_converter(torch.ops.aten.hardswish.default) +def aten_ops_hardswish(mgx_module, node, args, kwargs): + assert len(args) == 1 + acc_kwargs = {"input": args[0]} + + return acc_ops_converters.acc_ops_hard_swish(mgx_module, node, (), + acc_kwargs) + @migraphx_converter(torch.ops.aten.hardsigmoid.default) def aten_ops_hardsigmoid(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index f4af53b0..5b53276a 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -446,6 +446,11 @@ def hardsigmoid(*, input): return nn.functional.hardsigmoid(input) +@register_acc_op +def hardswish(*, input): + return nn.functional.hardswish(input) + + @register_acc_op_mapping( op_and_target=("call_method", "softmax"), arg_replacement_tuples=[ diff --git a/tests/dynamo/converters/test_activations_dynamo.py b/tests/dynamo/converters/test_activations_dynamo.py index b842cd7b..a7321419 100644 --- a/tests/dynamo/converters/test_activations_dynamo.py +++ b/tests/dynamo/converters/test_activations_dynamo.py @@ -23,6 +23,7 @@ def test_clamp(op_alias, inp_size): torch.ops.aten.relu.default, torch.ops.aten.tanh.default, torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardswish.default, torch.ops.aten.sigmoid.default, torch.ops.aten.gelu.default, torch.ops.aten.silu.default, From b07ad2e53a80e2f2b965d5bd53709654687c0eb5 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Oct 2023 13:37:25 -0700 Subject: [PATCH 03/14] fixup! Add converter for harswish activation --- .../fx/converters/acc_ops_converters.py | 26 ------------------- .../fx/converters/aten_ops_converters.py | 7 +++-- .../fx/tracer/acc_tracer/acc_ops.py | 5 ---- 3 files changed, 5 insertions(+), 33 deletions(-) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index ba1b6a20..ef60342b 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -589,32 +589,6 @@ def acc_ops_hard_sigmoid(mgx_module, node, args, kwargs): return mgx_module.add_instruction(migraphx.op('clip'), [add, zeros, ones]) -@migraphx_converter(acc_ops.hardswish) -def acc_ops_hard_swish(mgx_module, node, args, kwargs): - - inp = kwargs['input'] - dtype = get_arg_dtype(inp) - shape = inp.shape().lens() - - alpha = mgx_module.add_instruction( - migraphx.op('multibroadcast', out_lens=shape), - [mgx_module.add_literal(torch.tensor([1 / 6], dtype=dtype).numpy())]) - - beta = mgx_module.add_instruction( - migraphx.op('multibroadcast', out_lens=shape), - [mgx_module.add_literal(torch.tensor([1 / 2], dtype=dtype).numpy())]) - - zeros = mgx_module.add_instruction( - migraphx.op('multibroadcast', out_lens=shape), - [mgx_module.add_literal(torch.tensor([0], dtype=dtype).numpy())]) - - mul = mgx_module.add_instruction(migraphx.op('mul'), [alpha, inp]) - add = mgx_module.add_instruction(migraphx.op('add'), [beta, mul]) - - mul2 = mgx_module.add_instruction(migraphx.op('mul'), [add, inp]) - return mgx_module.add_instruction(migraphx.op('clip'), [zeros, inp, mul2]) - - @migraphx_converter(acc_ops.softmax) def acc_ops_softmax(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 4c6da8c1..c0f83e32 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -290,13 +290,16 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_leaky_relu(mgx_module, node, (), acc_kwargs) + @migraphx_converter(torch.ops.aten.hardswish.default) def aten_ops_hardswish(mgx_module, node, args, kwargs): assert len(args) == 1 acc_kwargs = {"input": args[0]} - return acc_ops_converters.acc_ops_hard_swish(mgx_module, node, (), - acc_kwargs) + hard_sig = acc_ops_converters.acc_ops_hard_sigmoid(mgx_module, node, (), acc_kwargs) + + mul_kwargs = {"input": args[0], "other": hard_sig} + return acc_ops_converters.acc_ops_mul(mgx_module, node, (), mul_kwargs) @migraphx_converter(torch.ops.aten.hardsigmoid.default) diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index 5b53276a..f4af53b0 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -446,11 +446,6 @@ def hardsigmoid(*, input): return nn.functional.hardsigmoid(input) -@register_acc_op -def hardswish(*, input): - return nn.functional.hardswish(input) - - @register_acc_op_mapping( op_and_target=("call_method", "softmax"), arg_replacement_tuples=[ From 7d754204faa2dbfa98bd422bac607bb26e154673 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Oct 2023 13:58:03 -0700 Subject: [PATCH 04/14] Fix names for selu & softsign acc_ops converters Was breaking when trying to use elu converter for aten elu converter op --- py/torch_migraphx/fx/converters/acc_ops_converters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index ef60342b..9a3ace7c 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -399,7 +399,7 @@ def acc_ops_elu(mgx_module, node, args, kwargs): @migraphx_converter(acc_ops.selu) -def acc_ops_elu(mgx_module, node, args, kwargs): +def acc_ops_selu(mgx_module, node, args, kwargs): inp = kwargs['input'] dtype = get_arg_dtype(inp) @@ -442,7 +442,7 @@ def acc_ops_elu(mgx_module, node, args, kwargs): @migraphx_converter(acc_ops.softsign) -def acc_ops_elu(mgx_module, node, args, kwargs): +def acc_ops_softsign(mgx_module, node, args, kwargs): inp = kwargs['input'] dtype = get_arg_dtype(inp) From bb8ff6ff586bf9f543d319d7d2a82d72f62d0ac4 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Oct 2023 13:59:38 -0700 Subject: [PATCH 05/14] Add elu aten converter op - Test for both non parameterize and parametertized for alpha/default values --- py/torch_migraphx/fx/converters/aten_ops_converters.py | 10 ++++++++++ tests/dynamo/converters/test_activations_dynamo.py | 6 +++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index c0f83e32..81a5fb01 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -280,6 +280,16 @@ def aten_ops_tanh(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_tanh(mgx_module, node, (), acc_kwargs) +@migraphx_converter(torch.ops.aten.elu.default) +def aten_ops_elu(mgx_module, node, args, kwargs): + assert len(args) >= 1 + inp = args[0] + alpha = 1.0 if len(args) < 2 else args[1] + + acc_kwargs = {'input': inp, 'alpha': alpha} + return acc_ops_converters.acc_ops_elu(mgx_module, node, (), acc_kwargs) + + @migraphx_converter(torch.ops.aten.leaky_relu.default) def aten_ops_leaky_relu(mgx_module, node, args, kwargs): assert len(args) >= 1 diff --git a/tests/dynamo/converters/test_activations_dynamo.py b/tests/dynamo/converters/test_activations_dynamo.py index a7321419..5f7a5d7d 100644 --- a/tests/dynamo/converters/test_activations_dynamo.py +++ b/tests/dynamo/converters/test_activations_dynamo.py @@ -21,6 +21,7 @@ def test_clamp(op_alias, inp_size): @pytest.mark.parametrize('op_alias', [ torch.ops.aten.relu.default, + torch.ops.aten.elu.default, torch.ops.aten.tanh.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, @@ -35,7 +36,10 @@ def test_noparam_activation_funcs(op_alias): verify_outputs(mod, mgx_mod, inp) -@pytest.mark.parametrize('op_alias', [torch.ops.aten.leaky_relu.default]) +@pytest.mark.parametrize('op_alias', [ + torch.ops.aten.elu.default, + torch.ops.aten.leaky_relu.default, +]) @pytest.mark.parametrize('inp_size, alpha', [ ((11, 3, 9), 0.1), ((6, 12, 32, 6), 0.05), From 22f15e3bb27353e2ff2995c70eed7d88cade891a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 25 Oct 2023 16:39:58 -0700 Subject: [PATCH 06/14] Add aten.max for max operator Handle getting max value from input tensor. --- .../fx/converters/aten_ops_converters.py | 12 +++++++++++- tests/dynamo/converters/test_maxmin_dynamo.py | 5 ++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 81a5fb01..e7c587fa 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -743,7 +743,6 @@ def aten_ops_embedding(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_embedding(mgx_module, node, (), acc_kwargs) - @migraphx_converter(torch.ops.aten.argmax.default) def aten_ops_argmax(mgx_module, node, args, kwargs): assert len(args) >= 1 @@ -756,6 +755,17 @@ def aten_ops_argmax(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_argmax(mgx_module, node, (), acc_kwargs) +@migraphx_converter(torch.ops.aten.max.default) +def aten_ops_max(mgx_module, node, args, kwargs): + assert len(args) >= 1 + + acc_kwargs = { + "input": args[0], + "dim": args[1] if len(args) >= 2 else None, + "keepdim": args[2] if len(args) >= 3 else False + } + + return acc_ops_converters.acc_ops_maximum(mgx_module, node, (), acc_kwargs) @migraphx_converter(torch.ops.aten.as_strided.default) def aten_ops_as_strided(mgx_module, node, args, kwargs): diff --git a/tests/dynamo/converters/test_maxmin_dynamo.py b/tests/dynamo/converters/test_maxmin_dynamo.py index 8a7b845c..d2b077d7 100644 --- a/tests/dynamo/converters/test_maxmin_dynamo.py +++ b/tests/dynamo/converters/test_maxmin_dynamo.py @@ -7,7 +7,10 @@ pytest.skip(allow_module_level=True) -@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmax.default]) +@pytest.mark.parametrize('op_alias', + [torch.ops.aten.argmax.default, + torch.ops.aten.max.default, +]) @pytest.mark.parametrize('dim, keepdim', [ (2, True), (-1, False), From b154ac2612bd4c2733482cbdbb0aa7649062c2d0 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 30 Oct 2023 12:05:46 -0700 Subject: [PATCH 07/14] fixup! Add aten.max for max operator --- py/torch_migraphx/fx/converters/acc_ops_converters.py | 6 ++++++ py/torch_migraphx/fx/converters/aten_ops_converters.py | 3 ++- py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py | 7 +++++++ tests/dynamo/converters/test_maxmin_dynamo.py | 1 + 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index 9a3ace7c..32e8a02a 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -980,6 +980,12 @@ def acc_ops_maximum(mgx_module, node, args, kwargs): return mgx_module.add_instruction(migraphx.op('max'), [inp, other]) +@migraphx_converter(acc_ops.max) +def acc_ops_max(mgx_module, node, args, kwargs): + inp, dim, keepdims = kwargs["input"], kwargs["dim"], kwargs["keepdims"] + return mgx_module.add_instruction(migraphx.op('reduce_max'), [inp, dim, keepdims]) + + @migraphx_converter(acc_ops.mean) def acc_ops_mean(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index e7c587fa..72e5a39c 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -756,6 +756,7 @@ def aten_ops_argmax(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_argmax(mgx_module, node, (), acc_kwargs) @migraphx_converter(torch.ops.aten.max.default) +@migraphx_converter(torch.ops.aten.max.dim) def aten_ops_max(mgx_module, node, args, kwargs): assert len(args) >= 1 @@ -765,7 +766,7 @@ def aten_ops_max(mgx_module, node, args, kwargs): "keepdim": args[2] if len(args) >= 3 else False } - return acc_ops_converters.acc_ops_maximum(mgx_module, node, (), acc_kwargs) + return acc_ops_converters.acc_ops_max(mgx_module, node, (), acc_kwargs) @migraphx_converter(torch.ops.aten.as_strided.default) def aten_ops_as_strided(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index f4af53b0..140ebf0d 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -166,6 +166,13 @@ def maximum(*, input, other): return torch.maximum(input=input, other=other) +@register_acc_op_mapping(op_and_target=("call_function", torch.max)) +@register_acc_op_mapping(op_and_target=("call_method", "max")) +@register_acc_op +def max(*, input, dim, keepdim=False): + return torch.max(input=input, dim=dim, keepdim=keepdim) + + @register_acc_op_mapping(op_and_target=("call_function", operator.getitem)) @register_acc_op def getitem(*, input, idx): diff --git a/tests/dynamo/converters/test_maxmin_dynamo.py b/tests/dynamo/converters/test_maxmin_dynamo.py index d2b077d7..bbc0ddea 100644 --- a/tests/dynamo/converters/test_maxmin_dynamo.py +++ b/tests/dynamo/converters/test_maxmin_dynamo.py @@ -10,6 +10,7 @@ @pytest.mark.parametrize('op_alias', [torch.ops.aten.argmax.default, torch.ops.aten.max.default, + torch.ops.aten.max.dim, ]) @pytest.mark.parametrize('dim, keepdim', [ (2, True), From d72b455cb0e0cca3bb5bcdf125dd18cb26ab8864 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 30 Oct 2023 12:14:58 -0700 Subject: [PATCH 08/14] fixup! fixup! Add aten.max for max operator --- .../fx/converters/acc_ops_converters.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index 32e8a02a..f21d9a36 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -982,8 +982,16 @@ def acc_ops_maximum(mgx_module, node, args, kwargs): @migraphx_converter(acc_ops.max) def acc_ops_max(mgx_module, node, args, kwargs): - inp, dim, keepdims = kwargs["input"], kwargs["dim"], kwargs["keepdims"] - return mgx_module.add_instruction(migraphx.op('reduce_max'), [inp, dim, keepdims]) + min = mgx_module.add_instruction( + migraphx.op('reduce_max', axes=list(kwargs['dim'])), + [kwargs['input']]) + + if 'keepdim' in kwargs and kwargs['keepdim']: + return max + + return mgx_module.add_instruction( + migraphx.op('squeeze', axes=list(kwargs['dim'])), [max]) + @migraphx_converter(acc_ops.mean) From 2fdf562109a147b406f80fddc41afc0108c25026 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 30 Oct 2023 12:34:45 -0700 Subject: [PATCH 09/14] Add support and aten op for torch.min Similar to the onnx reduce_min operator, implimentatin is similar to that of mean, max and map to similar reduce ops in MIGraphX. This one is a freebie when doing max --- .../fx/converters/acc_ops_converters.py | 12 ++++++++++++ .../fx/converters/aten_ops_converters.py | 13 +++++++++++++ py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py | 7 +++++++ tests/dynamo/converters/test_maxmin_dynamo.py | 3 +++ 4 files changed, 35 insertions(+) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index f21d9a36..b73f6a73 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -993,6 +993,18 @@ def acc_ops_max(mgx_module, node, args, kwargs): migraphx.op('squeeze', axes=list(kwargs['dim'])), [max]) +@migraphx_converter(acc_ops.min) +def acc_ops_min(mgx_module, node, args, kwargs): + min = mgx_module.add_instruction( + migraphx.op('reduce_min', axes=list(kwargs['dim'])), + [kwargs['input']]) + + if 'keepdim' in kwargs and kwargs['keepdim']: + return min + + return mgx_module.add_instruction( + migraphx.op('squeeze', axes=list(kwargs['dim'])), [min]) + @migraphx_converter(acc_ops.mean) def acc_ops_mean(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 72e5a39c..6da02542 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -768,6 +768,19 @@ def aten_ops_max(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_max(mgx_module, node, (), acc_kwargs) +@migraphx_converter(torch.ops.aten.min.default) +@migraphx_converter(torch.ops.aten.min.dim) +def aten_ops_min(mgx_module, node, args, kwargs): + assert len(args) >= 1 + + acc_kwargs = { + "input": args[0], + "dim": args[1] if len(args) >= 2 else None, + "keepdim": args[2] if len(args) >= 3 else False + } + + return acc_ops_converters.acc_ops_min(mgx_module, node, (), acc_kwargs) + @migraphx_converter(torch.ops.aten.as_strided.default) def aten_ops_as_strided(mgx_module, node, args, kwargs): assert len(args) >= 3 diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index 140ebf0d..2ee54b3f 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -173,6 +173,13 @@ def max(*, input, dim, keepdim=False): return torch.max(input=input, dim=dim, keepdim=keepdim) +@register_acc_op_mapping(op_and_target=("call_function", torch.min)) +@register_acc_op_mapping(op_and_target=("call_method", "min")) +@register_acc_op +def min(*, input, dim, keepdim=False): + return torch.min(input=input, dim=dim, keepdim=keepdim) + + @register_acc_op_mapping(op_and_target=("call_function", operator.getitem)) @register_acc_op def getitem(*, input, idx): diff --git a/tests/dynamo/converters/test_maxmin_dynamo.py b/tests/dynamo/converters/test_maxmin_dynamo.py index bbc0ddea..1ce29cb8 100644 --- a/tests/dynamo/converters/test_maxmin_dynamo.py +++ b/tests/dynamo/converters/test_maxmin_dynamo.py @@ -11,6 +11,9 @@ [torch.ops.aten.argmax.default, torch.ops.aten.max.default, torch.ops.aten.max.dim, + torch.ops.aten.min.default, + torch.ops.aten.min.dim, + ]) @pytest.mark.parametrize('dim, keepdim', [ (2, True), From 020c5d99396500d09e71b726fb9fd0c75d206dfa Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 9 Nov 2023 14:23:48 -0800 Subject: [PATCH 10/14] Add changes for stack op --- .../fx/converters/aten_ops_converters.py | 32 +++++++++++++++++++ .../converters/test_shape_ops_dynamo.py | 15 ++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 6da02542..23e47270 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -29,6 +29,7 @@ import migraphx import torch +from typing import cast, Iterable, List, Sequence from ..converter_registry import migraphx_converter from torch_migraphx.fx.converters import acc_ops_converters from ..utils import torch_dtype_to_mgx_enum @@ -781,6 +782,37 @@ def aten_ops_min(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_min(mgx_module, node, (), acc_kwargs) +@migraphx_converter(torch.ops.aten.stack.default) +def aten_ops_stack(mgx_module, node, args, kwargs): + assert len(args) >= 1 + + """ + Map aten.stack to unsqueeze + cat acc ops. + """ + inputs = args[0] + assert isinstance(inputs, Sequence) + + print(inputs) + + dims = args[1] if len(args) > 1 else 0 + + unsqueeze_kwargs={ + "dim": dims + } + cat_kwargs={ + "dim": dims + } + + unsqueeze_nodes = [] + for i, t in enumerate(inputs): + unsqueeze_kwargs["input"] = t + unsq_res = acc_ops_converters.acc_ops_unsqueeze(mgx_module, node, (), unsqueeze_kwargs) + unsqueeze_nodes.append(unsq_res) + + cat_kwargs["tensors"] = unsqueeze_nodes + return acc_ops_converters.acc_ops_cat(mgx_module, node, (), cat_kwargs) + + @migraphx_converter(torch.ops.aten.as_strided.default) def aten_ops_as_strided(mgx_module, node, args, kwargs): assert len(args) >= 3 diff --git a/tests/dynamo/converters/test_shape_ops_dynamo.py b/tests/dynamo/converters/test_shape_ops_dynamo.py index aa24b8b8..545ccb31 100644 --- a/tests/dynamo/converters/test_shape_ops_dynamo.py +++ b/tests/dynamo/converters/test_shape_ops_dynamo.py @@ -148,4 +148,17 @@ def test_as_strided(op_alias, size, new_size, strides, offset): inp = torch.randn(size).cuda() mod = FuncModule(op_alias, new_size, strides, offset) mgx_mod = convert_to_mgx(mod, [inp]) - verify_outputs(mod, mgx_mod, inp) \ No newline at end of file + verify_outputs(mod, mgx_mod, inp) + +class StackModule(FuncModule): + def forward(self, x1, x2, x3): + return self.func([x1, x2, x3], *self.args, **self.kwargs) + +@pytest.mark.parametrize('op_alias', [torch.ops.aten.stack.default]) +@pytest.mark.parametrize('dim', [0, 3, -1]) +def test_stack(op_alias, dim): + inp = [torch.randn(20, 12, 15, 40).cuda() for _ in range(3)] + mod = StackModule(op_alias, dim=dim) + mgx_mod = convert_to_mgx(mod, inp) + verify_outputs(mod, mgx_mod, inp) + From 66d354908af226fcf51bbd63eec6c768f577b71f Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 13 Nov 2023 14:31:40 -0800 Subject: [PATCH 11/14] Add fx/dynamo changes for argmin op Required if we want to support min() operator down the road. Added converter for fx and dynamo in similar vein as the argmax function. Also added unit tests --- .../fx/converters/acc_ops_converters.py | 20 ++++++++++++++++ .../fx/converters/aten_ops_converters.py | 12 ++++++++++ .../fx/tracer/acc_tracer/acc_ops.py | 8 +++++++ tests/dynamo/converters/test_maxmin_dynamo.py | 24 ++++++++++++------- tests/fx/converters/test_maxmin_fx.py | 16 +++++++++++++ 5 files changed, 71 insertions(+), 9 deletions(-) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index b73f6a73..a129c2a6 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -787,6 +787,26 @@ def acc_ops_argmax(mgx_module, node, args, kwargs): return out +@migraphx_converter(acc_ops.argmin) +def acc_ops_argmin(mgx_module, node, args, kwargs): + inp = kwargs['input'] + dim = kwargs["dim"] + keepdim = kwargs["keepdim"] + + if dim is None: + assert not keepdim, "keepdim cannot be true when dim is None" + inp = acc_ops_flatten(mgx_module, node, (), {"input": inp}) + dim = 0 + + out = mgx_module.add_instruction(migraphx.op('argmin', axis=dim), [inp]) + + if not keepdim: + out = mgx_module.add_instruction(migraphx.op('squeeze', axes=[dim]), + [out]) + + return out + + @migraphx_converter(acc_ops.embedding) def acc_ops_embedding(mgx_module, node, args, kwargs): inp = kwargs['input'] diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 23e47270..8c8d6e62 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -756,6 +756,18 @@ def aten_ops_argmax(mgx_module, node, args, kwargs): return acc_ops_converters.acc_ops_argmax(mgx_module, node, (), acc_kwargs) +@migraphx_converter(torch.ops.aten.argmin.default) +def aten_ops_argmin(mgx_module, node, args, kwargs): + assert len(args) >= 1 + + acc_kwargs = { + "input": args[0], + "dim": args[1] if len(args) >= 2 else None, + "keepdim": args[2] if len(args) >= 3 else False + } + + return acc_ops_converters.acc_ops_argmin(mgx_module, node, (), acc_kwargs) + @migraphx_converter(torch.ops.aten.max.default) @migraphx_converter(torch.ops.aten.max.dim) def aten_ops_max(mgx_module, node, args, kwargs): diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index 2ee54b3f..3ee0d1b2 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -928,6 +928,14 @@ def argmax(*, input, dim, keepdim): return torch.argmax(input=input, dim=dim, keepdim=keepdim) +@register_acc_op_properties(AccOpProperty.unary) +@register_acc_op_mapping(op_and_target=("call_function", torch.argmin)) +@register_acc_op_mapping(op_and_target=("call_method", "argmin")) +@register_acc_op +def argmin(*, input, dim, keepdim): + return torch.argmin(input=input, dim=dim, keepdim=keepdim) + + @register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding)) @register_acc_op diff --git a/tests/dynamo/converters/test_maxmin_dynamo.py b/tests/dynamo/converters/test_maxmin_dynamo.py index 1ce29cb8..abcf3d92 100644 --- a/tests/dynamo/converters/test_maxmin_dynamo.py +++ b/tests/dynamo/converters/test_maxmin_dynamo.py @@ -7,14 +7,7 @@ pytest.skip(allow_module_level=True) -@pytest.mark.parametrize('op_alias', - [torch.ops.aten.argmax.default, - torch.ops.aten.max.default, - torch.ops.aten.max.dim, - torch.ops.aten.min.default, - torch.ops.aten.min.dim, - -]) +@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmax.default]) @pytest.mark.parametrize('dim, keepdim', [ (2, True), (-1, False), @@ -22,6 +15,19 @@ ]) def test_argmax(op_alias, dim, keepdim): inp = torch.randn(10, 2, 12, 8, 14).cuda() - mod = FuncModule(torch.argmax, dim, keepdim) + mod = FuncModule(op_alias, dim, keepdim) + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) + + +@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmin.default]) +@pytest.mark.parametrize('dim, keepdim', [ + (2, True), + (-1, False), + (0, False), +]) +def test_argmin(op_alias, dim, keepdim): + inp = torch.randn(10, 2, 12, 8, 14).cuda() + mod = FuncModule(op_alias, dim, keepdim) mgx_mod = convert_to_mgx(mod, [inp]) verify_outputs(mod, mgx_mod, inp) \ No newline at end of file diff --git a/tests/fx/converters/test_maxmin_fx.py b/tests/fx/converters/test_maxmin_fx.py index 82ba32dd..d35c6a63 100644 --- a/tests/fx/converters/test_maxmin_fx.py +++ b/tests/fx/converters/test_maxmin_fx.py @@ -46,6 +46,22 @@ def test_argmax(dim, keepdim): mod_func = FuncModule(torch.argmax, dim=dim, keepdim=keepdim) mod_method = MethodModule('argmax', dim=dim, keepdim=keepdim) + for mod in [mod_func, mod_method]: + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) + + +@pytest.mark.parametrize('dim, keepdim', [ + (2, True), + (-1, False), + (None, False), +]) +def test_argmin(dim, keepdim): + inp = torch.randn(10, 2, 12, 8, 14) + + mod_func = FuncModule(torch.argmin, dim=dim, keepdim=keepdim) + mod_method = MethodModule('argmin', dim=dim, keepdim=keepdim) + for mod in [mod_func, mod_method]: mgx_mod = convert_to_mgx(mod, [inp]) verify_outputs(mod, mgx_mod, inp) \ No newline at end of file From 417d9759cda19fc50351e87db21f5f437595ade9 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 13 Nov 2023 14:46:04 -0800 Subject: [PATCH 12/14] Fix acc op and converters for max/min ops Updated unit tests as well --- .../fx/converters/acc_ops_converters.py | 50 +++++++++++++------ .../fx/tracer/acc_tracer/acc_ops.py | 48 +++++++++++++++--- tests/fx/converters/test_reduce_ops_fx.py | 44 ++++++++++++++++ 3 files changed, 121 insertions(+), 21 deletions(-) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index a129c2a6..b68a78dc 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -1002,28 +1002,50 @@ def acc_ops_maximum(mgx_module, node, args, kwargs): @migraphx_converter(acc_ops.max) def acc_ops_max(mgx_module, node, args, kwargs): - min = mgx_module.add_instruction( - migraphx.op('reduce_max', axes=list(kwargs['dim'])), - [kwargs['input']]) + inp = kwargs['input'] + in_shape = inp.shape().lens() - if 'keepdim' in kwargs and kwargs['keepdim']: - return max + if 'dim' not in kwargs: + dims = list(range(len(in_shape))) + max_ = mgx_module.add_instruction( + migraphx.op('reduce_max', axes=dims), [inp]) + return mgx_module.add_instruction(migraphx.op('squeeze', axes=dims), [max_]) + else: + dims = kwargs['dim'] + indicies = acc_ops_argmax(mgx_module, node, args, kwargs) + max_ = mgx_module.add_instruction( + migraphx.op('reduce_max', axes=[dims]), [inp]) - return mgx_module.add_instruction( - migraphx.op('squeeze', axes=list(kwargs['dim'])), [max]) + if 'keepdim' in kwargs and kwargs['keepdim']: + return [max_, indicies] + + max_ = mgx_module.add_instruction( + migraphx.op('reduce_max', axes=[dims]), [inp]) + return [mgx_module.add_instruction(migraphx.op('squeeze', axes=[dims]), [max_]), indicies] @migraphx_converter(acc_ops.min) def acc_ops_min(mgx_module, node, args, kwargs): - min = mgx_module.add_instruction( - migraphx.op('reduce_min', axes=list(kwargs['dim'])), - [kwargs['input']]) + inp = kwargs['input'] + in_shape = inp.shape().lens() - if 'keepdim' in kwargs and kwargs['keepdim']: - return min + if 'dim' not in kwargs: + dims = list(range(len(in_shape))) + min_ = mgx_module.add_instruction( + migraphx.op('reduce_min', axes=dims), [inp]) + return mgx_module.add_instruction(migraphx.op('squeeze', axes=dims), [min_]) + else: + dims = kwargs['dim'] + indicies = acc_ops_argmin(mgx_module, node, args, kwargs) + min_ = mgx_module.add_instruction( + migraphx.op('reduce_min', axes=[dims]), [inp]) - return mgx_module.add_instruction( - migraphx.op('squeeze', axes=list(kwargs['dim'])), [min]) + if 'keepdim' in kwargs and kwargs['keepdim']: + return [min_, indicies] + + min_ = mgx_module.add_instruction( + migraphx.op('reduce_min', axes=[dims]), [inp]) + return [mgx_module.add_instruction(migraphx.op('squeeze', axes=[dims]), [min_]), indicies] @migraphx_converter(acc_ops.mean) diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index 3ee0d1b2..14aab29e 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -166,18 +166,52 @@ def maximum(*, input, other): return torch.maximum(input=input, other=other) -@register_acc_op_mapping(op_and_target=("call_function", torch.max)) -@register_acc_op_mapping(op_and_target=("call_method", "max")) +@register_acc_op_mapping( + op_and_target=("call_method", "max"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +@register_acc_op_mapping( + op_and_target=("call_function", torch.max), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) @register_acc_op -def max(*, input, dim, keepdim=False): - return torch.max(input=input, dim=dim, keepdim=keepdim) +def max(*, input, dim=None, keepdim=False): + if dim is not None: + return torch.max(input, dim=dim, keepdim=keepdim) + else: + return torch.max(input) -@register_acc_op_mapping(op_and_target=("call_function", torch.min)) -@register_acc_op_mapping(op_and_target=("call_method", "min")) +@register_acc_op_mapping( + op_and_target=("call_method", "min"), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) +@register_acc_op_mapping( + op_and_target=("call_function", torch.min), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim", this_arg_is_optional), + ("keepdim", "keepdim", this_arg_is_optional), + ], +) @register_acc_op def min(*, input, dim, keepdim=False): - return torch.min(input=input, dim=dim, keepdim=keepdim) + if dim is not None: + return torch.min(input, dim=dim, keepdim=keepdim) + else: + return torch.min(input) @register_acc_op_mapping(op_and_target=("call_function", operator.getitem)) diff --git a/tests/fx/converters/test_reduce_ops_fx.py b/tests/fx/converters/test_reduce_ops_fx.py index 21b1d292..2c1aebb3 100644 --- a/tests/fx/converters/test_reduce_ops_fx.py +++ b/tests/fx/converters/test_reduce_ops_fx.py @@ -15,6 +15,50 @@ def test_mean(dim, keepdim): verify_outputs(mod, mgx_mod, inp) +@pytest.mark.parametrize('dim, keepdim', [(0, True), (-1, False), (3, False), + (-2, True)]) +def test_max_dim(dim, keepdim): + inp = torch.randn(32, 43, 11, 2, 12) + mod_func = FuncModule(torch.max, dim=dim, keepdim=keepdim) + mod_method = MethodModule('max', dim=dim, keepdim=keepdim) + + for mod in [mod_func, mod_method]: + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) + + +def test_max_no_opt_param(): + inp = torch.randn(32, 43, 11, 2, 12) + mod_func = FuncModule(torch.max) + mod_method = MethodModule('max') + + for mod in [mod_func, mod_method]: + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) + + +@pytest.mark.parametrize('dim, keepdim', [(0, True), (-1, False), (3, False), + (-2, True)]) +def test_min_dim(dim, keepdim): + inp = torch.randn(32, 43, 11, 2, 12) + mod_func = FuncModule(torch.min, dim=dim, keepdim=keepdim) + mod_method = MethodModule('min', dim=dim, keepdim=keepdim) + + for mod in [mod_func, mod_method]: + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) + + +def test_min_no_opt_param(): + inp = torch.randn(32, 43, 11, 2, 12) + mod_func = FuncModule(torch.min) + mod_method = MethodModule('min') + + for mod in [mod_func, mod_method]: + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) + + @pytest.mark.parametrize('dim, keepdim', [(0, True), (-1, False), ([2, 3], False), (None, None)]) def test_sum(dim, keepdim): From f992c019b161f25f7cd46b6a99c3a3d0fcbaf43d Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 14 Nov 2023 12:07:19 -0800 Subject: [PATCH 13/14] Fix aten op for max/min added changes to unit tests and fix operators to handle multi input args correctly. --- .../fx/converters/aten_ops_converters.py | 18 ++++++++--- .../converters/test_reduce_ops_dynamo.py | 30 ++++++++++++++++++- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 8c8d6e62..3512fab2 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -775,10 +775,15 @@ def aten_ops_max(mgx_module, node, args, kwargs): acc_kwargs = { "input": args[0], - "dim": args[1] if len(args) >= 2 else None, - "keepdim": args[2] if len(args) >= 3 else False + "keepdim": False, } + if len(args) >= 2: + acc_kwargs["dim"] = args[1] + + if len(args) >= 3: + acc_kwargs["keepdim"] = args[2] + return acc_ops_converters.acc_ops_max(mgx_module, node, (), acc_kwargs) @migraphx_converter(torch.ops.aten.min.default) @@ -788,10 +793,15 @@ def aten_ops_min(mgx_module, node, args, kwargs): acc_kwargs = { "input": args[0], - "dim": args[1] if len(args) >= 2 else None, - "keepdim": args[2] if len(args) >= 3 else False + "keepdim": False } + if len(args) >= 2: + acc_kwargs["dim"] = args[1] + + if len(args) >= 3: + acc_kwargs["keepdim"] = args[2] + return acc_ops_converters.acc_ops_min(mgx_module, node, (), acc_kwargs) @migraphx_converter(torch.ops.aten.stack.default) diff --git a/tests/dynamo/converters/test_reduce_ops_dynamo.py b/tests/dynamo/converters/test_reduce_ops_dynamo.py index 9c727107..76497731 100644 --- a/tests/dynamo/converters/test_reduce_ops_dynamo.py +++ b/tests/dynamo/converters/test_reduce_ops_dynamo.py @@ -1,6 +1,6 @@ import pytest import torch -from dynamo_test_utils import FuncModule, convert_to_mgx, verify_outputs +from dynamo_test_utils import FuncModule, convert_to_mgx, verify_outputs, acc_tracer import torch_migraphx if not hasattr(torch_migraphx, "dynamo"): @@ -18,3 +18,31 @@ def test_reduce_ops(op_alias, dim, keepdim): mod = FuncModule(op_alias, dim, keepdim).cuda() mgx_mod = convert_to_mgx(mod, [inp]) verify_outputs(mod, mgx_mod, inp) + + +@pytest.mark.parametrize('op_alias', [ + torch.ops.aten.max.dim, + torch.ops.aten.min.dim, +]) +@pytest.mark.parametrize('dim, keepdim', [ + (0, True), + (1, False), + (3, False), + (2, True) +]) +def test_reduce_maxmin_ops_dim(op_alias, dim, keepdim): + inp = torch.randn(32, 43, 11, 2, 12).cuda() + mod = FuncModule(op_alias, dim, keepdim).cuda() + mgx_mod = convert_to_mgx(mod, [inp], tracer=acc_tracer) + verify_outputs(mod, mgx_mod, inp) + + +@pytest.mark.parametrize('op_alias', [ + torch.ops.aten.max.default, + torch.ops.aten.min.default, +]) +def test_reduce_maxmin_ops_no_param(op_alias): + inp = torch.randn(32, 43, 11, 2, 12).cuda() + mod = FuncModule(op_alias).cuda() + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp) From f0fc8d49cc5a8c081cf2e9be90a006fbcd797218 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 15 Nov 2023 07:03:09 -0800 Subject: [PATCH 14/14] Changes based on review comments - remove print in stack - merge argmin argmax in testing - rename test_leak_relu -> test_single_param_activation_funcs - Add default for min dim=None --- .../fx/converters/aten_ops_converters.py | 2 -- .../fx/tracer/acc_tracer/acc_ops.py | 2 +- .../converters/test_activations_dynamo.py | 2 +- tests/dynamo/converters/test_maxmin_dynamo.py | 18 ++++-------------- 4 files changed, 6 insertions(+), 18 deletions(-) diff --git a/py/torch_migraphx/fx/converters/aten_ops_converters.py b/py/torch_migraphx/fx/converters/aten_ops_converters.py index 3512fab2..a7b543f6 100644 --- a/py/torch_migraphx/fx/converters/aten_ops_converters.py +++ b/py/torch_migraphx/fx/converters/aten_ops_converters.py @@ -814,8 +814,6 @@ def aten_ops_stack(mgx_module, node, args, kwargs): inputs = args[0] assert isinstance(inputs, Sequence) - print(inputs) - dims = args[1] if len(args) > 1 else 0 unsqueeze_kwargs={ diff --git a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py index 14aab29e..f3f23c35 100644 --- a/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py @@ -207,7 +207,7 @@ def max(*, input, dim=None, keepdim=False): ], ) @register_acc_op -def min(*, input, dim, keepdim=False): +def min(*, input, dim=None, keepdim=False): if dim is not None: return torch.min(input, dim=dim, keepdim=keepdim) else: diff --git a/tests/dynamo/converters/test_activations_dynamo.py b/tests/dynamo/converters/test_activations_dynamo.py index 5f7a5d7d..455a89bd 100644 --- a/tests/dynamo/converters/test_activations_dynamo.py +++ b/tests/dynamo/converters/test_activations_dynamo.py @@ -45,7 +45,7 @@ def test_noparam_activation_funcs(op_alias): ((6, 12, 32, 6), 0.05), ((2, ), 0), ]) -def test_leaky_relu(op_alias, inp_size, alpha): +def test_single_param_activation_funcs(op_alias, inp_size, alpha): inp = torch.randn(inp_size).cuda() mod = FuncModule(op_alias, alpha).cuda() mgx_mod = convert_to_mgx(mod, [inp]) diff --git a/tests/dynamo/converters/test_maxmin_dynamo.py b/tests/dynamo/converters/test_maxmin_dynamo.py index abcf3d92..6e221a20 100644 --- a/tests/dynamo/converters/test_maxmin_dynamo.py +++ b/tests/dynamo/converters/test_maxmin_dynamo.py @@ -7,26 +7,16 @@ pytest.skip(allow_module_level=True) -@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmax.default]) -@pytest.mark.parametrize('dim, keepdim', [ - (2, True), - (-1, False), - (0, False), +@pytest.mark.parametrize('op_alias', [ + torch.ops.aten.argmax.default, + torch.ops.aten.argmax.default, ]) -def test_argmax(op_alias, dim, keepdim): - inp = torch.randn(10, 2, 12, 8, 14).cuda() - mod = FuncModule(op_alias, dim, keepdim) - mgx_mod = convert_to_mgx(mod, [inp]) - verify_outputs(mod, mgx_mod, inp) - - -@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmin.default]) @pytest.mark.parametrize('dim, keepdim', [ (2, True), (-1, False), (0, False), ]) -def test_argmin(op_alias, dim, keepdim): +def test_argmax_argmin(op_alias, dim, keepdim): inp = torch.randn(10, 2, 12, 8, 14).cuda() mod = FuncModule(op_alias, dim, keepdim) mgx_mod = convert_to_mgx(mod, [inp])