Skip to content

Commit

Permalink
fix bug of recompute in hybridparallel (#35588)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored Sep 10, 2021
1 parent 652da1f commit d53e567
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/operators/flatten_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down Expand Up @@ -50,6 +51,8 @@ REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
Expand All @@ -63,6 +66,8 @@ REGISTER_OP_CUDA_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def _split_activation(tensor):

# use inplace operation to save memory
data = tensor.flatten_()

part_size = tensor_numel // mp_degree
start = part_size * mp_rank
end = start + part_size
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _update_list(self):
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# fp16 is slower than fp32, though fp16 is supported.
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
Expand Down

0 comments on commit d53e567

Please sign in to comment.