Skip to content

Commit

Permalink
Fix linearize and transpose.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 13, 2022
1 parent 5d7cda7 commit 1c163aa
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def topo_path(xs, ys, block=None):
ins = set(get_input_vars(op))
if any(ins.intersection(def_vars)):
path.append(op)
outs = set(get_output_vars(op))
for out in outs:
if any(out is y for y in ys):
# Found an output op
assert not any(out is y for y in sink_ops)
sink_ops[id(out)] = op
else:
def_vars.append(out)
outs = set(get_output_vars(op))
for out in outs:
if any(out is y for y in ys):
# Found an output op
assert not any(out is y for y in sink_ops)
sink_ops[id(out)] = op
else:
def_vars.append(out)
if len(sink_ops) != len(ys):
not_reachable = (var for var in ys if id(var) not in sink_ops)
raise f"Output vars: {' '.join(not_reachable)} are not reachable from inputs."
Expand All @@ -78,7 +78,7 @@ def __init__(self, name, varset):
self.varset = varset
self.tab = OrderedDict()

def set(self, key_var, value_var):
def add(self, key_var, value_var):
self.tab[id(key_var)] = id(value_var)

def lookup(self, key_var):
Expand Down Expand Up @@ -122,7 +122,7 @@ def init_varset(self, block):
self.vars[id(var)] = var

def update_varset(self, new_vars):
self.vars.update({id(v) : v for v in new_vars})
self.vars.update({id(v) : v for v in new_vars if v is not None})

def erase_dots(self, vars_to_erase):
for var in vars_to_erase:
Expand All @@ -149,14 +149,21 @@ def linearize(self, xs, ys, xs_dot=None):
assert all(x.shape == x_dot.shape for x, x_dot in zip(xs, xs_dot))

self.update_varset(xs_dot)
map(self.var2dot.set, xs, xs_dot)
for x, dot in zip(xs, xs_dot):
self.var2dot.add(x, dot)
for op in topo_path(xs, ys, self.block):
xs_dot = list(map(self.var2dot.lookup, get_input_vars(op)))
ys_dot = _jvp(op, *xs_dot)
self.update_varset(ys_dot)
map(self.var2dot.set, op.get_output_vars(op), ys_dot)
ins = get_input_vars(op)
ins_dot = [self.var2dot.lookup(var) for var in ins]
jvp_ins = [x if dot is None else dot for dot, x in zip(ins_dot, ins)]
outs_dot = _jvp(op, *jvp_ins)
if not isinstance(outs_dot, list):
outs_dot = [outs_dot]
self.update_varset(outs_dot)

for x, dot in zip(get_output_vars(op), outs_dot):
self.var2dot.add(x, dot)

ys_dot = map(self.var2dot.lookup, ys)
ys_dot = [self.var2dot.lookup(y) for y in ys]
return xs_dot, ys_dot

def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
Expand All @@ -171,14 +178,20 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
assert y_dot.dtype == y_bar.dtype

self.update_varset(ys_bar)
map(self.dot2bar.set, ys_dot, ys_bar)
for dot, bar in zip(ys_dot, ys_bar):
self.dot2bar.add(dot, bar)
for op in reversed(topo_path(xs_dot, ys_dot, self.block)):
ys_bar = list(map(self.dot2bar.lookup, get_output_vars(op)))
xs_bar = _transpose(op, self.is_dot, *ys_bar)
self.update_varset(xs_bar)
map(self.dot2bar.set, op.get_input_vars(), xs_bar)

xs_bar = list(map(self.dot2bar.lookup, xs_dot))
outs_bar = [self.dot2bar.lookup(var) for var in get_output_vars(op)]
ins_bar = _transpose(op, self.is_dot, *outs_bar)
if isinstance(ins_bar, (list, tuple)):
ins_bar = list(ins_bar)
else:
ins_bar = [ins_bar]
self.update_varset(ins_bar)
for dot, bar in zip(op.get_input_vars(), ins_bar):
if bar is not None:
self.dot2bar.add(dot, bar)
xs_bar = [self.dot2bar.lookup(x) for x in xs_dot]

if not retain_fwd:
dots_to_remove = set()
Expand Down

0 comments on commit 1c163aa

Please sign in to comment.