Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] bug fix shard_index (#37042) #37421

Merged
merged 1 commit into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/shard_index_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel {
"but the value given is %d.",
x_dims.size()));
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U,
platform::errors::InvalidArgument(
"The last dimension of Input(X) should be 1, "
"but the value given is %d.",
Expand Down
31 changes: 20 additions & 11 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14904,28 +14904,37 @@ def deformable_roi_pooling(input,
@deprecated(since="2.0.0", update_to="paddle.shard_index")
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
Recompute the `input` indices according to the offset of the
shard. The length of the indices is evenly divided into N shards, and if
the `shard_id` matches the shard with the input index inside, the index is
recomputed on the basis of the shard offset, elsewise it is set to
`ignore_value`. The detail is as follows:
Reset the values of `input` according to the shard it beloning to.
Every value in `input` must be a non-negative integer, and
the parameter `index_num` represents the integer above the maximum
value of `input`. Thus, all values in `input` must be in the range
[0, index_num) and each value can be regarded as the offset to the beginning
of the range. The range is further split into multiple shards. Specifically,
we first compute the `shard_size` according to the following formula,
which represents the number of integers each shard can hold. So for the
i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size).
::

shard_size = (index_num + nshards - 1) // nshards
y = x % shard_size if x // shard_size == shard_id else ignore_value

NOTE: If the length of indices cannot be evely divided by the shard number,
the size of the last shard will be less than the calculated `shard_size`
For each value `v` in `input`, we reset it to a new value according to the
following formula:
::

v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value

That is, the value `v` is set to the new offset within the range represented by the shard `shard_id`
if it in the range. Otherwise, we reset it to be `ignore_value`.

Args:
input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer defining the range of the index.
input (Tensor): Input tensor with data type int64 or int32. It's last dimension must be 1.
index_num (int): An integer represents the integer above the maximum value of `input`.
nshards (int): The number of shards.
shard_id (int): The index of the current shard.
ignore_value (int): An integer value out of sharded index range.

Returns:
Tensor: The sharded index of input.
Tensor.

Examples:
.. code-block:: python
Expand Down