This repository has been archived by the owner on Jun 14, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathinference_parameters.py
281 lines (236 loc) · 11.6 KB
/
inference_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================
import warnings
import mxnet as mx
from mxnet import initializer
from mxnet.gluon import ParameterDict
from ..components.variables import VariableType, Variable
from ..components import ModelComponent
from ..util.inference import realize_shape
from ..common.config import get_default_device, get_default_dtype
class InferenceParameters(object):
"""
The parameters and outcomes of an inference method.
InferenceParameters is a pool of memory that contains a mapping from uuid to two types of memories
(MXNet ParameterDict and Constants).
:param constants: Specify a list of model variables as constants
:type constants: {ModelComponent.uuid : mxnet.ndarray}
:param dtype: data type for internal numerical representation
:type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'}
:param context: The MXNet context
:type context: {mxnet.cpu or mxnet.gpu}
"""
def __init__(self, constants=None, dtype=None, context=None):
self.dtype = dtype if dtype is not None else get_default_dtype()
self.mxnet_context = context if context is not None else get_default_device()
self._constants = {}
self._var_ties = {}
if constants is not None:
constant_uuids = {
(k.uuid if isinstance(k, ModelComponent) else k): v
for k, v in constants.items()}
self._constants.update(constant_uuids)
self._params = ParameterDict()
def update_constants(self, constants):
"""
Update the constants.
:param constants: The constants to be updated.
:type constants: {Variable: float or MXNet NDArray}
"""
self.constants.update({
(k.uuid if isinstance(k, ModelComponent) else k): v
for k, v in constants.items()})
def initialize_params(self, graphs, observed_uuid):
"""
:param graphs: a list of graphs in which the parameters will be optimized.
:type graphs: a list of FactorGraph
:param observed_uuid: Parameter Variables that are passed in directly as data, not to be inferred.
:type observed_uuid: list, set
"""
if self._params is not None:
warnings.warn("InferenceParameters has already been initialized. The existing one will be overwritten.")
self._params = ParameterDict()
for g in graphs:
for var in g.get_constants():
self._constants[var.uuid] = var.constant
excluded = set(self._constants.keys()).union(observed_uuid)
for var in g.get_parameters(excluded=excluded):
var_shape = realize_shape(var.shape, self._constants)
init = initializer.Constant(var.initial_value_before_transformation) \
if var.initial_value is not None else None
self._params.get(name=var.uuid, shape=var_shape,
dtype=self.dtype,
allow_deferred_init=True, init=init)
for m in g.modules.values():
m.initialize_hidden_parameters(self._params, excluded, self._constants)
self._params.initialize(ctx=self.mxnet_context)
def initialize_with_carryover_params(self, graphs, observed_uuid, var_ties,
carryover_params):
"""
:param graphs: a list of graphs in which the parameters will be optimized.
:type graphs: a list of FactorGraph
:param observed_uuid: Parameter Variables that are passed in directly as data, not to be inferred.
:type observed_uuid: {UUID : mx.ndarray}
:param var_ties: A dictionary of variable maps that are tied together and use the MXNet Parameter of the dict
value's uuid.
:type var_ties: { UUID to tie from : UUID to tie to }
:param carryover_params: list of InferenceParameters containing the outcomes of previous inference algorithms.
:type carryover_params: [InferenceParameters]
"""
# TODO: var_ties is discarded at the moment.
var_uuid = set()
for g in graphs:
var_uuid = var_uuid.union(set(g.variables.keys()))
for m in g.modules.values():
var_uuid = var_uuid.union(set(m.hidden_parameters))
carryover_pairs = {}
for carryover in carryover_params:
for uuid, v in carryover.param_dict.items():
if uuid in var_uuid:
if uuid in carryover_pairs:
warnings.warn('The variable with UUID '+uuid+' exists in multiple carryover parameter sets.')
carryover_pairs[uuid] = v
# self._var_ties = var_ties.copy()
# for g in graphs:
# # TODO: check the behavior of var_ties in graph
# self._var_ties.update(g.var_ties)
# for v_uuid in self.constants:
# if v_uuid in self._var_ties:
# del self._var_ties[v_uuid]
observed_uuid = set(observed_uuid).union(carryover_pairs.keys())
self.initialize_params(graphs, observed_uuid)
# carryover_pairs = {
# to_var_uuid: carryover.param_dict[to_var_uuid]
# for from_var_uuid, to_var_uuid in self._var_ties.items()
# for carryover in carryover_params
# if to_var_uuid in carryover.param_dict}
self._params.update(carryover_pairs)
def fix_all(self):
for p in self.param_dict.values():
p.grad_req = 'null'
def fix_variable(self, variable, value=None):
"""
Fixes a variable so that it isn't changed by the inference algorithm. Optionally a value can be specified to
which to variable's value will be fixed. If a value is not specified, the variable will be fixed to it's current
value
:param variable: The variable to be fixed
:type variable: MXFusion Variable
:param value: (Optional) Value to which the variable will be fixed
:type value: MXNet NDArray
"""
if value is not None:
self[variable] = value
parameter = self.param_dict[variable.uuid]
parameter.grad_req = 'null'
def unfix_variable(self, variable):
"""
Allows a variable to be changed by the inference algorithm if it has been previously fixed
:param variable: The variable to be unfixed
:type variable: MXFusion Variable
"""
parameter = self.param_dict[variable.uuid]
parameter.grad_req = 'write'
@property
def param_dict(self):
return self._params
@property
def constants(self):
return self._constants
@property
def var_ties(self):
return self._var_ties
def __getitem__(self, key, ctx=None):
if not isinstance(key, Variable):
raise KeyError("The access key of inference parameter needs to be Variable, but got "+str(type(key))+".")
val = self._params.get(key.uuid).data(ctx)
if key.transformation is not None:
val = key.transformation.transform(val)
return val
def __setitem__(self, key, item):
if not isinstance(key, Variable):
raise KeyError("The access key of inference parameter needs to be Variable, but get "+str(type(key))+".")
if key.type == VariableType.PARAMETER:
if key.transformation is not None:
item = key.transformation.inverseTransform(item)
self._params.get(key.uuid).set_data(item)
elif key.type == VariableType.CONSTANT:
self._params.get(key.uuid)._value = item
# Override contains so that it doesn't use the __getitem__ method.
def __contains__(self, k):
return k in self.__dict__
@staticmethod
def load_parameters(uuid_map=None,
mxnet_parameters=None,
variable_constants=None,
mxnet_constants=None,
context=None, dtype=None,
current_params=None):
"""
Loads back a set of InferenceParameters from files.
:param mxnet_parameters: These are the parameters of
the previous inference algorithm.
These are in a {uuid: mx.nd.array} mapping.
:type mxnet_parameters: Dict of {uuid: mx.nd.array}
:param mxnet_constants: These are the constants in mxnet format
from the previous inference algorithm.
These are in a {uuid: mx.nd.array} mapping.
:type mxnet_constants: Dict of {uuid: mx.nd.array}
:param variable_constants: These are the constants in
primitive format from the previous
inference algorithm.
:type variable_constants: dict of {uuid: constant primitive}
"""
def with_uuid_map(item, uuid_map):
if uuid_map is not None:
return uuid_map[item]
else:
return item
ip = InferenceParameters(context=context, dtype=dtype)
mapped_params = {with_uuid_map(k, uuid_map): v
for k, v in mxnet_parameters.items()}
new_paramdict = ParameterDict()
if current_params is not None:
new_paramdict.update(current_params)
# Do this because we need to map the uuids to the new Model
# before loading them into the ParamDict
for name, mapped_param in mapped_params.items():
new_paramdict[name]._load_init(mapped_param, ip.mxnet_context)
ip._params = new_paramdict
new_mxnet_constants = {}
new_variable_constants = {}
new_variable_constants = {with_uuid_map(k, uuid_map): v
for k, v in variable_constants.items()}
new_mxnet_constants = {with_uuid_map(k, uuid_map): v
for k, v in mxnet_constants.items()}
ip._constants = {}
ip._constants.update(new_variable_constants)
ip._constants.update(new_mxnet_constants)
return ip
def get_serializable(self):
"""
Returns three dicts:
1. MXNet parameters {uuid: mxnet parameters, mx.nd.array}.
2. MXNet constants {uuid: mxnet parameter (only constant types), mx.nd.array}
3. Other constants {uuid: primitive numeric types (int, float)}
:returns: Three dictionaries: MXNet parameters, MXNet constants, and other constants (in that order)
:rtypes: {uuid: mx.nd.array}, {uuid: mx.nd.array}, {uuid: primitive (int/float)}
"""
mxnet_parameters = {key: value._reduce() for key, value in self._params.items()}
mxnet_constants = {uuid: value
for uuid, value in self._constants.items()
if isinstance(value, mx.ndarray.ndarray.NDArray)}
variable_constants = {uuid: value
for uuid, value in self._constants.items()
if uuid not in mxnet_constants}
return mxnet_parameters, mxnet_constants, variable_constants