forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayer_test_util.py
139 lines (114 loc) · 4.74 KB
/
layer_test_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
## @package layer_test_util
# Module caffe2.python.layer_test_util
from collections import namedtuple
from caffe2.python import (
core,
layer_model_instantiator,
layer_model_helper,
schema,
test_util,
workspace,
utils,
)
from caffe2.proto import caffe2_pb2
import numpy as np
# pyre-fixme[13]: Pyre can't detect attribute initialization through the
# super().__new__ call
class OpSpec(namedtuple("OpSpec", "type input output arg")):
def __new__(cls, op_type, op_input, op_output, op_arg=None):
return super(OpSpec, cls).__new__(cls, op_type, op_input,
op_output, op_arg)
class LayersTestCase(test_util.TestCase):
def setUp(self):
super().setUp()
self.setup_example()
def setup_example(self):
"""
This is undocumented feature in hypothesis,
https://github.com/HypothesisWorks/hypothesis-python/issues/59
"""
workspace.ResetWorkspace()
self.reset_model()
def reset_model(self, input_feature_schema=None, trainer_extra_schema=None):
input_feature_schema = input_feature_schema or schema.Struct(
('float_features', schema.Scalar((np.float32, (32,)))),
)
trainer_extra_schema = trainer_extra_schema or schema.Struct()
self.model = layer_model_helper.LayerModelHelper(
'test_model',
input_feature_schema=input_feature_schema,
trainer_extra_schema=trainer_extra_schema)
def new_record(self, schema_obj):
return schema.NewRecord(self.model.net, schema_obj)
def get_training_nets(self, add_constants=False):
"""
We don't use
layer_model_instantiator.generate_training_nets_forward_only()
here because it includes initialization of global constants, which make
testing tricky
"""
train_net = core.Net('train_net')
if add_constants:
train_init_net = self.model.create_init_net('train_init_net')
else:
train_init_net = core.Net('train_init_net')
for layer in self.model.layers:
layer.add_operators(train_net, train_init_net)
return train_init_net, train_net
def get_eval_net(self):
return layer_model_instantiator.generate_eval_net(self.model)
def get_predict_net(self):
return layer_model_instantiator.generate_predict_net(self.model)
def run_train_net(self):
self.model.output_schema = schema.Struct()
train_init_net, train_net = \
layer_model_instantiator.generate_training_nets(self.model)
workspace.RunNetOnce(train_init_net)
workspace.RunNetOnce(train_net)
def run_train_net_forward_only(self, num_iter=1):
self.model.output_schema = schema.Struct()
train_init_net, train_net = \
layer_model_instantiator.generate_training_nets_forward_only(
self.model)
workspace.RunNetOnce(train_init_net)
assert num_iter > 0, 'num_iter must be larger than 0'
workspace.CreateNet(train_net)
workspace.RunNet(train_net.Proto().name, num_iter=num_iter)
def assertBlobsEqual(self, spec_blobs, op_blobs):
"""
spec_blobs can either be None or a list of blob names. If it's None,
then no assertion is performed. The elements of the list can be None,
in that case, it means that position will not be checked.
"""
if spec_blobs is None:
return
self.assertEqual(len(spec_blobs), len(op_blobs))
for spec_blob, op_blob in zip(spec_blobs, op_blobs):
if spec_blob is None:
continue
self.assertEqual(spec_blob, op_blob)
def assertArgsEqual(self, spec_args, op_args):
self.assertEqual(len(spec_args), len(op_args))
keys = [a.name for a in op_args]
def parse_args(args):
operator = caffe2_pb2.OperatorDef()
# Generate the expected value in the same order
for k in keys:
v = args[k]
arg = utils.MakeArgument(k, v)
operator.arg.add().CopyFrom(arg)
return operator.arg
self.assertEqual(parse_args(spec_args), op_args)
def assertNetContainOps(self, net, op_specs):
"""
Given a net and a list of OpSpec's, check that the net match the spec
"""
ops = net.Proto().op
self.assertEqual(len(op_specs), len(ops))
for op, op_spec in zip(ops, op_specs):
self.assertEqual(op_spec.type, op.type)
self.assertBlobsEqual(op_spec.input, op.input)
self.assertBlobsEqual(op_spec.output, op.output)
if op_spec.arg is not None:
self.assertArgsEqual(op_spec.arg, op.arg)
return ops