-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathresource_ops.py
139 lines (95 loc) · 3.55 KB
/
resource_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Defines Ops for creating, updating, and reading variables."""
from collections import namedtuple
import numpy as np
from .mixins import _ScalarShape
from .operation import Operation
from .tensor_shape import TensorShape
class Placeholder(Operation):
"""Op that generates a `Placeholder` Tensor.
Input (0): None.
Output (1): a `Placeholder` Tensor.
Side Effect: None.
"""
def __init__(self, shape, graph=None, name=None):
"""Constrcutor.
Args:
shape (List or Tuple): shape of the `Placeholder` Tensor.
graph (Graph): (Optional) the parent `Graph`.
name (str): (Optional) name of the Op.
"""
self._shape = list(shape)
super(Placeholder, self).__init__(graph=graph, name=name)
def _run(self):
if self.id not in self._graph._runtime._placeholder_values:
raise ValueError(f"Placeholder {self.id}'s value is not initialized.")
outputs = self._graph._runtime._placeholder_values[self.id]
return outputs
def _compute_shapes(self):
return [TensorShape(self._shape)]
class CreateVariable(Operation, _ScalarShape):
"""Op that initializes a variable.
Input (0): None.
Output (1): ID of this `CreateVariable` Op.
Side Effect: setting the initialized value in the `Runtime` in which the
`Graph` runs.
"""
def __init__(self, shape, init_fn, graph=None, name=None):
"""Constructor.
Args:
shape (List or Tuple): shape of the variable.
init_fn (callable): a function that return the value of the variable.
graph (Graph): (Optional) the parent `Graph`.
name (str): (Optional) name of the Op.
"""
self._shape = tuple(shape)
self._init_fn = init_fn
super(CreateVariable, self).__init__(graph=graph, name=name)
def _run(self):
if self.id not in self._graph._runtime._variable_values:
init_value = self._init_fn(shape=self._shape)
assert init_value.shape == self._shape
self._graph._runtime._variable_values[self.id] = init_value
return np.asarray(self.id)
class AssignVariable(Operation):
"""Op that has the side effect of updating the value of a variable with
`new_value`.
Input (2): a Tensor from `CreateVariable` Op and another Tensor whose value is
used to assign to the variable.
Output (0): None.
Side Effect: setting the new value in the `Runtime` in which the `Graph` runs.
"""
def _run(self, creator_id, new_value):
self._graph_runtime._variable_values[creator_id.item()] = new_value
@property
def num_outputs(self):
return 0
def _compute_shapes(self):
return None
class AddToVariable(Operation):
"""Op that has the side effect of updating the value of a variable by adding
`delta` to the orignal value.
Input (2): a Tensor from `CreateVariable` Op and another Tensor whose value is
used to add to the variable value.
Output (0): None.
Side Effect: setting the new value in the `Runtime` in which the `Graph` runs.
"""
def _run(self, creator_id, delta):
self._graph._runtime._variable_values[creator_id.item()] += delta
@property
def num_outputs(self):
return 0
def _compute_shapes(self):
return None
class ReadVariable(Operation):
"""Op that returns the value of an initialized variable.
Input (1): a Tensor from `CreateVariable` Op.
Output (1): the value of the variable.
Side Effect: None.
"""
def _run(self, creator_id):
outputs = self._graph._runtime.get_variable_value(creator_id.item())
return outputs
def _compute_shapes(self):
return [TensorShape(list(self._input_list[0].op._shape))]
def mutable(self):
return True