diff --git a/python/paddle/autograd/primrules.py b/python/paddle/autograd/primrules.py index 61dd16224d63f..a950ea44e2a2f 100644 --- a/python/paddle/autograd/primrules.py +++ b/python/paddle/autograd/primrules.py @@ -14,7 +14,8 @@ import paddle from .primreg import REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_JVP, REGISTER_TRANSPOSE -from .primreg import lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp, lookup_transpose +from .primreg import (lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp, + lookup_transpose, op_position_inputs, op_position_output) from .primops import (neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast, transpose, split, concat, reduce, matmul, slice_select, slice_assign, gather, scatter_add, fill_const, set_value) @@ -361,7 +362,7 @@ def sub_jvp(op, x_dot, y_dot): @REGISTER_JVP('mul_p') def mul_jvp(op, x_dot, y_dot): assert op.type == 'mul_p' - x, y = get_input_vars(op) + x, y = op_position_inputs(op) t1, t2 = mul(x_dot, y), mul(x, y_dot) z_dot = add(t1, t2) return z_dot @@ -369,7 +370,7 @@ def mul_jvp(op, x_dot, y_dot): @REGISTER_JVP('div_p') def div_jvp(op, x_dot, y_dot): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) t1, t2 = div(x_dot, y), div(mul(x, y_dot), mul(y, y)) z_dot = sub(t1, t2) return z_dot @@ -377,7 +378,7 @@ def div_jvp(op, x_dot, y_dot): @REGISTER_JVP('sqrt_p') def sqrt_jvp(op, x_dot): - x, = get_input_vars(op) + x, = op_position_inputs(op) c2 = fill_const(value=2.0, shape=x.shape, dtype=x.dtype) y_dot = div(x_dot, mul(c2, sqrt(x))) return y_dot @@ -385,7 +386,7 @@ def sqrt_jvp(op, x_dot): @REGISTER_JVP('tanh_p') def tanh_jvp(op, x_dot): - y, = get_output_vars(op) + y, = op_position_inputs(op) c1 = fill_const(value=1.0, shape=y.shape, dtype=y.dtype) y_dot = mul(x_dot, sub(c1, mul(y, y))) return y_dot @@ -431,7 +432,7 @@ def reduce_jvp(op, x_dot): @REGISTER_JVP('matmul_p') def matmul_jvp(op, x_dot, y_dot): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) t1 = matmul(x, y_dot) t2 = matmul(x_dot, y) z_dot = add(t1, t2) @@ -460,14 +461,14 @@ def slice_assign_jvp(op, x_dot, y_dot): @REGISTER_JVP('gather_p') def gather_jvp(op, x_dot): - _, indextensor = get_input_vars(op) + _, indextensor = op_position_inputs(op) axis = op.attr('axis') return linear_jvp(op, x_dot, indextensor, axis=axis) @REGISTER_JVP('scatter_add_p') def scatter_add_jvp(op, x_dot, y_dot): - _, _, indextensor = get_input_vars(op) + _, _, indextensor = op_position_inputs(op) axis = op.attr('axis') return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis) @@ -477,7 +478,7 @@ def scatter_add_jvp(op, x_dot, y_dot): @REGISTER_TRANSPOSE('add_p') def add_transpose(op, check_dot, z_bar): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) assert check_dot(x) or check_dot(y) x_bar = z_bar if check_dot(x) else None y_bar = z_bar if check_dot(y) else None @@ -486,7 +487,7 @@ def add_transpose(op, check_dot, z_bar): @REGISTER_TRANSPOSE('sub_p') def sub_transpose(op, check_dot, z_bar): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) assert check_dot(x) or check_dot(y) x_bar = z_bar if check_dot(x) else None y_bar = neg(z_bar) if check_dot(y) else None @@ -495,7 +496,7 @@ def sub_transpose(op, check_dot, z_bar): @REGISTER_TRANSPOSE('mul_p') def mul_transpose(op, check_dot, z_bar): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) assert check_dot(x) ^ check_dot(y) if check_dot(x): return mul(z_bar, y), None @@ -505,7 +506,7 @@ def mul_transpose(op, check_dot, z_bar): @REGISTER_TRANSPOSE('div_p') def div_transpose(op, check_dot, z_bar): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) assert not check_dot(y) x_bar = div(z_bar, y) if check_dot(x) else None return x_bar, None @@ -513,14 +514,14 @@ def div_transpose(op, check_dot, z_bar): @REGISTER_TRANSPOSE('reshape_p') def reshape_transpose(op, check_dot, y_bar): - x, = get_input_vars(op) + x, = op_position_inputs(op) assert check_dot(x) return reshape(y_bar, shape=x.shape) @REGISTER_TRANSPOSE('broadcast_p') def broadcast_transpose(op, check_dot, y_bar): - x, = get_input_vars(op) + x, = op_position_inputs(op) assert check_dot(x) bat = len(y_bar.shape) - len(x.shape) axis = list(range(bat)) @@ -532,7 +533,7 @@ def broadcast_transpose(op, check_dot, y_bar): @REGISTER_TRANSPOSE('transpose_p') def transpose_transpose(op, check_dot, y_bar): - x, = get_input_vars(op) + x, = op_position_inputs(op) assert check_dot(x) axis = op.attr('axis') reordered = sorted((k, i) for i, k in enumerate(axis)) @@ -542,14 +543,14 @@ def transpose_transpose(op, check_dot, y_bar): @REGISTER_TRANSPOSE('split_p') def split_transpose(op, check_dot, ys_bar): - x, = get_input_vars(op) + x, = op_position_inputs(op) assert check_dot(x) return concat(ys_bar, axis=op.attr('axis')) @REGISTER_TRANSPOSE('concat_p') def concat_transpose(op, check_dot, y_bar): - xs = get_input_vars(op) + xs, = op_position_inputs(op) for x in xs: assert check_dot(x) axis = op.attr('axis') @@ -559,7 +560,7 @@ def concat_transpose(op, check_dot, y_bar): @REGISTER_TRANSPOSE('reduce_p') def reduce_transpose(op, check_dot, y_bar): - x, = get_input_vars(op) + x, = op_position_inputs(op) assert check_dot(x) axes = op.attr('axis') shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape)) @@ -569,7 +570,7 @@ def reduce_transpose(op, check_dot, y_bar): @REGISTER_TRANSPOSE('matmul_p') def matmul_transpose(op, check_dot, z_bar): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) assert check_dot(x) ^ check_dot(y) # TODO: replace it. this is hacky axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1] @@ -581,7 +582,7 @@ def matmul_transpose(op, check_dot, z_bar): @REGISTER_TRANSPOSE('slice_select_p') def slice_select_transpose(op, check_dot, y_bar): - x, = get_input_vars(op) + x, = op_position_inputs(op) assert check_dot(x) zeros = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) axis = op.attr('axis') @@ -594,7 +595,7 @@ def slice_select_transpose(op, check_dot, y_bar): @REGISTER_TRANSPOSE('slice_assign_p') def slice_assign_transpose(op, check_dot, z_bar): - x, y = get_input_vars(op) + x, y = op_position_inputs(op) assert check_dot(x) and check_dot(y) zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) axis = op.attr('axis') @@ -610,10 +611,11 @@ def slice_assign_transpose(op, check_dot, z_bar): @REGISTER_TRANSPOSE('gather_p') def gather_transpose(op, check_dot, y_bar): - x, indextensor = get_input_vars(op) + x, indextensor = op_position_inputs(op) assert check_dot(x) axis = op.attr('axis') - return scatter_add(y_bar, indextensor, axis=axis) + zeros = fill_const(0.0, x.shape, x.dtype) + return scatter_add(zeros, y_bar, indextensor, axis=axis), None @REGISTER_TRANSPOSE('scatter_add_p') diff --git a/python/paddle/autograd/primx.py b/python/paddle/autograd/primx.py index c489619d41cf3..8949f1003f1e3 100644 --- a/python/paddle/autograd/primx.py +++ b/python/paddle/autograd/primx.py @@ -17,88 +17,87 @@ from paddle.fluid.framework import default_main_program, default_startup_program from paddle.fluid import unique_name, core from paddle.fluid.framework import Operator +from paddle import compat as cpt from .primops import fill_const, add from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig from .primrules import get_input_vars, get_output_vars, _orig2prim, _prim2orig, _jvp, _transpose from collections import OrderedDict +def flatten(inp): + if inp is None or isinstance(inp, paddle.fluid.framework.Variable): + return [inp] + flattened = [] + for part in inp: + flattened += flatten(part) + return flattened + def topo_path(xs, ys, block=None): """ Returns the list of ops on the path from `xs` to `ys` in topological order. TODO(Tongxin): supporting control flow and nested blocks. - Args: - xs: a list|tuple of vars as source ys: a list|tuple of vars as sink block: the program block containing the path, optional - Returns: path: a list of ops """ if block is None: block = default_main_program().current_block() + path = [] - reached_vars = list(xs) - sink_vars = list(ys) - sink_ops = {} + backpath = [] + reached_vars = OrderedDict() + used_vars = OrderedDict() - # block.ops are supposedly in topological order as of now + # Initialized reached vars + for x in xs: + assert x.block == block + reached_vars[id(x)] = x + + # block.ops are supposedly in the order that preservers correct data dependence. + reaching = lambda op: any(id(v) in reached_vars for v in get_input_vars(op)) + + # Forward pass to identify all reached variables and ops for op in block.ops: - if len(sink_ops) == len(ys): - break - ins = set(get_input_vars(op)) - if any(ins.intersection(reached_vars)): + if reaching(op): path.append(op) - outs = set(get_output_vars(op)) - for out in outs: - if any(out is y for y in sink_vars): - # Found an output op - assert not any(out is y for y in sink_ops) - sink_ops[id(out)] = op - # TODO(Tongxin): handling cases where sink vars - # have dependencies themselves - else: - reached_vars.append(out) - - # if len(sink_ops) != len(sink_vars): - # for var in sink_vars: - # assert id(var) in sink_ops, ( - # f"{var} is not reachable from input vars.") - - return path - - -def output_vars_on_path(xs, ys, block=None): + for var in get_output_vars(op): + reached_vars[id(var)] = var + + # Backward pass to find all used variables + used_vars = OrderedDict((id(y), y) for y in ys if id(y) in reached_vars) + back_reaching = lambda op: any(id(out) in used_vars for out in get_output_vars(op)) + + for op in reversed(path): + if back_reaching(op): + backpath.append(op) + for var in get_input_vars(op): + used_vars[id(var)] = var + + unused_xs = [x for x in xs if id(x) not in used_vars] + unreached_ys = [y for y in ys if id(y) not in reached_vars] + + return list(reversed(backpath)), unused_xs, unreached_ys + + +def output_vars_on_path(path): """ Returns the output variables of all the ops on the path from `xs` to `ys`. Args: - - xs: a list|tuple of vars as source - ys: a list|tuple of vars as sink - block: the program block containing the path, optional + path: a list of ops on which to find the output variables Returns: vars: the output vars """ - if block is None: - block = default_main_program().current_block() - vars = OrderedDict() - - for var in xs + ys: - vars[id(var)] = var - - sink_ops = set(y.op for y in ys) - - for op in topo_path(xs, ys, block): - if op not in sink_ops: - for out in get_output_vars(op): - vars[id(out)] = out + for op in path: + for out in get_output_vars(op): + vars[id(out)] = out return vars @@ -175,6 +174,8 @@ def add_vars(self, new_vars): self.vars.update({id(v): v for v in new_vars if v is not None}) def add_vars_rec(self, new_vars): + if new_vars is None: + return if isinstance(new_vars, paddle.fluid.framework.Variable): self.vars.update({id(new_vars): new_vars}) return @@ -188,8 +189,10 @@ def erase_dots(self, vars_to_erase): del self.vars[id(var)] self.dot2bar.delete_keyvars(vars_to_erase) self.var2dot.delete_valuevars(vars_to_erase) + block = self.block for var in vars_to_erase: - del var.block.vars[var.name] + block.desc._remove_var(cpt.to_bytes(var.name)) + block._sync_with_cpp() def var2dot_rec(self, vars, defaults=None): @@ -206,10 +209,42 @@ def var2dot_rec(self, vars, defaults=None): self.var2dot_rec(var, default) for var, default in zip(vars, defaults) ] - return dots + def dot2bar_rec(self, dots, defaults=None): + + if isinstance(dots, paddle.fluid.framework.Variable): + bar = self.dot2bar.lookup(dots) + if bar is None and defaults is not None: + bar = defaults + return bar + + if defaults is None: + defaults = [None for _ in range(dots)] + + bars = [ + self.dot2bar_rec(dot, default) + for dot, default in zip(dots, defaults) + ] + return bars + def linearize(self, xs, ys, xs_dot=None): + """ Performs the linearization transform, a.k.a, forward mode AD + transform, on a primitive lowered program. + + Args: + xs: a list of input variables + ys: a list of output variables + xs_dot: optional, a list of gradient input variables. The list size + must be equal to `len(xs)`. The shape and dtype of each element + must be the same as in `xs` + + Returns: + (xs_dot, ys_dot): a tuple of two lists. `xs_dot` is the list of + gradient inputs of the resulting linearized program. `ys_dot` is + the list gradient outputs of the resulting linearized program + + """ if xs_dot is None: xs_dot = [fill_const(1.0, shape=x.shape, dtype=x.dtype) for x in xs] self.add_vars(xs_dot) @@ -221,7 +256,9 @@ def linearize(self, xs, ys, xs_dot=None): assert x.shape == dot.shape self.var2dot.add(x, dot) - for op in topo_path(xs, ys, self.block): + path, _, _ = topo_path(xs, ys, self.block) + + for op in path: # An input var may not be on the input-output path, which implies # there may be None's in `ins_dot`. In this case we place # the original input in the position of the otherwise forward @@ -238,6 +275,24 @@ def linearize(self, xs, ys, xs_dot=None): return xs_dot, ys_dot def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): + """ Performs the transpose transform, a.k.a, reverse mode AD + transform, on a linearized primitive program. + + Note, `transpose` is supposed to be used in couple with `linearize`. + + Args: + ys_dot: a list of outputs of the linearized program. + xs_dot: a list of inputs of the linearized program. + ys_bar: optional, a list of inputs of the resulting transposed + program. The list size must be equal to `len(ys_dot)`. The shape + and dtype of each element must be the same as in `ys_dot` + + Returns: + (ys_bar, xs_bar): a tuple of two lists. `ys_bar` is the list of + inputs of the resulting transposed program. `xs_bar` is + the list outputs of the resulting transposed program + + """ if ys_bar is None: ys_bar = [] for y in ys_dot: @@ -253,20 +308,30 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): self.dot2bar.add(dot, bar) # find all the relevant forward gradients - dotvars = output_vars_on_path(xs_dot, ys_dot) + path, _, _ = topo_path(xs_dot, ys_dot, self.block) + dotvars = output_vars_on_path(path) + dotvars.update((id(var), var) for var in xs_dot) is_dot = lambda v: id(v) in dotvars - - for op in reversed(topo_path(xs_dot, ys_dot, self.block)): - outs_bar = [self.dot2bar.lookup(var) for var in get_output_vars(op)] - ins_bar = _transpose(op, is_dot, *outs_bar) - if isinstance(ins_bar, (list, tuple)): - ins_bar = list(ins_bar) - else: - ins_bar = [ins_bar] - self.add_vars(ins_bar) - assert len(get_input_vars(op)) == len(ins_bar) - for dot, bar in zip(get_input_vars(op), ins_bar): + + for op in reversed(path): + out = op_position_output(op) + out_bar_rec = self.dot2bar_rec(out, defaults=out) + ins_bar_rec = _transpose(op, is_dot, out_bar_rec) + + # TODO(Tongxin): this is hacking. Tuple implies the Transpose rule + # returns multiple entities + if isinstance(ins_bar_rec, tuple): + ins_bar_rec = list(ins_bar_rec) + else: + ins_bar_rec = [ins_bar_rec] + self.add_vars_rec(ins_bar_rec) + + ins_bar = flatten(ins_bar_rec) + ins = get_input_vars(op) + assert len(ins) == len(ins_bar) + + for dot, bar in zip(ins, ins_bar): if bar is not None: # aggregate gradient grad = self.dot2bar.lookup(dot) @@ -279,17 +344,28 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): xs_bar = [self.dot2bar.lookup(x) for x in xs_dot] - if not retain_fwd: - dots_to_remove = set() - for op in topo_path(xs_dot, ys_dot): + if not retain_fwd and len(path) > 0: + vars_to_remove = set() + for op in path: for var in get_input_vars(op): if is_dot(var): - dots_to_remove.add(var) - block = op.block - op_idx = block.ops.index(op) - block._remove_op(op_idx) + vars_to_remove.add(var) + + op_indexes = [] + + block = self.block + for i, op in enumerate(block.ops): + if op in path: + op_indexes.append(i) + path.pop(0) + if len(path) == 0: + break - self.erase_dots(dots_to_remove) + # remove ops + for op_index in reversed(op_indexes): + block.desc._remove_op(op_index, op_index + 1) + + self.erase_dots(vars_to_remove) return ys_bar, xs_bar @@ -314,10 +390,7 @@ def _gradients(ys, xs, ys_bar=None): # is completely lowered to primitive ops, it's mandatory to run the lowering # pass once and again. This is obviously inefficient and needs to be # optimized. - new_vars = [] - for var in xs + ys: - new_vars.append(var) - + new_vars = xs + ys orig2prim(block, new_vars) ad = Transform(block)