diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index eb7fa00d6f3e..16686963d3be 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3707,6 +3707,59 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple(result), len(result)) +class Momentum(OnnxOpConverter): + """Operator converter for Momentum op.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + alpha = attr["alpha"] + beta = attr["beta"] + mode = attr["mode"].decode("utf-8") + norm_coefficient = attr["norm_coefficient"] + + assert mode in ["nesterov", "standard"], f"Unknown momentum mode {mode}" + R = inputs[0] + T = inputs[1] + + assert ( + len(inputs) - 2 + ) % 3 == 0, f"Expect triplets for remaining inputs, found {len(inputs) - 2}" + # Remaining inputs are: + # [x_1, x_2 ..., x_1_gradient, x_2_gradient, ... x_1_momentum, x_2_momentum...] + num_input_tensors = (len(inputs) - 2) // 3 + + # convert attributes to constants + dtype_inputs = infer_type(inputs[3]).checked_type.dtype + alpha = relay.const(alpha, dtype=dtype_inputs) + beta = relay.const(beta, dtype=dtype_inputs) + norm_coefficient = relay.const(norm_coefficient, dtype=dtype_inputs) + default_beta = relay.const(1.0, dtype=dtype_inputs) + + # Calculate updated values for every input + output_tensors = [] + output_momentums = [] + for i in range(num_input_tensors): + x = inputs[i + 2] + gradient = inputs[i + 2 + num_input_tensors] + momentum = inputs[i + 2 + 2 * num_input_tensors] + g_regularized = norm_coefficient * x + gradient + beta_adjusted = relay.If(T > relay.const(0, dtype="int64"), beta, default_beta) + new_momentum = alpha * momentum + beta_adjusted * g_regularized + + if mode == "standard": + x_output = x - R * new_momentum + else: + # mode == 'nesterov' + x_output = x - R * (g_regularized + alpha * new_momentum) + + output_tensors.append(x_output) + output_momentums.append(new_momentum) + + # append lists together, momentums come after result tensors + result = output_tensors + output_momentums + return _expr.TupleWrapper(_expr.Tuple(result), len(result)) + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3895,6 +3948,7 @@ def _get_convert_map(opset): "NegativeLogLikelihoodLoss": NegativeLogLikelihoodLoss.get_converter(opset), "Adagrad": Adagrad.get_converter(opset), "Adam": Adam.get_converter(opset), + "Momentum": Momentum.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ecef39ada53f..056d1b8fffdf 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4755,10 +4755,7 @@ def verify_eyelike(indata): "test_maxpool_with_argmax_2d_precomputed_pads", "test_maxpool_with_argmax_2d_precomputed_strides", "test_maxunpool_export_with_output_shape", - "test_momentum", - "test_momentum_multiple", "test_mvn", - "test_nesterov_momentum", # When unsqueeze is fully supported, remaining nllloss tests should work: "test_nllloss_NC_expanded", "test_nllloss_NCd1_expanded",