Skip to content

Commit

Permalink
Fix roll_op by avoiding DivisionByZeroError, test=develop (#34499)
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang authored Jul 30, 2021
1 parent ba19398 commit 5571c98
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
17 changes: 10 additions & 7 deletions paddle/fluid/operators/roll_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ class RollKernel<platform::CUDADeviceContext, T>
int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
int64_t size = input_dim[dim];

shifts[i] = (shifts[i] % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
if (size != 0) {
shifts[i] = (shifts[i] % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}

Expand Down Expand Up @@ -151,10 +153,11 @@ class RollGradKernel<platform::CUDADeviceContext, T>
for (size_t i = 0; i < nums; i++) {
int dim = dims[i] >= 0 ? dims[i] : dims[i] + input_dim.size();
int64_t size = input_dim[dim];

shifts[i] = ((-shifts[i]) % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
if (size != 0) {
shifts[i] = ((-shifts[i]) % size + size) % size;
strides[i] = stride_dim[dim];
sizes[i] = size;
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/roll_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ inline void shift_along_dim(T* data, const DDim& input_dim, int64_t dim,
if (dim < 0) {
dim += input_dim.size();
}
if (input_dim[dim] == 0) {
return;
}
shift = shift % input_dim[dim];
if (shift < 0) {
shift += input_dim[dim];
Expand Down

0 comments on commit 5571c98

Please sign in to comment.