Skip to content

Commit 6f41f48

Browse files
committed
Reformat project.
1 parent 21194cb commit 6f41f48

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+472
-328
lines changed

.scalafmt.conf

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
version = "2.0.0-RC8"
22
align = some
3-
maxColumn = 100
3+
maxColumn = 92
44
assumeStandardLibraryStripMargin = true
55
continuationIndent.defnSite = 4
66
continuationIndent.callSite = 2

src/main/scala/botkop/numsca/package.scala

-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ package object numsca {
167167
def multiply(a: Tensor, b: Tensor): Tensor = a * b
168168
def dot(a: Tensor, b: Tensor): Tensor = a dot b
169169

170-
171170
def clip(t: Tensor, min: Double, max: Double): Tensor = t.clip(min, max)
172171

173172
// def concatNd4j(ts: Seq[Tensor], axis: Int): Tensor = {

src/main/scala/funcdiff/API.scala

+30-13
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ import TensorExtension.epsilon
1010
trait APITrait {
1111

1212
sealed trait GraphMode
13+
1314
/** [[CompNode]] created under this mode consumes less memory but does not support back
14-
* propagation. Use this mode during testing. */
15+
* propagation. Use this mode during testing. */
1516
case object ModeEval extends GraphMode
17+
1618
/** [[CompNode]] created under this mode needs to keep the entire computation graph
17-
* in the memory to support gradient back propagation. Use this mode during training. */
19+
* in the memory to support gradient back propagation. Use this mode during training. */
1820
case object ModeTraining extends GraphMode
1921

2022
var debugOpTime = true
@@ -32,7 +34,7 @@ trait APITrait {
3234

3335
def funcNode(func: DiffFunc)(implicit mode: GraphMode): CompNode =
3436
mode match {
35-
case ModeEval => new CompNode(ConstFunc(func.value))
37+
case ModeEval => new CompNode(ConstFunc(func.value))
3638
case ModeTraining => new CompNode(func)
3739
}
3840

@@ -48,7 +50,8 @@ trait APITrait {
4850

4951
def mean(x1: CompNode)(implicit mode: GraphMode): CompNode = funcNode(Mean(x1))
5052

51-
def mean(x1: CompNode, axis: Int)(implicit mode: GraphMode): CompNode = funcNode(MeanByAxis(x1, axis))
53+
def mean(x1: CompNode, axis: Int)(implicit mode: GraphMode): CompNode =
54+
funcNode(MeanByAxis(x1, axis))
5255

5356
def square(x1: CompNode)(implicit mode: GraphMode): CompNode = x1 * x1
5457

@@ -62,7 +65,9 @@ trait APITrait {
6265

6366
def sum(x1: CompNode)(implicit mode: GraphMode): CompNode = funcNode(Sum(x1))
6467

65-
def sum(x1: CompNode, axis: Int, keepDim: Boolean = true)(implicit mode: GraphMode): CompNode =
68+
def sum(x1: CompNode, axis: Int, keepDim: Boolean = true)(
69+
implicit mode: GraphMode
70+
): CompNode =
6671
funcNode(SumByAxis(x1, axis, keepDim))
6772

6873
def softmax(x1: CompNode)(implicit mode: GraphMode): CompNode = funcNode(Softmax(x1))
@@ -71,7 +76,8 @@ trait APITrait {
7176

7277
def abs(x1: CompNode)(implicit mode: GraphMode): CompNode = funcNode(Abs(x1))
7378

74-
def max(x1: CompNode, x2: CompNode)(implicit mode: GraphMode): CompNode = funcNode(MaxBinary(x1, x2))
79+
def max(x1: CompNode, x2: CompNode)(implicit mode: GraphMode): CompNode =
80+
funcNode(MaxBinary(x1, x2))
7581

7682
def plusN(xs: IS[CompNode])(implicit mode: GraphMode): CompNode = funcNode(PlusN(xs))
7783

@@ -89,22 +95,32 @@ trait APITrait {
8995

9096
def stackRows(xs: IS[CompNode])(implicit mode: GraphMode) = funcNode(StackRows(xs))
9197

92-
def concatTupledRows(rows: IS[(CompNode, CompNode)])(implicit mode: GraphMode): CompNode = {
98+
def concatTupledRows(
99+
rows: IS[(CompNode, CompNode)]
100+
)(implicit mode: GraphMode): CompNode = {
93101
val (l, r) = rows.unzip
94102
stackRows(l).concat(stackRows(r), axis = 1)
95103
}
96104

97-
def crossEntropy(prediction: CompNode, targets: Tensor)(implicit mode: GraphMode): CompNode =
105+
def crossEntropy(prediction: CompNode, targets: Tensor)(
106+
implicit mode: GraphMode
107+
): CompNode =
98108
-sum(log(prediction + epsilon) * targets, axis = 1)
99109

100-
def crossEntropyOnSoftmax(logits: CompNode, targets: Tensor)(implicit mode: GraphMode): CompNode =
110+
def crossEntropyOnSoftmax(logits: CompNode, targets: Tensor)(
111+
implicit mode: GraphMode
112+
): CompNode =
101113
funcNode(CrossEntropyOnSoftmax(logits, targets))
102114

103-
def crossEntropyOnSigmoid(logits: CompNode, targets: Tensor)(implicit mode: GraphMode): CompNode = {
115+
def crossEntropyOnSigmoid(logits: CompNode, targets: Tensor)(
116+
implicit mode: GraphMode
117+
): CompNode = {
104118
funcNode(CrossEntropyOnSigmoid(logits, targets))
105119
}
106120

107-
def crossEntropyOnSoftmaxIneff(logits: CompNode, targets: Tensor)(implicit mode: GraphMode): CompNode =
121+
def crossEntropyOnSoftmaxIneff(logits: CompNode, targets: Tensor)(
122+
implicit mode: GraphMode
123+
): CompNode =
108124
-sum(log(softmax(logits) + epsilon) * targets, axis = 1)
109125

110126
def normSquared(x1: CompNode)(implicit mode: GraphMode): CompNode =
@@ -178,7 +194,7 @@ trait APITrait {
178194

179195
implicit class ExtendedFunctions(x1: CompNode) {
180196

181-
def unary_-(implicit mode: GraphMode) : CompNode = funcNode(Negate(x1))
197+
def unary_-(implicit mode: GraphMode): CompNode = funcNode(Negate(x1))
182198

183199
def t(implicit mode: GraphMode): CompNode = funcNode(Transpose(x1))
184200

@@ -195,7 +211,8 @@ trait APITrait {
195211
def concat(x2: CompNode, axis: Int)(implicit mode: GraphMode): CompNode =
196212
funcNode(Concat(x1, x2, axis))
197213

198-
def slice(ranges: NumscaRange*)(implicit mode: GraphMode): CompNode = funcNode(Slice(x1, ranges))
214+
def slice(ranges: NumscaRange*)(implicit mode: GraphMode): CompNode =
215+
funcNode(Slice(x1, ranges))
199216

200217
def dot(x2: CompNode)(implicit mode: GraphMode): CompNode = funcNode(Dot(x1, x2))
201218

src/main/scala/funcdiff/CompNode.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class CompNode(val func: DiffFunc) extends Serializable {
1717

1818
def shape: Shape = value.shape
1919

20-
def reshape(newShape: Shape)(implicit mode: GraphMode): CompNode = funcNode(Reshape(newShape, this))
20+
def reshape(newShape: Shape)(implicit mode: GraphMode): CompNode =
21+
funcNode(Reshape(newShape, this))
2122

2223
def backprop: Map[CompNode, Gradient] = CompNode.backprop(this)
2324

@@ -49,8 +50,7 @@ class CompNode(val func: DiffFunc) extends Serializable {
4950
}
5051

5152
@SerialVersionUID(2L)
52-
class ParamNode(v: Tensor, val path: SymbolPath)
53-
extends CompNode(ConstFunc(v)) {
53+
class ParamNode(v: Tensor, val path: SymbolPath) extends CompNode(ConstFunc(v)) {
5454

5555
override def toString: String = {
5656
s"param{$path, shape=${value.shape}"

src/main/scala/funcdiff/DiffFunc.scala

+7-12
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ private[funcdiff] object DiffFunc {
147147
def name: String = "sum"
148148
}
149149

150-
case class SumByAxis(x1: CompNode, axis: Int, keepDim: Boolean = true)
151-
extends UnaryFunc {
150+
case class SumByAxis(x1: CompNode, axis: Int, keepDim: Boolean = true) extends UnaryFunc {
152151
val value: Tensor = ns.sumAxis(x1.value, axis, keepDim)
153152

154153
def backprop1(grad: Gradient): Gradient = {
@@ -174,8 +173,7 @@ private[funcdiff] object DiffFunc {
174173
}
175174

176175
case class LeakyRelu(x1: CompNode, slope: Double) extends UnaryFunc {
177-
val value
178-
: Tensor = ns.maximum(x1.value, 0.0) + ns.minimum(x1.value, 0.0) * slope
176+
val value: Tensor = ns.maximum(x1.value, 0.0) + ns.minimum(x1.value, 0.0) * slope
179177

180178
def backprop1(grad: Gradient): Gradient = {
181179
grad * ((x1.value > 0).boolToFloating + (x1.value < 0).boolToFloating * slope)
@@ -391,10 +389,9 @@ private[funcdiff] object DiffFunc {
391389
}
392390

393391
/**
394-
* @param fromRows Whether all arguments are row vectors and have the same shape.
395-
*/
396-
case class ConcatN(args: IS[CompNode], axis: Int, fromRows: Boolean)
397-
extends DiffFunc {
392+
* @param fromRows Whether all arguments are row vectors and have the same shape.
393+
*/
394+
case class ConcatN(args: IS[CompNode], axis: Int, fromRows: Boolean) extends DiffFunc {
398395

399396
val value: Tensor =
400397
if (fromRows) ns.fromRows(args.map(_.value), axis)
@@ -445,8 +442,7 @@ private[funcdiff] object DiffFunc {
445442
}
446443

447444
// ================ Loss functions ======================
448-
case class CrossEntropyOnSoftmax(logits: CompNode, targets: Tensor)
449-
extends UnaryFunc {
445+
case class CrossEntropyOnSoftmax(logits: CompNode, targets: Tensor) extends UnaryFunc {
450446
require(
451447
targets.shape == logits.shape,
452448
s"Targets shape (${targets.shape}) is different from logits (${logits.shape})."
@@ -472,8 +468,7 @@ private[funcdiff] object DiffFunc {
472468
}
473469
}
474470

475-
case class CrossEntropyOnSigmoid(logits: CompNode, targets: Tensor)
476-
extends UnaryFunc {
471+
case class CrossEntropyOnSigmoid(logits: CompNode, targets: Tensor) extends UnaryFunc {
477472
require(targets.shape(1) == 1)
478473
require(logits.shape(1) == 1)
479474

src/main/scala/funcdiff/LayerFactory.scala

+46-44
Original file line numberDiff line numberDiff line change
@@ -96,41 +96,42 @@ case class LayerFactory(
9696
def gru(
9797
name: SymbolPath,
9898
initializer: WeightsInitializer = LayerFactory.xavier
99-
)(state: CompNode, input: CompNode)(implicit mode: GraphMode): CompNode = withPrefix(name) { prefix =>
100-
val inputSize = input.shape(1)
101-
val stateSize = state.shape(1)
102-
103-
val Wg = paramCollection
104-
.getVar(prefix / 'Wg, attributes = Set(NeedRegularization)) {
105-
initializer(inputSize, 2 * stateSize)
106-
}
107-
val Ug = paramCollection
108-
.getVar(prefix / 'Ug, attributes = Set(NeedRegularization)) {
109-
initializer(stateSize, 2 * stateSize)
99+
)(state: CompNode, input: CompNode)(implicit mode: GraphMode): CompNode =
100+
withPrefix(name) { prefix =>
101+
val inputSize = input.shape(1)
102+
val stateSize = state.shape(1)
103+
104+
val Wg = paramCollection
105+
.getVar(prefix / 'Wg, attributes = Set(NeedRegularization)) {
106+
initializer(inputSize, 2 * stateSize)
107+
}
108+
val Ug = paramCollection
109+
.getVar(prefix / 'Ug, attributes = Set(NeedRegularization)) {
110+
initializer(stateSize, 2 * stateSize)
111+
}
112+
val bg = paramCollection.getVar(prefix / 'bg) {
113+
ns.zeros(1, 2 * stateSize)
110114
}
111-
val bg = paramCollection.getVar(prefix / 'bg) {
112-
ns.zeros(1, 2 * stateSize)
113-
}
114115

115-
val gates = sigmoid(input.dot(Wg) + state.dot(Ug) + bg)
116-
val updateGate = gates.slice(:>, 0 :> stateSize)
117-
val restGate = gates.slice(:>, stateSize :>)
118-
119-
val Wh = paramCollection
120-
.getVar(prefix / 'Wh, attributes = Set(NeedRegularization)) {
121-
initializer(inputSize, stateSize)
116+
val gates = sigmoid(input.dot(Wg) + state.dot(Ug) + bg)
117+
val updateGate = gates.slice(:>, 0 :> stateSize)
118+
val restGate = gates.slice(:>, stateSize :>)
119+
120+
val Wh = paramCollection
121+
.getVar(prefix / 'Wh, attributes = Set(NeedRegularization)) {
122+
initializer(inputSize, stateSize)
123+
}
124+
val Uh = paramCollection
125+
.getVar(prefix / 'Uh, attributes = Set(NeedRegularization)) {
126+
initializer(stateSize, stateSize)
127+
}
128+
val bh = paramCollection.getVar(prefix / 'bh) {
129+
ns.zeros(1, stateSize)
122130
}
123-
val Uh = paramCollection
124-
.getVar(prefix / 'Uh, attributes = Set(NeedRegularization)) {
125-
initializer(stateSize, stateSize)
126-
}
127-
val bh = paramCollection.getVar(prefix / 'bh) {
128-
ns.zeros(1, stateSize)
129-
}
130131

131-
val hHat = tanh(input.dot(Wh) + (state * restGate).dot(Uh) + bh)
132-
updateGate * hHat + state * (-updateGate + 1)
133-
}
132+
val hHat = tanh(input.dot(Wh) + (state * restGate).dot(Uh) + bh)
133+
updateGate * hHat + state * (-updateGate + 1)
134+
}
134135

135136
/**
136137
* Long short-term memory unit: [https://en.wikipedia.org/wiki/Long_short-term_memory]
@@ -195,19 +196,20 @@ case class LayerFactory(
195196
name: SymbolPath,
196197
stateShape: Shape,
197198
combiner: (CompNode, CompNode) => CompNode
198-
)(inputs: IS[CompNode])(implicit mode: GraphMode): IS[CompNode] = withPrefix(name) { prefix =>
199-
val leftInit: CompNode = getVar(prefix / 'leftInit) {
200-
ns.randn(stateShape)
201-
}
202-
val states1 = inputs.scanLeft(leftInit)(gru(name / 'leftRNN))
203-
val rightInit: CompNode = getVar(prefix / 'rightInit) {
204-
ns.randn(stateShape)
205-
}
206-
val states2 =
207-
inputs.reverse.scanLeft(rightInit)(gru(name / 'rightRNN)).reverse
208-
states1.zip(states2).map {
209-
case (l, r) => combiner(l, r)
210-
}
199+
)(inputs: IS[CompNode])(implicit mode: GraphMode): IS[CompNode] = withPrefix(name) {
200+
prefix =>
201+
val leftInit: CompNode = getVar(prefix / 'leftInit) {
202+
ns.randn(stateShape)
203+
}
204+
val states1 = inputs.scanLeft(leftInit)(gru(name / 'leftRNN))
205+
val rightInit: CompNode = getVar(prefix / 'rightInit) {
206+
ns.randn(stateShape)
207+
}
208+
val states2 =
209+
inputs.reverse.scanLeft(rightInit)(gru(name / 'rightRNN)).reverse
210+
states1.zip(states2).map {
211+
case (l, r) => combiner(l, r)
212+
}
211213
}
212214

213215
/** performs weighted-sum over ys using dot-product attention */

src/main/scala/funcdiff/Optimizer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ trait Optimizer extends Serializable {
5757
(path, g) <- transformed
5858
p <- paramMap.get(path).toIterable
5959
delta = parameterChangeAmount(p, g) * scaleLearningRate
60-
} yield p.synchronized{
60+
} yield p.synchronized {
6161
if (newlyCreated contains p.node) {
6262
delta.addToTensor(p.node.value)
6363
} else {

src/main/scala/funcdiff/Param.scala

+9-9
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ object ParameterAttribute {
1313
}
1414

1515
/**
16-
* Each [[Param]] contains a mutable [[ParamNode]], representing a trainable parameter.
17-
*
18-
* Note that in the future, [[Param]] will not be serializable and should only be provided by
19-
* [[ParamCollection]] to ensure reference consistency. (currently keep it serializable
20-
* to be able to load previously trained model)
21-
*
22-
* todo: make this not serializable
23-
*/
16+
* Each [[Param]] contains a mutable [[ParamNode]], representing a trainable parameter.
17+
*
18+
* Note that in the future, [[Param]] will not be serializable and should only be provided by
19+
* [[ParamCollection]] to ensure reference consistency. (currently keep it serializable
20+
* to be able to load previously trained model)
21+
*
22+
* todo: make this not serializable
23+
*/
2424
@SerialVersionUID(1L)
2525
class Param(
2626
var node: ParamNode,
@@ -41,7 +41,7 @@ case class SymbolPath(repr: Symbol) {
4141
})
4242

4343
@inline
44-
def /(seg: String): SymbolPath = this/Symbol(seg)
44+
def /(seg: String): SymbolPath = this / Symbol(seg)
4545

4646
def ++(other: SymbolPath): SymbolPath = this / other.repr
4747

src/main/scala/funcdiff/SimpleMath.scala

+15-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@ package funcdiff
22

33
import ammonite.ops.Path
44

5-
import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream, ObjectStreamClass, Serializable}
5+
import java.io.{
6+
File,
7+
FileInputStream,
8+
FileOutputStream,
9+
ObjectInputStream,
10+
ObjectOutputStream,
11+
ObjectStreamClass,
12+
Serializable
13+
}
614
import scala.util.Random
715
import collection.mutable
816
import Numeric.Implicits._
@@ -106,7 +114,7 @@ object SimpleMath {
106114
def mean[T: Numeric](xs: Iterable[T]): Double = xs.sum.toDouble / xs.size
107115

108116
def variance[T: Numeric](xs: Iterable[T]): Double = {
109-
if(xs.size > 1) {
117+
if (xs.size > 1) {
110118
val mu = mean(xs)
111119
xs.map(x => square(x.toDouble() - mu)).sum / (xs.size - 1)
112120
} else Double.NaN
@@ -465,10 +473,11 @@ object SimpleMath {
465473

466474
def show: String = {
467475
stat.toSeq
468-
.sortBy(_._2).reverseMap {
469-
case (name, time) =>
470-
s"$name: ${prettyPrintTime(time)}"
471-
}
476+
.sortBy(_._2)
477+
.reverseMap {
478+
case (name, time) =>
479+
s"$name: ${prettyPrintTime(time)}"
480+
}
472481
.mkString("\n")
473482
}
474483
}

0 commit comments

Comments
 (0)