@@ -10,11 +10,13 @@ import TensorExtension.epsilon
10
10
trait APITrait {
11
11
12
12
sealed trait GraphMode
13
+
13
14
/** [[CompNode ]] created under this mode consumes less memory but does not support back
14
- * propagation. Use this mode during testing. */
15
+ * propagation. Use this mode during testing. */
15
16
case object ModeEval extends GraphMode
17
+
16
18
/** [[CompNode ]] created under this mode needs to keep the entire computation graph
17
- * in the memory to support gradient back propagation. Use this mode during training. */
19
+ * in the memory to support gradient back propagation. Use this mode during training. */
18
20
case object ModeTraining extends GraphMode
19
21
20
22
var debugOpTime = true
@@ -32,7 +34,7 @@ trait APITrait {
32
34
33
35
def funcNode (func : DiffFunc )(implicit mode : GraphMode ): CompNode =
34
36
mode match {
35
- case ModeEval => new CompNode (ConstFunc (func.value))
37
+ case ModeEval => new CompNode (ConstFunc (func.value))
36
38
case ModeTraining => new CompNode (func)
37
39
}
38
40
@@ -48,7 +50,8 @@ trait APITrait {
48
50
49
51
def mean (x1 : CompNode )(implicit mode : GraphMode ): CompNode = funcNode(Mean (x1))
50
52
51
- def mean (x1 : CompNode , axis : Int )(implicit mode : GraphMode ): CompNode = funcNode(MeanByAxis (x1, axis))
53
+ def mean (x1 : CompNode , axis : Int )(implicit mode : GraphMode ): CompNode =
54
+ funcNode(MeanByAxis (x1, axis))
52
55
53
56
def square (x1 : CompNode )(implicit mode : GraphMode ): CompNode = x1 * x1
54
57
@@ -62,7 +65,9 @@ trait APITrait {
62
65
63
66
def sum (x1 : CompNode )(implicit mode : GraphMode ): CompNode = funcNode(Sum (x1))
64
67
65
- def sum (x1 : CompNode , axis : Int , keepDim : Boolean = true )(implicit mode : GraphMode ): CompNode =
68
+ def sum (x1 : CompNode , axis : Int , keepDim : Boolean = true )(
69
+ implicit mode : GraphMode
70
+ ): CompNode =
66
71
funcNode(SumByAxis (x1, axis, keepDim))
67
72
68
73
def softmax (x1 : CompNode )(implicit mode : GraphMode ): CompNode = funcNode(Softmax (x1))
@@ -71,7 +76,8 @@ trait APITrait {
71
76
72
77
def abs (x1 : CompNode )(implicit mode : GraphMode ): CompNode = funcNode(Abs (x1))
73
78
74
- def max (x1 : CompNode , x2 : CompNode )(implicit mode : GraphMode ): CompNode = funcNode(MaxBinary (x1, x2))
79
+ def max (x1 : CompNode , x2 : CompNode )(implicit mode : GraphMode ): CompNode =
80
+ funcNode(MaxBinary (x1, x2))
75
81
76
82
def plusN (xs : IS [CompNode ])(implicit mode : GraphMode ): CompNode = funcNode(PlusN (xs))
77
83
@@ -89,22 +95,32 @@ trait APITrait {
89
95
90
96
def stackRows (xs : IS [CompNode ])(implicit mode : GraphMode ) = funcNode(StackRows (xs))
91
97
92
- def concatTupledRows (rows : IS [(CompNode , CompNode )])(implicit mode : GraphMode ): CompNode = {
98
+ def concatTupledRows (
99
+ rows : IS [(CompNode , CompNode )]
100
+ )(implicit mode : GraphMode ): CompNode = {
93
101
val (l, r) = rows.unzip
94
102
stackRows(l).concat(stackRows(r), axis = 1 )
95
103
}
96
104
97
- def crossEntropy (prediction : CompNode , targets : Tensor )(implicit mode : GraphMode ): CompNode =
105
+ def crossEntropy (prediction : CompNode , targets : Tensor )(
106
+ implicit mode : GraphMode
107
+ ): CompNode =
98
108
- sum(log(prediction + epsilon) * targets, axis = 1 )
99
109
100
- def crossEntropyOnSoftmax (logits : CompNode , targets : Tensor )(implicit mode : GraphMode ): CompNode =
110
+ def crossEntropyOnSoftmax (logits : CompNode , targets : Tensor )(
111
+ implicit mode : GraphMode
112
+ ): CompNode =
101
113
funcNode(CrossEntropyOnSoftmax (logits, targets))
102
114
103
- def crossEntropyOnSigmoid (logits : CompNode , targets : Tensor )(implicit mode : GraphMode ): CompNode = {
115
+ def crossEntropyOnSigmoid (logits : CompNode , targets : Tensor )(
116
+ implicit mode : GraphMode
117
+ ): CompNode = {
104
118
funcNode(CrossEntropyOnSigmoid (logits, targets))
105
119
}
106
120
107
- def crossEntropyOnSoftmaxIneff (logits : CompNode , targets : Tensor )(implicit mode : GraphMode ): CompNode =
121
+ def crossEntropyOnSoftmaxIneff (logits : CompNode , targets : Tensor )(
122
+ implicit mode : GraphMode
123
+ ): CompNode =
108
124
- sum(log(softmax(logits) + epsilon) * targets, axis = 1 )
109
125
110
126
def normSquared (x1 : CompNode )(implicit mode : GraphMode ): CompNode =
@@ -178,7 +194,7 @@ trait APITrait {
178
194
179
195
implicit class ExtendedFunctions (x1 : CompNode ) {
180
196
181
- def unary_- (implicit mode : GraphMode ) : CompNode = funcNode(Negate (x1))
197
+ def unary_- (implicit mode : GraphMode ): CompNode = funcNode(Negate (x1))
182
198
183
199
def t (implicit mode : GraphMode ): CompNode = funcNode(Transpose (x1))
184
200
@@ -195,7 +211,8 @@ trait APITrait {
195
211
def concat (x2 : CompNode , axis : Int )(implicit mode : GraphMode ): CompNode =
196
212
funcNode(Concat (x1, x2, axis))
197
213
198
- def slice (ranges : NumscaRange * )(implicit mode : GraphMode ): CompNode = funcNode(Slice (x1, ranges))
214
+ def slice (ranges : NumscaRange * )(implicit mode : GraphMode ): CompNode =
215
+ funcNode(Slice (x1, ranges))
199
216
200
217
def dot (x2 : CompNode )(implicit mode : GraphMode ): CompNode = funcNode(Dot (x1, x2))
201
218
0 commit comments