Skip to content

Commit

Permalink
[Tensor Parallelism] split fix bug (#33015)
Browse files Browse the repository at this point in the history
  • Loading branch information
JZ-LIANG authored May 26, 2021
1 parent a2a45d8 commit 20b9be6
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 5 deletions.
11 changes: 11 additions & 0 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,11 @@ def _parallel_linear(x,
group=None):
"""
Parallel Linear
axis the dimension of the parameter of linear layer.
axis = 0: the row dimension
axid = 1: the col dimension
"""
if group is not None and not group.is_member():
return
Expand Down Expand Up @@ -1008,6 +1013,12 @@ def _parallel_linear(x,
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[linear.weight.name].is_distributed = True
main_block.vars[linear.weight.name].is_distributed = True
# set is_distributed for splited bias
# if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
# if a linear layer is splited by col, the bias would also be split into each rank as its weight
if axis == 1 and linear._bias_attr != False:
startup_block.vars[linear.bias.name].is_distributed = True
main_block.vars[linear.bias.name].is_distributed = True

if not gather_out: return linear_out

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/base/distributed_strategy.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def sharding_configs(self):
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 32,
"sharding_degree": 8,
"sharding_degree": 2,
"dp_degree": 2,
"gradient_merge_acc_step": 4,
}
"""
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _update_list(self):
'sign',
'cast',
'fused_bn_add_activation',
'c_identity',
}

# The set of ops that don't support fp16 calculation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_model(self, main_prog, startup_program, rank):
axis=1,
num_partitions=2,
weight_attr=param_attr,
bias_attr=False, )
bias_attr=True, )

return [linear_out]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def get_model(self, main_prog, startup_program, rank):

linear_out = paddle.distributed.split(
data,
size=(1000, 8),
size=(1000, 16),
operation='linear',
axis=0,
num_partitions=2,
weight_attr=param_attr,
bias_attr=False, )
bias_attr=True, )

return [linear_out]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def _run_cluster(self, model_file, envs):
#update environment
env0.update(envs)
env1.update(envs)
tr_cmd = "%s %s"
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
tr_cmd = "%s -m coverage run --branch -p %s"
else:
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")
Expand Down

0 comments on commit 20b9be6

Please sign in to comment.