-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathconvert_parameters.py
49 lines (42 loc) · 1.75 KB
/
convert_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch.autograd import Variable
def parameters_to_vector(parameters):
"""Convert parameters to one vector
Arguments:
parameters (Iterable[Variable]): an iterator of Variables that are the
parameters of a model.
Returns:
The parameters represented by a single vector
"""
vec = []
for param in parameters:
vec.append(param.view(-1))
return torch.cat(vec)
def vector_to_parameters(vec, parameters):
"""Convert one vector to the parameters
Arguments:
vec (Variable): a single vector represents the parameters of a model.
parameters (Iterable[Variable]): an iterator of Variables that are the
parameters of a model.
"""
# Ensure vec of type Variable
if not isinstance(vec, Variable):
raise TypeError('expected torch.autograd.Variable, but got: {}'
.format(torch.typename(vec)))
# Flag for the device where the parameter is located
param_device = None
# Pointer for slicing the vector for each parameter
pointer = 0
for param in parameters:
# Ensure the parameters are located in the same device
if param_device is None:
param_device = param.get_device() if param.is_cuda else -1
else:
if param.get_device() != param_device:
raise TypeError('Found two parameters on different devices, this is currently not supported.')
# The length of the parameter
num_param = torch.prod(torch.LongTensor(list(param.size())))
# Slice the vector, reshape it, and replace the old data of the parameter
param.data = vec[pointer:pointer + num_param].view(param.size()).data
# Increment the pointer
pointer += num_param