-
Notifications
You must be signed in to change notification settings - Fork 0
/
reparam_module.py
200 lines (171 loc) · 8.17 KB
/
reparam_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import torch
import torch.nn as nn
import numpy as np
import warnings
import types
from collections import namedtuple
from contextlib import contextmanager
from utils.utils_baseline import get_network
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
class ReparamModule(nn.Module):
def _get_module_from_name(self, mn):
if mn == '':
return self
m = self
for p in mn.split('.'):
m = getattr(m, p)
return m
def __init__(self, module):
super(ReparamModule, self).__init__()
self.module = module
param_infos = [] # (module name/path, param name)
shared_param_memo = {}
shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name)
params = []
param_numels = []
param_shapes = []
for mn, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p is not None:
if p in shared_param_memo:
shared_mn, shared_n = shared_param_memo[p]
shared_param_infos.append((mn, n, shared_mn, shared_n))
else:
shared_param_memo[p] = (mn, n)
param_infos.append((mn, n))
params.append(p.detach())
param_numels.append(p.numel())
param_shapes.append(p.size())
assert len(set(p.dtype for p in params)) <= 1, \
"expects all parameters in module to have same dtype"
# store the info for unflatten
self._param_infos = tuple(param_infos)
self._shared_param_infos = tuple(shared_param_infos)
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
# flatten
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
self.register_parameter('flat_param', flat_param)
self.param_numel = flat_param.numel()
del params
del shared_param_memo
# deregister the names as parameters
for mn, n in self._param_infos:
delattr(self._get_module_from_name(mn), n)
for mn, n, _, _ in self._shared_param_infos:
delattr(self._get_module_from_name(mn), n)
# register the views as plain attributes
self._unflatten_param(self.flat_param)
# now buffers
# they are not reparametrized. just store info as (module, name, buffer)
buffer_infos = []
for mn, m in self.named_modules():
for n, b in m.named_buffers(recurse=False):
if b is not None:
buffer_infos.append((mn, n, b))
self._buffer_infos = tuple(buffer_infos)
self._traced_self = None
def trace(self, example_input, **trace_kwargs):
assert self._traced_self is None, 'This ReparamModule is already traced'
if isinstance(example_input, torch.Tensor):
example_input = (example_input,)
example_input = tuple(example_input)
example_param = (self.flat_param.detach().clone(),)
example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),)
self._traced_self = torch.jit.trace_module(
self,
inputs=dict(
_forward_with_param=example_param + example_input,
_forward_with_param_and_buffers=example_param + example_buffers + example_input,
),
**trace_kwargs,
)
# replace forwards with traced versions
self._forward_with_param = self._traced_self._forward_with_param
self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers
return self
def clear_views(self):
for mn, n in self._param_infos:
setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr
def _apply(self, *args, **kwargs):
if self._traced_self is not None:
self._traced_self._apply(*args, **kwargs)
return self
return super(ReparamModule, self)._apply(*args, **kwargs)
def _unflatten_param(self, flat_param):
ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
for (mn, n), p in zip(self._param_infos, ps):
setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr
for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
@contextmanager
def unflattened_param(self, flat_param):
saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos]
self._unflatten_param(flat_param)
yield
# Why not just `self._unflatten_param(self.flat_param)`?
# 1. because of https://github.com/pytorch/pytorch/issues/17583
# 2. slightly faster since it does not require reconstruct the split+view
# graph
for (mn, n), p in zip(self._param_infos, saved_views):
setattr(self._get_module_from_name(mn), n, p)
for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
@contextmanager
def replaced_buffers(self, buffers):
for (mn, n, _), new_b in zip(self._buffer_infos, buffers):
setattr(self._get_module_from_name(mn), n, new_b)
yield
for mn, n, old_b in self._buffer_infos:
setattr(self._get_module_from_name(mn), n, old_b)
def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs):
with self.unflattened_param(flat_param):
with self.replaced_buffers(buffers):
return self.module(*inputs, **kwinputs)
def _forward_with_param(self, flat_param, *inputs, **kwinputs):
with self.unflattened_param(flat_param):
return self.module(*inputs, **kwinputs)
def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs):
flat_param = torch.squeeze(flat_param)
# print("PARAMS ON DEVICE: ", flat_param.get_device())
# print("DATA ON DEVICE: ", inputs[0].get_device())
# flat_param.to("cuda:{}".format(inputs[0].get_device()))
# self.module.to("cuda:{}".format(inputs[0].get_device()))
if flat_param is None:
flat_param = self.flat_param
if buffers is None:
return self._forward_with_param(flat_param, *inputs, **kwinputs)
else:
return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs)
def recover_loss_to_params(self, flat_loss):
layer_to_loss = {}
with torch.no_grad():
grouped_loss = flat_loss.split(self._param_numels)
for i in range(len(grouped_loss)):
param_info = self._param_infos[i]
if param_info[1] == 'bias':
continue
layer_name = "Module {}/Parameter {}".format(param_info[0], param_info[1])
layer_to_loss[layer_name] = torch.mean(grouped_loss[i]).detach().cpu().item()
# layer_to_loss[layer_name] = torch.sum(grouped_loss[i]).detach().cpu().item()
return layer_to_loss
if __name__ == '__main__':
student_net = get_network("ConvNet", 3, 10, dist=False)
student_net = ReparamModule(student_net)
# losses = torch.load("losses.pt")
# layer_to_loss = student_net.recover_loss_to_params(losses)
# layers = list(layer_to_loss.keys())
# losses = list(layer_to_loss.values())
# norm = Normalize(vmin=np.min(losses), vmax=np.max(losses))
# cmap = plt.get_cmap('Reds')
# color_params = cmap(norm(losses))
# plt.figure(figsize=(12, 8))
# plt.bar(np.arange(len(losses)), losses, color=color_params)
# plt.xlabel('Layer')
# plt.ylabel('Average Loss')
# plt.title('Average Loss across Different Layers')
# plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), label='Average Loss')
# plt.savefig("loss.png")
# print(layer_to_loss)