Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][ONNX]ReduceLogSumExp Operator support #5453

Merged
merged 2 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,11 @@ class ReduceProd(Reduce):
"""
name = 'prod'

class ReduceLogSumExp(Reduce):
""" Operator converter for ReduceLogSumExp.
"""
name = 'logsumexp'

class ArgMax(OnnxOpConverter):
""" Operator converter for ArgMax.
"""
Expand Down Expand Up @@ -1598,8 +1603,7 @@ def _get_convert_map(opset):
'ReduceSum': ReduceSum.get_converter(opset),
'ReduceMean': ReduceMean.get_converter(opset),
'ReduceProd': ReduceProd.get_converter(opset),
# 'ReduceProd'
# 'ReduceLogSumExp'
'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset),

#defs/sorting
'ArgMax': ArgMax.get_converter(opset),
Expand Down
39 changes: 38 additions & 1 deletion python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# pylint: disable=redefined-builtin

from . import _make
from .tensor import sqrt
from .tensor import sqrt, log, exp
from .transform import squeeze
from ..expr import Tuple, TupleWrapper

Expand Down Expand Up @@ -475,3 +475,40 @@ def prod(data, axis=None, keepdims=False, exclude=False):
"""
axis = [axis] if isinstance(axis, int) else axis
return _make.prod(data, axis, keepdims, exclude)


def logsumexp(data, axis=None, keepdims=False):
"""Compute the log of the sum of exponentials of input elements over given axes.

This function is more numerically stable than log(sum(exp(input))).
It avoids overflows caused by taking the exp of large inputs and underflows
caused by taking the log of small inputs.

Parameters
----------
data : relay.Expr
The input data

axis : None or int or tuple of int
Axis or axes along which a standard deviation operation is performed.
The default, axis=None, will compute the log of the sum of exponentials of all elements
in the input array. If axis is negative it counts from the last to the first axis.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.

Returns
-------
result : relay.Expr
The computed result.
"""

axis = [axis] if isinstance(axis, int) else axis
max_x = max(data, axis, True)
exp_x = exp(data - max_x)
sum_x = sum(exp_x, axis, True)
out_x = log(sum_x) + max_x
if not keepdims:
out_x = squeeze(out_x, axis)
return out_x
38 changes: 38 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,15 @@ def verify_reduce_x(name, indata, axis, keepdims):
outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMean':
outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceLogSumExp':
def _np_log_sum_exp(x, axis, keepdims=False):
max_x = np.max(x, axis=axis, keepdims=True)
x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True))
x = x + max_x
if not keepdims:
x = np.squeeze(x, axis=axis)
return x
outdata = _np_log_sum_exp(indata, axis=axis, keepdims=keepdims == 1)
else:
raise Exception('unsupport op: {}'.format(name))
if len(np.asarray(outdata).shape) == 0:
Expand Down Expand Up @@ -1379,6 +1388,34 @@ def test_reduce_mean():
axis=(1,), keepdims=1)


def test_reduce_logsumexp():

for keepdims in [True, False]:
verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 3, 3, 1).astype(np.float32),
axis=(1, 2), keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 3, 3, 1).astype(np.float32),
axis=(1), keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(1, 3, 4, 1).astype(np.float32),
axis=(1), keepdims=keepdims)


def verify_split(indata, outdatas, split, axis=0):
indata = np.array(indata).astype(np.float32)
outdatas = [np.array(o).astype(np.float32) for o in outdatas]
Expand Down Expand Up @@ -2466,6 +2503,7 @@ def verify_topk(input_dims, K, axis=-1):
test_reduce_min()
test_reduce_sum()
test_reduce_mean()
test_reduce_logsumexp()
test_pad()
test_split()
test_binary_ops()
Expand Down
14 changes: 13 additions & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
dtype = "bool" if ref_func in [np.all, np.any] else dtype

x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
if test_func == relay.logsumexp:
z = test_func(x, axis, keepdims)
else:
z = test_func(x, axis, keepdims, exclude)
zz = run_infer_type(z)
if axis:
assert "axis=" in z.astext()
Expand Down Expand Up @@ -215,6 +218,14 @@ def _wrapper(data, axis=None, keepdims=False):
return func(data, axis=axis).reshape(out_shape)
return _wrapper

def _np_log_sum_exp(x, axis, keepdims=False):
max_x = np.max(x, axis=axis, keepdims=True)
x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True))
x = x + max_x
if not keepdims:
x = np.squeeze(x, axis=axis)
return x

d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4")
for func in [[relay.sum, np.sum],
[relay.max, np.max],
Expand All @@ -225,6 +236,7 @@ def _wrapper(data, axis=None, keepdims=False):
[relay.prod, np.prod],
[relay.all, np.all],
[relay.any, np.any],
[relay.logsumexp, _np_log_sum_exp],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
Expand Down