Skip to content

Commit

Permalink
[Relay] Fixes to sum (apache#2439)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and tqchen committed Jan 16, 2019
1 parent 967bcb3 commit a527b58
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
35 changes: 18 additions & 17 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of maximum element all of the elements of
Axis or axes along which a argmax operation is performed.
The default, axis=None, will find the indices of the maximum element of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
Expand Down Expand Up @@ -73,14 +73,14 @@ def sum(data, axis=None, keepdims=False, exclude=False):
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
Axis or axes along which a sum is performed. The default, axis=None,
will sum all of the elements of the input array. If axis is
negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one. With this option, the result will broadcast
correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
Expand All @@ -91,7 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
axis = [axis] if axis and isinstance(axis, int) else axis
return _make.sum(data, axis, keepdims, exclude)


Expand All @@ -104,9 +104,9 @@ def max(data, axis=None, keepdims=False, exclude=False):
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
Axis or axes along which the max operation is performed.
The default, axis=None, will find the max element from all of the elements of the input
array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
Expand Down Expand Up @@ -135,9 +135,10 @@ def min(data, axis=None, keepdims=False, exclude=False):
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
Axis or axes along which a minimum operation is performed.
The default, axis=None, will find the minimum element from all
of the elements of the input array. If axis is negative it counts from
the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
Expand Down Expand Up @@ -166,7 +167,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
Expand Down Expand Up @@ -197,7 +198,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
Axis or axes along which a product is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _wrapper(data, axis=None, keepdims=False):
[relay.prod, np.prod],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4))
verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3))
verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
Expand Down

0 comments on commit a527b58

Please sign in to comment.