From 132017de96c0009412720663d70acb235f5efeac Mon Sep 17 00:00:00 2001 From: Samuel Date: Fri, 8 May 2020 02:20:12 +0530 Subject: [PATCH] [RELAY][ONNX]ReduceLogSumExp Operator support (#5453) * [RELAY]LogSumExp Op Support * [ONNX]LogSumExp Op Support --- python/tvm/relay/frontend/onnx.py | 8 +++-- python/tvm/relay/op/reduce.py | 39 +++++++++++++++++++++- tests/python/frontend/onnx/test_forward.py | 38 +++++++++++++++++++++ tests/python/relay/test_op_level4.py | 14 +++++++- 4 files changed, 95 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 52d87f6af30c..4ae083ca8e48 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1076,6 +1076,11 @@ class ReduceProd(Reduce): """ name = 'prod' +class ReduceLogSumExp(Reduce): + """ Operator converter for ReduceLogSumExp. + """ + name = 'logsumexp' + class ArgMax(OnnxOpConverter): """ Operator converter for ArgMax. """ @@ -1640,8 +1645,7 @@ def _get_convert_map(opset): 'ReduceSum': ReduceSum.get_converter(opset), 'ReduceMean': ReduceMean.get_converter(opset), 'ReduceProd': ReduceProd.get_converter(opset), - # 'ReduceProd' - # 'ReduceLogSumExp' + 'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset), #defs/sorting 'ArgMax': ArgMax.get_converter(opset), diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index d3226012e887..988c94928d33 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -18,7 +18,7 @@ # pylint: disable=redefined-builtin from . import _make -from .tensor import sqrt +from .tensor import sqrt, log, exp from .transform import squeeze from ..expr import Tuple, TupleWrapper @@ -475,3 +475,40 @@ def prod(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis return _make.prod(data, axis, keepdims, exclude) + + +def logsumexp(data, axis=None, keepdims=False): + """Compute the log of the sum of exponentials of input elements over given axes. + + This function is more numerically stable than log(sum(exp(input))). + It avoids overflows caused by taking the exp of large inputs and underflows + caused by taking the log of small inputs. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a standard deviation operation is performed. + The default, axis=None, will compute the log of the sum of exponentials of all elements + in the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + axis = [axis] if isinstance(axis, int) else axis + max_x = max(data, axis, True) + exp_x = exp(data - max_x) + sum_x = sum(exp_x, axis, True) + out_x = log(sum_x) + max_x + if not keepdims: + out_x = squeeze(out_x, axis) + return out_x diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dc832aaa2570..78658e76060d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1307,6 +1307,15 @@ def verify_reduce_x(name, indata, axis, keepdims): outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1) elif name == 'ReduceMean': outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1) + elif name == 'ReduceLogSumExp': + def _np_log_sum_exp(x, axis, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + x = x + max_x + if not keepdims: + x = np.squeeze(x, axis=axis) + return x + outdata = _np_log_sum_exp(indata, axis=axis, keepdims=keepdims == 1) else: raise Exception('unsupport op: {}'.format(name)) if len(np.asarray(outdata).shape) == 0: @@ -1380,6 +1389,34 @@ def test_reduce_mean(): axis=(1,), keepdims=1) +def test_reduce_logsumexp(): + + for keepdims in [True, False]: + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1, 2), keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1), keepdims=keepdims) + + verify_reduce_x("ReduceLogSumExp", + np.random.randn(1, 3, 4, 1).astype(np.float32), + axis=(1), keepdims=keepdims) + + def verify_split(indata, outdatas, split, axis=0): indata = np.array(indata).astype(np.float32) outdatas = [np.array(o).astype(np.float32) for o in outdatas] @@ -2557,6 +2594,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_reduce_min() test_reduce_sum() test_reduce_mean() + test_reduce_logsumexp() test_pad() test_split() test_binary_ops() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index bbe2c69d6294..947a4bfd0b3b 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -165,7 +165,10 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") dtype = "bool" if ref_func in [np.all, np.any] else dtype x = relay.var("x", relay.TensorType(data, dtype)) - z = test_func(x, axis, keepdims, exclude) + if test_func == relay.logsumexp: + z = test_func(x, axis, keepdims) + else: + z = test_func(x, axis, keepdims, exclude) zz = run_infer_type(z) if axis: assert "axis=" in z.astext() @@ -215,6 +218,14 @@ def _wrapper(data, axis=None, keepdims=False): return func(data, axis=axis).reshape(out_shape) return _wrapper + def _np_log_sum_exp(x, axis, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + x = x + max_x + if not keepdims: + x = np.squeeze(x, axis=axis) + return x + d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") for func in [[relay.sum, np.sum], [relay.max, np.max], @@ -225,6 +236,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.prod, np.prod], [relay.all, np.all], [relay.any, np.any], + [relay.logsumexp, _np_log_sum_exp], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())