Skip to content

Commit

Permalink
pylint, do not use -1 for default value
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 968f3bd commit ce7848b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
8 changes: 2 additions & 6 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,10 +1102,6 @@ def gather_nd_shape_func(attrs, inputs, _):
batch_dims = get_const_int(attrs.batch_dims)
index_rank = get_const_int(attrs.index_rank)

assert (
index_rank > 0
), "index_rank needs to be specified for dynamic gather_nd"
assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd"

return [
_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))
]
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))]
4 changes: 2 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices, batch_dims=0, index_rank=-1):
def gather_nd(data, indices, batch_dims=0, index_rank=None):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand All @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, index_rank=-1):
batch_dims : int
The number of batch dimensions.
index_rank : int
index_rank : int, optional
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.
Expand Down
5 changes: 3 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3373,11 +3373,12 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int index_rank = -1) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0,
Optional<Integer> index_rank = NullValue<Integer>()) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->index_rank = Integer(index_rank);
attrs->index_rank = index_rank;
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down

0 comments on commit ce7848b

Please sign in to comment.