-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathlinker.py
91 lines (68 loc) · 3.12 KB
/
linker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator
class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
generator = unique_name_generator(["torch_linker"])
# Ensure that torch is aware of the generated
# code so we can compile without graph breaks
def conversion_func_register(*args, **kwargs):
functor = pytorch_funcify(*args, **kwargs)
name = kwargs["unique_name"](functor)
self.gen_functors.append((f"_{name}", functor))
return functor
built_kwargs = {
"unique_name": generator,
"conversion_func": conversion_func_register,
**kwargs,
}
return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs
)
def jit_compile(self, fn):
import torch
# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
from pytensor.link.pytorch.dispatch import pytorch_typify
class wrapper:
"""
Pytorch would fail compiling our method when trying
to resolve some of the methods returned from dispatch
calls. We want to be careful to not leak the methods,
so this class just holds them and provisions the expected
location accordingly
https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
"""
def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy()
def __call__(self, *inputs, **kwargs):
import pytensor.link.utils
# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)
# Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])
return tuple(out.cpu().numpy() for out in outs)
def __del__(self):
del self.gen_functors
inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = []
return inner_fn
def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)
return thunk_inputs