-
Notifications
You must be signed in to change notification settings - Fork 3
/
op.py
133 lines (100 loc) · 2.81 KB
/
op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import List
import numpy as np
from .graph import Graph, Operation, Variable
__all__ = [
"get_weight",
"Conv2dOp",
"DenseOp",
"PoolOp",
"MaxPoolOp",
"AvgPoolOp",
"ReluOp",
"TransposeOp",
"PadOp",
"ReshapeOp",
"BatchNormOp",
"AddOp",
"SqueezeOp",
"ConcatOp",
"MeanOp",
]
class DenseOp(Operation):
def __init__(self, graph: Graph, name: str, weight: Variable, bias: Variable):
super().__init__(graph, name)
self.variables.extend([weight, bias])
self.weight = weight
self.bias = bias
class Conv2dOp(Operation):
def __init__(
self,
graph: Graph,
name: str,
kernel: Variable,
bias: Variable,
padding: str,
strides=None,
dilations=None,
data_format: str = None,
):
super().__init__(graph, name)
self.variables.extend([kernel, bias])
self.kernel = kernel
self.bias = bias
self.padding = padding
self.strides = strides
self.dilation_rate = dilations
self.data_format = data_format
def get_weight(op: Operation) -> Variable:
if isinstance(op, DenseOp):
return op.weight
elif isinstance(op, Conv2dOp):
return op.kernel
else:
raise RuntimeError(f"op {op.name} with type {type(op)} doesn't have weight")
class PoolOp(Operation):
def __init__(
self,
graph: Graph,
name: str,
filter_shape: List[int],
padding: str,
strides: List[int] = None,
data_format: str = None,
):
super().__init__(graph, name)
self.filter_shape = filter_shape
self.padding = padding
self.strides = strides
self.data_format = data_format
class MaxPoolOp(PoolOp):
...
class AvgPoolOp(PoolOp):
...
class TransposeOp(Operation):
def __init__(self, graph: Graph, name: str, perm: List[int]):
super().__init__(graph, name)
self.perm = perm
class ConcatOp(Operation):
def __init__(self, graph: Graph, name: str, axis: int):
super().__init__(graph, name)
self.axis = axis
class SqueezeOp(Operation):
def __init__(self, graph: Graph, name: str, squeeze_dims: List[int]):
super().__init__(graph, name)
self.squeeze_dims = squeeze_dims
class PadOp(Operation):
def __init__(self, graph: Graph, name: str, paddings: np.ndarray):
super().__init__(graph, name)
self.paddings = paddings
class MeanOp(Operation):
def __init__(self, graph: Graph, name: str, reduction_indices: List[int]):
super().__init__(graph, name)
self.reduction_indices = reduction_indices
class ReshapeOp(Operation):
...
class ReluOp(Operation):
...
class BatchNormOp(Operation):
...
class AddOp(Operation):
...