From 79e5740e4c6e0f62f89287084ea88366f6dd2142 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Mon, 13 Sep 2021 14:26:10 -0700 Subject: [PATCH 1/4] add momentum --- python/tvm/relay/frontend/onnx.py | 71 ++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 48089d164a2f..50335cc22ffe 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -36,21 +36,9 @@ from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .common import ( - AttrCvt, - Renamer, - fold_constant, - get_name, - get_relay_op, - gru_cell, - infer_channels, - infer_shape, - infer_type, - infer_value, - lstm_cell, - new_var, - unbind, -) +from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, + gru_cell, infer_channels, infer_shape, infer_type, + infer_value, lstm_cell, new_var, unbind) __all__ = ["from_onnx"] @@ -3698,6 +3686,59 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Momentum(OnnxOpConverter): + """Operator converter for Momentum op.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = attr["alpha"] + beta = attr["beta"] + mode = attr["mode"].decode("utf-8") + norm_coefficient = attr["norm_coefficient"] + + assert mode in ["nesterov", "standard"], f"Unknown momentum mode {mode}" + R = inputs[0] + T = inputs[1] + + assert ( + len(inputs) - 2 + ) % 3 == 0, f"Expect triplets for remaining inputs, found {len(inputs) - 2}" + # Remaining inputs are: + # [x_1, x_2 ..., x_1_gradient, x_2_gradient, ... x_1_momentum, x_2_momentum...] + num_input_tensors = (len(inputs) - 2) // 3 + + # convert attributes to constants + dtype_inputs = infer_type(inputs[3]).checked_type.dtype + alpha = relay.const(alpha, dtype=dtype_inputs) + beta = relay.const(beta, dtype=dtype_inputs) + norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs) + default_beta = relay.const(1.0, dtype=dtype_inputs) + + # Calculate updated values for every input + output_tensors = [] + output_momentums = [] + for i in range(num_input_tensors): + x = inputs[i + 2] + gradient = inputs[i + 2 + num_input_tensors] + momentum = inputs[i + 2 + 2 * num_input_tensors] + g_regularized = norm_coefficient * x + gradient + beta_adjusted = relay.If(T > relay.const(0, dtype="int64"), beta, default_beta) + new_momentum = alpha * momentum + beta_adjusted * g_regularized + + if mode == "standard": + x_output = x - R * new_momentum + else: + # mode == 'nesterov' + x_output = x - R * (g_regularized + alpha * new_momentum) + + output_tensors.append(x_output) + output_momentums.append(new_momentum) + + # append lists together, momentums come after result tensors + result = output_tensors + output_momentums + return _expr.TupleWrapper(_expr.Tuple(result), len(result)) + + # compatible operators that do NOT require any conversion. _identity_list = [] From c52741b555b163f7e51b3e8fca408ad022f5de50 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Mon, 13 Sep 2021 14:46:08 -0700 Subject: [PATCH 2/4] make tests pass for momentum --- python/tvm/relay/frontend/onnx.py | 1 + tests/python/frontend/onnx/test_forward.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 50335cc22ffe..77b38bb6504e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3926,6 +3926,7 @@ def _get_convert_map(opset): "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), "Adagrad": Adagrad.get_converter(opset), "Adam": Adam.get_converter(opset), + "Momentum": Momentum.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d9f2e97f8247..f5835ec925be 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -235,7 +235,8 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static + from onnxruntime.quantization import (CalibrationDataReader, QuantType, + quantize_static) input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] @@ -4760,10 +4761,7 @@ def verify_eyelike(indata): "test_maxpool_with_argmax_2d_precomputed_pads", "test_maxpool_with_argmax_2d_precomputed_strides", "test_maxunpool_export_with_output_shape", - "test_momentum", - "test_momentum_multiple", "test_mvn", - "test_nesterov_momentum", # When unsqueeze is fully supported, remaining nllloss tests should work: "test_nllloss_NC_expanded", "test_nllloss_NCd1_expanded", From c8bd991e5b6b5e22c3a0550c514780a1537932ec Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Tue, 14 Sep 2021 23:25:54 -0700 Subject: [PATCH 3/4] blacking --- tests/python/frontend/onnx/test_forward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f5835ec925be..9084353b6d27 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -235,8 +235,7 @@ def verify_with_ort( def quantize_and_verify_with_ort(onnx_model, input_names, input_shapes, target, dev): - from onnxruntime.quantization import (CalibrationDataReader, QuantType, - quantize_static) + from onnxruntime.quantization import CalibrationDataReader, QuantType, quantize_static input_arrays = [np.random.random(shape).astype("float32") for shape in input_shapes] From f830a13971f44723f72235984134b1e286e3a063 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Wed, 15 Sep 2021 10:49:02 -0700 Subject: [PATCH 4/4] lint --- python/tvm/relay/frontend/onnx.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 77b38bb6504e..ae229d016422 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -36,9 +36,21 @@ from .. import random as _random from .. import ty as _ty from .. import vision as _vision -from .common import (AttrCvt, Renamer, fold_constant, get_name, get_relay_op, - gru_cell, infer_channels, infer_shape, infer_type, - infer_value, lstm_cell, new_var, unbind) +from .common import ( + AttrCvt, + Renamer, + fold_constant, + get_name, + get_relay_op, + gru_cell, + infer_channels, + infer_shape, + infer_type, + infer_value, + lstm_cell, + new_var, + unbind, +) __all__ = ["from_onnx"]