Skip to content

Commit

Permalink
bug fix shard_index (PaddlePaddle#37042)
Browse files Browse the repository at this point in the history
  • Loading branch information
lilong12 committed Nov 22, 2021
1 parent 9ffb43b commit ec5bbe6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/shard_index_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel {
"but the value given is %d.",
x_dims.size()));
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U,
platform::errors::InvalidArgument(
"The last dimension of Input(X) should be 1, "
"but the value given is %d.",
Expand Down
31 changes: 20 additions & 11 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14904,28 +14904,37 @@ def deformable_roi_pooling(input,
@deprecated(since="2.0.0", update_to="paddle.shard_index")
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
Recompute the `input` indices according to the offset of the
shard. The length of the indices is evenly divided into N shards, and if
the `shard_id` matches the shard with the input index inside, the index is
recomputed on the basis of the shard offset, elsewise it is set to
`ignore_value`. The detail is as follows:
Reset the values of `input` according to the shard it beloning to.
Every value in `input` must be a non-negative integer, and
the parameter `index_num` represents the integer above the maximum
value of `input`. Thus, all values in `input` must be in the range
[0, index_num) and each value can be regarded as the offset to the beginning
of the range. The range is further split into multiple shards. Specifically,
we first compute the `shard_size` according to the following formula,
which represents the number of integers each shard can hold. So for the
i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size).
::

shard_size = (index_num + nshards - 1) // nshards
y = x % shard_size if x // shard_size == shard_id else ignore_value

NOTE: If the length of indices cannot be evely divided by the shard number,
the size of the last shard will be less than the calculated `shard_size`
For each value `v` in `input`, we reset it to a new value according to the
following formula:
::

v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value

That is, the value `v` is set to the new offset within the range represented by the shard `shard_id`
if it in the range. Otherwise, we reset it to be `ignore_value`.

Args:
input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer defining the range of the index.
input (Tensor): Input tensor with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer represents the integer above the maximum value of `input`.
nshards (int): The number of shards.
shard_id (int): The index of the current shard.
ignore_value (int): An integer value out of sharded index range.

Returns:
Tensor: The sharded index of input.
Tensor.

Examples:
.. code-block:: python
Expand Down

0 comments on commit ec5bbe6

Please sign in to comment.