-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathinvert.py
62 lines (55 loc) · 2.09 KB
/
invert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.fx as fx
# An inverse mapping is one that takes a function f(x) and returns a function g
# such that f(g(x)) == x. For example,since log(exp(x)) == x, exp and log are
# inverses.
invert_mapping = {}
def add_inverse(a, b):
invert_mapping[a] = b
invert_mapping[b] = a
inverses = [
(torch.sin, torch.arcsin),
(torch.cos, torch.arccos),
(torch.tan, torch.arctan),
(torch.exp, torch.log),
]
for a, b in inverses:
add_inverse(a, b)
# The general strategy is that we walk the graph backwards, transforming each
# node into its inverse. To do so, we swap the outputs and inputs of the
# functions, and then we look up its inverse in `invert_mapping`. Note that
# this transform assumes that all operations take in only one input and return
# one output.
def invert(model: torch.nn.Module) -> torch.nn.Module:
fx_model = fx.symbolic_trace(model)
new_graph = fx.Graph() # As we're building up a new graph
env = {}
for node in reversed(fx_model.graph.nodes):
if node.op == 'call_function':
# This creates a node in the new graph with the inverse function,
# and passes `env[node.name]` (i.e. the previous output node) as
# input.
new_node = new_graph.call_function(invert_mapping[node.target], (env[node.name],))
env[node.args[0].name] = new_node
elif node.op == 'output':
# We turn the output into an input placeholder
new_node = new_graph.placeholder(node.name)
env[node.args[0].name] = new_node
elif node.op == 'placeholder':
# We turn the input placeholder into an output
new_graph.output(env[node.name])
else:
raise RuntimeError("Not implemented")
new_graph.lint()
return fx.GraphModule(fx_model, new_graph)
def f(x):
return torch.exp(torch.tan(x))
res = invert(f)
print(res.code)
"""
def forward(self, output):
log_1 = torch.log(output); output = None
arctan_1 = torch.arctan(log_1); log_1 = None
return arctan_1
"""
print(f(res((torch.arange(5) + 1)))) # [1., 2., 3., 4, 5.]