Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] Add enable_partial_send_recv switch in pipeline_configs (#46992) #47083

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ message PipelineConfig {
optional int32 accumulate_steps = 2 [ default = 1 ];
optional string schedule_mode = 3 [ default = '1F1B' ];
optional bool p2p_cache_shape = 4 [ default = true ];
optional bool enable_partial_send_recv = 5 [ default = true ];
}

message TensorParallelConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def __init__(self, layers, hcg, strategy):
'micro_batch_size']
self.accumulate_steps = self._strategy.pipeline_configs[
'accumulate_steps']

# If sent tensor are not the same from different hosts,
# they shouldn't been sent partially and then concated as a whole tensor.
self._enable_partial_send_recv = self._strategy.pipeline_configs[
'enable_partial_send_recv']
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

self.num_stages = self._hcg.get_pipe_parallel_world_size()
Expand All @@ -58,7 +61,8 @@ def __init__(self, layers, hcg, strategy):
self._real_pp_world_size = self.num_stages
self._real_pp_rank = self.stage_id

p2p.initialize_p2p_groups(hcg, self._using_cache)
p2p.initialize_p2p_groups(hcg, self._using_cache,
self._enable_partial_send_recv)

self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@

_hcg = None
_use_cache = False
_enable_partial_send_recv = True


def initialize_p2p_groups(hcg, use_cache=True):
global _hcg, _use_cache
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
global _hcg, _use_cache, _enable_partial_send_recv
_hcg = hcg
_use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
)

Expand Down Expand Up @@ -157,7 +159,8 @@ def set_send_message(self, tensor):


def _is_valid_send_recv_partial(tensor, mp_degree):

if not _enable_partial_send_recv:
return False
tensor_numel = np.prod(tensor.shape)
assert tensor_numel != 0, "can't send/recv zero element"
return mp_degree > 1 and tensor_numel % mp_degree == 0
Expand Down