Skip to content

Commit

Permalink
add program_mode in minimize for pslib mode;test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
danleifeng committed Sep 28, 2021
1 parent 97843de commit 844b2f0
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,8 @@ def minimize(self,
scopes=None,
startup_programs=None,
parameter_list=None,
no_grad_set=None):
no_grad_set=None,
program_mode="all_reduce"):
"""
minimize a program through loss, loss can be a list in DistributedOptimizer.
Note that in parameter server mode, a worker will not get anything about optimize_os
Expand All @@ -1105,6 +1106,7 @@ def minimize(self,
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
program_mode (str|"all_reduce"): grad action for grogram when use_ps_gpu.
Returns:
tuple: (optimize_ops, params_grads) which are, list of operators appended;
and list of (param, grad) Variables pair for optimization.
Expand Down Expand Up @@ -1139,12 +1141,17 @@ def minimize(self,
if opt_info["use_ps_gpu"]:
from paddle.fluid.transpiler.collective import MultiThread
# check start program

if program_mode not in [
"all_reduce", "fuse_all_reduce", "all_gather"
]:
raise ValueError("You should set program_mode in [ all_reduce, \
fuse_all_reduce, all_gather ]")
env = self.get_dist_env()
if not isinstance(losses, list):
startup_programs = [startup_programs]
for i in range(0, len(startup_programs)):
t = MultiThread()

t = MultiThread(trans_mode=program_mode)
start_program = startup_programs[i]
main_program = programs[i]
t.transpile(
Expand Down
152 changes: 148 additions & 4 deletions python/paddle/fluid/transpiler/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,14 @@ class MultiThread(GradAllReduce):
'''
'''

def __init__(self, nrings=1):
def __init__(self, nrings=1, trans_mode="all_reduce"):
GradAllReduce.__init__(self, nrings)
self.mode = "box"
# use fuse_allreduce by default in gpubox
self.fuse_all_reduce_ops = True
self.trans_mode = trans_mode
self.fuse_grad_size_in_num = 128
gpu_nums = os.getenv("FLAGS_selected_gpus",
"0,1,2,3,4,5,6,7,8").split(",")
self.gpu_num = len(gpu_nums)

def _transpile_startup_program(self):
if len(self.endpoints) > 1:
Expand All @@ -466,11 +468,153 @@ def _transpile_startup_program(self):

def _transpile_main_program(self):
self._insert_scale_loss_grad_ops()
if self.fuse_all_reduce_ops:
if self.trans_mode == "all_gather":
print("begin to transpile in all-gather mode")
self.allgather_ranks = self.nranks * self.gpu_num
self._insert_allgather_ops()
self._update_adam_ops()
elif self.trans_mode == "fuse_all_reduce":
print("begin to transpile in fuse all-reduce mode")
self._insert_fuse_allreduce_ops()
else:
print("begin to transpile in all-reduce mode")
self._insert_allreduce_ops()

def _insert_allgather_ops(self):
"""
insert allgather op to the main_program
"""
block = self.main_program.global_block()
ring_id = -1
grad = None
for idx, op in reversed(list(enumerate(block.ops))):
if self._is_backward_op(op) and \
self.op_role_var_key in op.attr_names:
op_role_var = op.all_attrs()[self.op_role_var_key]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0

offset = idx
for i in range(0, len(op_role_var), 2):
param = block.vars[op_role_var[i]]
new_grad_var = block.create_var(
name=op_role_var[i] + "_allgather",
shape=[self.allgather_ranks] + list(param.shape),
persistable=False,
dtype=core.VarDesc.VarType.FP32,
stop_gradient=True)
grad = block.vars[op_role_var[i + 1]]
if param.is_distributed: # no need to care: used in PLSC
continue

if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={self.op_role_key: OpRole.Backward})
offset += 1

# As we search ops reversedly, we should insert c_allgather
# op in the same way to keep the ring_id alternate
ring_id = (ring_id + 1) % self.nrings
block._insert_op(
offset,
type='c_allgather',
inputs={'X': grad},
outputs={'Out': new_grad_var},
attrs={
'nranks': self.allgather_ranks,
'ring_id': ring_id,
self.op_role_key: OpRole.Backward
})

if grad is None:
return

for idx, op in enumerate(block.ops):
if self._is_optimizer_op(op):
for ring_id in range(self.nrings):
block._insert_op(
idx + ring_id,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
self.op_role_key: OpRole.Backward
})
break

def _update_adam_ops(self):
"""
remove the original adam op, and add new adam ops
"""
block = self.main_program.global_block()

for idx, op in reversed(list(enumerate(block.ops))):
if self._is_optimizer_op(op):
offset = idx
if op.type != 'adam' and op.type != 'lamb': # filter out scale op
continue
param_name = op.input("Param")[0]
inputs = {
"Param": block.vars[op.input("Param")[0]],
"LearningRate": block.vars[op.input("LearningRate")[0]],
"Moment1": block.vars[op.input("Moment1")[0]],
"Moment2": block.vars[op.input("Moment2")[0]],
"Beta1Pow": block.vars[op.input("Beta1Pow")[0]],
"Beta2Pow": block.vars[op.input("Beta2Pow")[0]]
}
outputs = {
"ParamOut": block.vars[op.output("ParamOut")[0]],
"Moment1Out": block.vars[op.output("Moment1Out")[0]],
"Moment2Out": block.vars[op.output("Moment2Out")[0]],
"Beta1PowOut": block.vars[op.output("Beta1PowOut")[0]],
"Beta2PowOut": block.vars[op.output("Beta2PowOut")[0]]
}
attrs = {
"epsilon": op.attr('epsilon'),
"beta1": op.attr('beta1'),
"beta2": op.attr('beta2'),
"lazy_mode": op.attr('lazy_mode'),
"min_row_size_to_use_multithread":
op.attr('min_row_size_to_use_multithread')
}
split_vars = [
block.create_var(
name=param_name + "_" + str(i),
shape=block.vars[op.input("Param")[0]].shape,
persistable=False,
dtype=core.VarDesc.VarType.FP32,
stop_gradient=True) for i in range(self.allgather_ranks)
]
block._insert_op(
offset,
type="split",
inputs={
'X': block.vars[op.input("Param")[0] + "_allgather"]
},
outputs={'Out': split_vars},
attrs={'num': self.allgather_ranks,
'axis': 0})
offset += 1

for i in range(self.allgather_ranks):
inputs["Grad"] = split_vars[i]
block._insert_op(
offset,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
offset += 1
# remove the original adam op
block._remove_op(offset)

def _insert_fuse_allreduce_ops(self):
"""
insert coalesce_tensor and all reduce ops
Expand Down

1 comment on commit 844b2f0

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.