Skip to content

Commit

Permalink
Support allow_partial switch, which can be configure in
Browse files Browse the repository at this point in the history
pipeline_configs. If sent tensor are not the same from
different hosts, they shouldn't been sent partially and
then concated as a whole tensor.
  • Loading branch information
GhostScreaming committed Oct 13, 2022
1 parent f5acb14 commit 9cd503c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ message PipelineConfig {
optional int32 accumulate_steps = 2 [ default = 1 ];
optional string schedule_mode = 3 [ default = '1F1B' ];
optional bool p2p_cache_shape = 4 [ default = true ];
optional bool allow_partial = 5 [ default = true ];
}

message TensorParallelConfig {
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/kps/reduce_sum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <climits>
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include <limits>
#include "paddle/phi/core/enforce.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(self, layers, hcg, strategy):
'micro_batch_size']
self.accumulate_steps = self._strategy.pipeline_configs[
'accumulate_steps']

# If sent tensor are not the same from different hosts,
# they shouldn't been sent partially and then concated as a whole tensor.
self._allow_partial = self._strategy.pipeline_configs['allow_partial']
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

self.num_stages = self._hcg.get_pipe_parallel_world_size()
Expand All @@ -58,7 +60,7 @@ def __init__(self, layers, hcg, strategy):
self._real_pp_world_size = self.num_stages
self._real_pp_rank = self.stage_id

p2p.initialize_p2p_groups(hcg, self._using_cache)
p2p.initialize_p2p_groups(hcg, self._using_cache, self._allow_partial)

self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
_use_cache = False


def initialize_p2p_groups(hcg, use_cache=True):
global _hcg, _use_cache
def initialize_p2p_groups(hcg, use_cache=True, allow_partial=True):
global _hcg, _use_cache, _allow_partial
_hcg = hcg
_use_cache = use_cache
_allow_partial = allow_partial
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
)

Expand Down Expand Up @@ -157,7 +158,8 @@ def set_send_message(self, tensor):


def _is_valid_send_recv_partial(tensor, mp_degree):

if not _allow_partial:
return False
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0
Expand Down

0 comments on commit 9cd503c

Please sign in to comment.