Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine distribute transpiler #8241

Merged
merged 3 commits into from
Feb 9, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 87 additions & 35 deletions python/paddle/v2/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
pass
return orig_shape

def _op_input_var(self, op, varname):
pass

def _is_op_on_pserver(self, endpoint, all_ops, idx):
"""
Recursively check if the op need to run on current server.
Expand All @@ -309,29 +312,35 @@ def _is_op_on_pserver(self, endpoint, all_ops, idx):
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
]
op = all_ops[idx]
if op.inputs.has_key("Param"):
if op.inputs["Param"].name in param_names:
input_names = set(op.input_names)
# TODO(typhoonzero): using Param and Grad input name to identify
# that the operator is an optimization operator, need a better way.
if "Param" in input_names:
if op.input("Param")[0] in param_names:
return True
else:
for n in param_names:
if same_or_split_var(n, op.inputs[
"Param"].name) and n != op.inputs["Param"].name:
if same_or_split_var(n, op.input("Param")[0]) \
and n != op.input("Param")[0]:
return True
return False
else:
j = idx - 1
while j >= 0:
prev_op = all_ops[j]
prev_output_names = [o.name for o in prev_op.outputs.values()]
prev_input_names = [o.name for o in prev_op.inputs.values()]
# prev_output_names = [o.name for o in prev_op.outputs.values()]
# prev_input_names = [o.name for o in prev_op.inputs.values()]
# NOTE(typhoonzero): consider list input/output
prev_output_names = prev_op.desc.output_arg_names()
prev_input_names = prev_op.desc.input_arg_names()
found1 = False
found2 = False
for _, v in op.inputs.iteritems():
if v.name in prev_output_names:
for varname in op.desc.input_arg_names():
if varname in prev_output_names:
found1 = self._is_op_on_pserver(endpoint, all_ops, j)
# later ops may produce output for prev op's next batch use.
for _, v in op.outputs.iteritems():
if v.name in prev_input_names:
for varname in op.desc.output_arg_names():
if varname in prev_input_names:
found2 = self._is_op_on_pserver(endpoint, all_ops, j)
if found1 or found2:
return True
Expand All @@ -342,11 +351,11 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
for key, var in opt_op.inputs.iteritems():
for key in opt_op.input_names:
if key == "Grad":
grad_block = None
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
if same_or_split_var(g.name, var.name):
if same_or_split_var(g.name, opt_op.input(key)[0]):
grad_block = g
break
if not grad_block:
Expand Down Expand Up @@ -376,7 +385,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, var.name):
if same_or_split_var(p.name, opt_op.input(key)[0]):
param_block = p
break
if not param_block:
Expand All @@ -389,11 +398,12 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):

new_inputs[key] = tmpvar

for key, var in opt_op.inputs.iteritems():
for key in opt_op.input_names:
if key in ["Param", "Grad"]:
continue
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
tmpvar = program.global_block().create_var(
Expand All @@ -412,30 +422,44 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
shape=new_shape)

# change output's ParamOut variable
opt_op.outputs["ParamOut"] = new_inputs["Param"]
outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=opt_op.outputs,
outputs=outputs,
attrs=opt_op.attrs)

def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
# Append the ops for parameters that do not need to be optimized/updated
for _, var in opt_op.inputs.iteritems():
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op)
for var in inputs.itervalues():
if type(var) == list:
varlist = var
else:
varlist = [var]
for var in varlist:
# TODO(typhoonzero): will remove below line later.
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)

outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)

program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
inputs=inputs,
outputs=outputs,
attrs=opt_op.attrs)

def get_pserver_program(self, endpoint):
Expand Down Expand Up @@ -472,7 +496,7 @@ def get_pserver_program(self, endpoint):
self.optimize_ops, idx)
if not is_op_on_pserver:
continue
if opt_op.inputs.has_key("Grad"):
if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_sub_program, pserver_program,
opt_op, endpoint)
else:
Expand All @@ -499,6 +523,30 @@ def get_pserver_program(self, endpoint):
pserver_program.sync_with_cpp()
return pserver_program

def _get_input_map_from_op(self, varmap, op):
iomap = dict()
for key in op.input_names:
vars = []
for varname in op.input(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap

def _get_output_map_from_op(self, varmap, op):
iomap = dict()
for key in op.output_names:
vars = []
for varname in op.output(key):
vars.append(varmap[varname])
if len(vars) == 1:
iomap[key] = vars[0]
else:
iomap[key] = vars
return iomap

def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
Expand Down Expand Up @@ -529,17 +577,21 @@ def _get_splited_name_and_shape(varname):

# 2. rename op outputs
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
# do not append startup op if var is not on this pserver
op_on_pserver = False
for key, var in op.outputs.iteritems():
newname, _ = _get_splited_name_and_shape(var.name)
for key in op.output_names:
newname, _ = _get_splited_name_and_shape(op.output(key)[0])
if newname:
op_on_pserver = True
new_outputs[key] = created_var_map[newname]
elif var.name in pserver_vars:
elif op.output(key)[0] in pserver_vars:
op_on_pserver = True
new_outputs[key] = pserver_vars[var.name]
new_outputs[key] = pserver_vars[op.output(key)[0]]

# most startup program ops have no inputs
new_inputs = self._get_input_map_from_op(pserver_vars, op)

if op_on_pserver:
if op.type in [
Expand All @@ -548,7 +600,7 @@ def _get_splited_name_and_shape(varname):
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op(
type=op.type,
inputs=op.inputs,
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
return s_prog