-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathexecutor.py
145 lines (132 loc) · 5.27 KB
/
executor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Executes a program represented by a list of operations.
(Do not directly use these methods; instead use a Program class).
"""
import pickle
import torch
from typing import List
import traceback
from mlca.program import Program
import mlca.operations as operations
class ProgramExecutionError(RuntimeError):
pass
def _execute_program(
program_operations: List[operations.Operation], input_values, data_structure_values,
optimizer_values, perform_mutations, print_on_error=True, profiler=None,
i_episode=None):
intermediate_values = {
** input_values,
** data_structure_values,
** optimizer_values
}
values_to_add_to_loss = []
output_values = []
for operation in program_operations:
input_values = [intermediate_values[i] for i in operation.inputs]
output_value = "UNSET"
try:
output_value = operation.execute(
input_values, profiler=profiler, i_episode=i_episode)
output_values.append(output_value)
intermediate_values[operation] = output_value
if operation.add_to_loss:
values_to_add_to_loss.append(operation.value_to_add_to_loss(
input_values, output_value, profiler, i_episode
))
if profiler is not None:
profiler.tick(i_episode, str(type(operation)))
except Exception as e:
if (type(output_value) == str and output_value == "UNSET") or not operation.cached_output_type.is_correctly_formatted_value(output_value):
if print_on_error or True:
print("\n\n!!!!!!!!!!!!!!!!")
print("Operation failed")
print(operation)
# print(input_values)
print(e)
traceback.print_exc()
for inp in input_values:
if inp is None:
print("inp", inp)
elif type(inp) is list:
print("inp list", inp[0].shape)
elif type(inp) is torch.Tensor:
print("inp", inp.shape, inp.device)
else:
print("inp", inp)
if output_value is None:
print(type(operation), output_value)
elif type(output_value) is list:
print(type(operation), "list", output_value[0].shape)
elif type(output_value) is torch.Tensor:
print(type(operation), output_value.shape)
else:
print(type(operation), output_value)
print("!!!!!!!!!!!!!!!!")
raise ProgramExecutionError(e)
for operation, output_value in zip(program_operations, output_values):
try:
assert operation.cached_output_type.value_class == type(output_value), (
"wanted", operation.cached_output_type.value_class, "got", type(output_value), operation)
assert operation.cached_output_type.is_correctly_formatted_value(output_value)
assert operation.cached_output_type.is_valid_value(
output_value)
except Exception as e:
if (type(output_value) == str and output_value == "UNSET") or not operation.cached_output_type.is_correctly_formatted_value(output_value):
if print_on_error or True:
print("\n\n!!!!!!!!!!!!!!!!")
print("Operation failed")
print(operation)
# print(input_values)
print(e)
traceback.print_exc()
for inp in input_values:
if inp is None:
print("inp", inp)
elif type(inp) is list:
print("inp list", inp[0].shape)
elif type(inp) is torch.Tensor:
print("inp", inp.shape, inp.device)
else:
print("inp", inp)
if output_value is None:
print(type(operation), output_value)
elif type(output_value) is list:
print(type(operation), "list", output_value[0].shape)
elif type(output_value) is torch.Tensor:
print(type(operation), output_value.shape)
else:
print(type(operation), output_value)
print("!!!!!!!!!!!!!!!!")
raise ProgramExecutionError(e)
if perform_mutations:
try:
if len(values_to_add_to_loss) > 0:
assert len(optimizer_values) == 1, f"Wrong # of optimizers given {optimizers}"
optimizer = list(optimizer_values.values())[0]
for v in values_to_add_to_loss:
assert v.shape == tuple() or len(v.shape) == 1, v.shape
loss = torch.sum(torch.stack([v.mean() for v in values_to_add_to_loss]))
if profiler is not None:
profiler.tick(i_episode, "Executor MinimizeValue: setup")
optimizer.zero_grad()
if profiler is not None:
profiler.tick(i_episode, "Executor MinimizeValue: zero grad")
loss.backward()
if profiler is not None:
profiler.tick(i_episode, "Executor MinimizeValue: backward")
# for param in self.q_net.parameters():
# if param.grad is not None:
# param.grad.data.clamp_(-1, 1)
optimizer.step()
if profiler is not None:
profiler.tick(i_episode, "Executor MinimizeValue: step")
except Exception as e:
if print_on_error or True:
print("\n\n!!!!!!!!!!!!!!!!")
print("Backprop gradients failed")
for p in program_operations:
print(p)
print(values_to_add_to_loss)
print("!!!!!!!!!!!!!!!!")
raise ProgramExecutionError(e)
return intermediate_values