Skip to content

Commit

Permalink
Fix reflection mode in grid sample op
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
wanghaoshuang committed Sep 7, 2020
1 parent 4e90409 commit e285162
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/grid_sampler_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>(
"padding_mode",
"(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflect' and "
"index is out of input images. It can be 'zeros', 'reflection' and "
"'border'.")
.SetDefault("zeros");

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/grid_sampler_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") {
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
Expand Down Expand Up @@ -432,7 +432,7 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") {
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/grid_sampler_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ static inline void clip(const platform::CPUDeviceContext& ctx,
if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflect") {
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs();
Expand Down Expand Up @@ -117,7 +117,7 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx,
auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res;
} else if (padding_mode == "reflect") {
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0));
Expand Down

0 comments on commit e285162

Please sign in to comment.