From a2261330c227be8ef26172dbe355a617d653553a Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 23 Jul 2015 07:55:15 -0700 Subject: [PATCH 1/9] Multilayer Perceptron regressor and classifier ANN test --- .../MultilayerPerceptronClassifier.scala | 130 +++ .../MultilayerPerceptronRegressor.scala | 206 +++++ .../apache/spark/mllib/ann/BreezeUtil.scala | 67 ++ .../org/apache/spark/mllib/ann/Layer.scala | 856 ++++++++++++++++++ .../MultilayerPerceptronClassifierSuite.scala | 52 ++ .../MultilayerPerceptronRegressorSuite.scala | 59 ++ .../org/apache/spark/mllib/ann/ANNSuite.scala | 74 ++ 7 files changed, 1444 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala new file mode 100644 index 0000000000000..b40597c556221 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import breeze.linalg.{argmax => Bargmax} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.regression.MultilayerPerceptronParams +import org.apache.spark.mllib.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.DataFrame + +/** + * :: Experimental :: + * Label to vector converter. + */ +@Experimental +private object LabelConverter { + + /** + * Encodes a label as a vector. + * Returns a vector of given length with zeroes at all positions + * and value 1.0 at the position that corresponds to the label. + * + * @param labeledPoint labeled point + * @param labelCount total number of labels + * @return vector encoding of a label + */ + def apply(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount){0.0} + output(labeledPoint.label.toInt) = 1.0 + (labeledPoint.features, Vectors.dense(output)) + } + + /** + * Converts a vector to a label. + * Returns the position of the maximal element of a vector. + * + * @param output label encoded with a vector + * @return label + */ + def apply(output: Vector): Double = { + Bargmax(output.toBreeze.toDenseVector).toDouble + } +} + +/** + * :: Experimental :: + * Classifier trainer based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * Number of inputs has to be equal to the size of feature vectors. + * Number of outputs has to be equal to the total number of labels. + * + */ +@Experimental +class MultilayerPerceptronClassifier (override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + with MultilayerPerceptronParams { + + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + + def this() = this(Identifiable.randomUID("mlpc")) + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + val labels = getLayers.last.toInt + val lpData = extractLabeledPoints(dataset) + val data = lpData.map(lp => LabelConverter(lp, labels)) + val myLayers = getLayers.map(_.toInt) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol(getTol).setNumIterations(getMaxIter) + FeedForwardTrainer.setStackSize(getBlockSize) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + } +} + +/** + * :: Experimental :: + * Classifier model based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + */ +@Experimental +class MultilayerPerceptronClassifierModel private[ml] (override val uid: String, + layers: Array[Int], + weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + with Serializable { + + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + override protected def predict(features: Vector): Double = { + LabelConverter(mlpModel.predict(features)) + } + + override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { + copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala new file mode 100644 index 0000000000000..28deeaab8428b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import breeze.linalg.{argmax => Bargmax} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.{Model, Transformer, Estimator, PredictorParams} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.ann.{FeedForwardTopology, FeedForwardTrainer} +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for Multilayer Perceptron. + */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams +with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = + // TODO: we need IntegerArrayParam! + new IntArrayParam(this, "layers", + "Sizes of layers including input and output from bottom to the top." + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "hidden layer with 100 neurons and output layer of 10 neurons." + // TODO: how to check that array is not empty? + ) + + /** + * Block size for stacking input data in matrices. Speeds up the computations. + * Cannot be more than the size of the dataset. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices.", + ParamValidators.gt(0)) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * Default is 11L. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + setDefault(seed -> 11L, maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 1) +} + +/** + * :: Experimental :: + * Multi-layer perceptron regression. Contains sigmoid activation function on all layers. + * See https://en.wikipedia.org/wiki/Multilayer_perceptron for details. + * + */ +@Experimental +class MultilayerPerceptronRegressor (override val uid: String) + extends Estimator[MultilayerPerceptronRegressorModel] + with MultilayerPerceptronParams with HasInputCol with HasOutputCol with HasRawPredictionCol + with Logging { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Fits a model to the input and output data. + * InputCol has to contain input vectors. + * OutputCol has to contain output vectors. + */ + override def fit(dataset: DataFrame): MultilayerPerceptronRegressorModel = { + val data = dataset.select($(inputCol), $(outputCol)).map { + case Row(x: Vector, y: Vector) => (x, y) + } + val myLayers = getLayers + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, false) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol(getTol).setNumIterations(getMaxIter) + FeedForwardTrainer.setStackSize(getBlockSize) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronRegressorModel(uid, myLayers, mlpModel.weights()) + } + + /** + * :: DeveloperApi :: + * + * Derives the output schema from the input schema. + */ + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + val outputType = schema($(outputCol)).dataType + require(outputType.isInstanceOf[VectorUDT], + s"Input column ${$(outputCol)} must be a vector column") + require(!schema.fieldNames.contains($(rawPredictionCol)), + s"Output column ${$(rawPredictionCol)} already exists.") + val outputFields = schema.fields :+ StructField($(rawPredictionCol), new VectorUDT, false) + StructType(outputFields) + } + + def this() = this(Identifiable.randomUID("mlpr")) + + override def copy(extra: ParamMap): MultilayerPerceptronRegressor = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Multi-layer perceptron regression model. + * + * @param layers array of layer sizes including input and output + * @param weights weights (or parameters) of the model + */ +@Experimental +class MultilayerPerceptronRegressorModel private[ml] (override val uid: String, + layers: Array[Int], + weights: Vector) + extends Model[MultilayerPerceptronRegressorModel] + with HasInputCol with HasRawPredictionCol { + + private val mlpModel = + FeedForwardTopology.multiLayerPerceptron(layers, false).getInstance(weights) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** + * Transforms the input dataset. + * InputCol has to contain input vectors. + * RawPrediction column will contain predictions (outputs of the regressor). + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val pcaOp = udf { mlpModel.predict _ } + dataset.withColumn($(rawPredictionCol), pcaOp(col($(inputCol)))) + } + + /** + * :: DeveloperApi :: + * + * Derives the output schema from the input schema. + */ + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(rawPredictionCol)), + s"Output column ${$(rawPredictionCol)} already exists.") + val outputFields = schema.fields :+ StructField($(rawPredictionCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): MultilayerPerceptronRegressorModel = { + copyValues(new MultilayerPerceptronRegressorModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala new file mode 100644 index 0000000000000..02e3601532cb7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.ann + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * In-place DGEMM and DGEMV for Breeze + */ +object BreezeUtil { + + private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + + /** + * DGEMM: C := alpha * A * B + beta * C + * @param alpha alpha + * @param a A + * @param b B + * @param beta beta + * @param c C + */ + def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + // TODO: add code if matrices isTranspose!!! + require(a.cols == b.rows, "A & B Dimension mismatch!") + require(a.rows == c.rows, "A & C Dimension mismatch!") + require(b.cols == c.cols, "A & C Dimension mismatch!") + if(a.rows == 0 || b.rows == 0 || a.cols == 0 || b.cols == 0) { + } else { + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) + } + } + + /** + * DGEMV: y := alpha * A * x + beta * y + * @param alpha alpha + * @param a A + * @param x x + * @param beta beta + * @param y y + */ + def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + + require(a.cols == x.length, "A & b Dimension mismatch!") + + NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, + alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + beta, y.data, y.offset, y.stride) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala new file mode 100644 index 0000000000000..3f85b80f112b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala @@ -0,0 +1,856 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.ann + +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => brzAxpy, +sum => Bsum} +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Trait that holds Layer properties, that are needed to instantiate it. + * Implements Layer instantiation. + * + */ +private[ann] trait Layer extends Serializable { + /** + * Returns the instance of the layer based on weights provided + * @param weights vector with layer weights + * @param position position of weights in the vector + * @return the layer model + */ + def getInstance(weights: Vector, position: Int): LayerModel + + /** + * Returns the instance of the layer with random generated weights + * @param seed seed + * @return the layer model + */ + def getInstance(seed: Long): LayerModel +} + +/** + * Trait that holds Layer weights (or parameters). + * Implements functions needed for forward propagation, computing delta and gradient. + * Can return weights in Vector format. + */ +private[ann] trait LayerModel extends Serializable { + /** + * number of weights + */ + val size: Int + + /** + * Evaluates the data (process the data through the layer) + * @param data data + * @return processed data + */ + def eval(data: BDM[Double]): BDM[Double] + + /** + * Computes the delta for back propagation + * @param nextDelta delta of the next layer + * @param input input data + * @return delta + */ + def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + + /** + * Computes the gradient + * @param delta delta for this layer + * @param input input data + * @return gradient + */ + def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] + + /** + * Returns weights for the layer in a single vector + * @return layer weights + */ + def weights(): Vector +} + +/** + * Layer properties of affine transformations, that is y=A*x+b + * @param numIn number of inputs + * @param numOut number of outputs + */ +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { + + override def getInstance(weights: Vector, position: Int): LayerModel = { + AffineLayerModel(this, weights, position) + } + + override def getInstance(seed: Long = 11L): LayerModel = { + AffineLayerModel(this, seed) + } +} + +/** + * Model of Affine layer y=A*x+b + * @param w weights (matrix A) + * @param b bias (vector b) + */ +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { + val size = w.size + b.length + val gwb = new Array[Double](size) + private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) + private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) + private var z: BDM[Double] = null + private var d: BDM[Double] = null + private var ones: BDV[Double] = null + + override def eval(data: BDM[Double]): BDM[Double] = { + if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) + z(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, z) + z + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) + BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) + gwb + } + + override def weights(): Vector = AffineLayerModel.roll(w, b) +} + +/** + * Fabric for Affine layer models + */ +private[ann] object AffineLayerModel { + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param weights vector with weights + * @param position position of weights in the vector + * @return model of Affine layer + */ + def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { + val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) + new AffineLayerModel(w, b) + } + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param seed seed + * @return model of Affine layer + */ + def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { + val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) + new AffineLayerModel(w, b) + } + + /** + * Unrolls the weights from the vector + * @param weights vector with weights + * @param position position of weights for this layer + * @param numIn number of layer inputs + * @param numOut number of layer outputs + * @return matrix A and vector b + */ + def unroll(weights: Vector, position: Int, + numIn: Int, numOut: Int): (BDM[Double], BDV[Double]) = { + val weightsCopy = weights.toArray + // TODO: the array is not copied to BDMs, make sure this is OK! + val a = new BDM[Double](numOut, numIn, weightsCopy, position) + val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) + (a, b) + } + + /** + * Roll the layer weights into a vector + * @param a matrix A + * @param b vector b + * @return vector of weights + */ + def roll(a: BDM[Double], b: BDV[Double]): Vector = { + val result = new Array[Double](a.size + b.length) + // TODO: make sure that we need to copy! + System.arraycopy(a.toArray, 0, result, 0, a.size) + System.arraycopy(b.toArray, 0, result, a.size, b.length) + Vectors.dense(result) + } + + /** + * Generate random weights for the layer + * @param numIn number of inputs + * @param numOut number of outputs + * @param seed seed + * @return (matrix A, vector b) + */ + def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { + val rand: XORShiftRandom = new XORShiftRandom(seed) + val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } + val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } + (weights, bias) + } +} + +/** + * Trait for functions and their derivatives for functional layers + */ +private[ann] trait ActivationFunction extends Serializable { + + /** + * Implements a function + * @param x input data + * @param y output data + */ + def eval(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a derivative of a function (needed for the back propagation) + * @param x input data + * @param y output data + */ + def derivative(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a cross entropy error of a function. + * Needed if the functional layer that contains this function is the output layer + * of the network. + * @param target target output + * @param output computed output + * @param result intermediate result + * @return cross-entropy + */ + def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + + /** + * Implements a mean squared error of a function + * @param target target output + * @param output computed output + * @param result intermediate result + * @return mean squared error + */ + def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double +} + +/** + * Implements in-place application of functions + */ +private[ann] object ActivationFunction { + + def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { + var i = 0 + while (i < x.rows) { + var j = 0 + while (j < x.cols) { + y(i, j) = func(x(i, j)) + j += 1 + } + i += 1 + } + } + + def apply(x1: BDM[Double], x2: BDM[Double], y: BDM[Double], + func: (Double, Double) => Double): Unit = { + var i = 0 + while (i < x1.rows) { + var j = 0 + while (j < x1.cols) { + y(i, j) = func(x1(i, j), x2(i, j)) + j += 1 + } + i += 1 + } + } + +} + +/** + * Implements SoftMax activation function + */ +private[ann] class SoftmaxFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < x.cols) { + var i = 0 + var max = Double.MinValue + while (i < x.rows) { + if (x(i, j) > max) { + max = x(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < x.rows) { + val res = Math.exp(x(i, j) - max) + y(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < x.rows) { + y(i, j) /= sum + i += 1 + } + j += 1 + } + } + + override def crossEntropy(output: BDM[Double], target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") + } +} + +/** + * Implements Sigmoid activation function + */ +private[ann] class SigmoidFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + def s(z: Double): Double = Bsigmoid(z) + ActivationFunction(x, y, s) + } + + override def crossEntropy(output: BDM[Double], target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + // TODO: make it readable + def m(o: Double, t: Double): Double = (o - t) + ActivationFunction(output, target, result, m) + val e = Bsum(result :* result) / 2 / output.cols + def m2(x: Double, o: Double) = x * (o - o * o) + ActivationFunction(result, output, result, m2) + e + } +} + +/** + * Functional layer properties, y = f(x) + * @param activationFunction activation function + */ +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { + override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) + + override def getInstance(seed: Long): LayerModel = + FunctionalLayerModel(this) +} + +/** + * Functional layer model. Holds no weights. + * @param activationFunction activation function + */ +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction + ) extends LayerModel { + val size = 0 + + private var f: BDM[Double] = null + private var d: BDM[Double] = null + private var e: BDM[Double] = null + private lazy val dg = new Array[Double](0) + + override def eval(data: BDM[Double]): BDM[Double] = { + if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) + activationFunction.eval(data, f) + f + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) + activationFunction.derivative(input, d) + d :*= nextDelta + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg + + override def weights(): Vector = Vectors.dense(new Array[Double](0)) + + def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.crossEntropy(output, target, e) + (e, error) + } + + def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.squared(output, target, e) + (e, error) + } + + def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + // TODO: allow user pick error + activationFunction match { + case sigmoid: SigmoidFunction => squared(output, target) + case softmax: SoftmaxFunction => crossEntropy(output, target) + } + } +} + +/** + * Fabric of functional layer models + */ +private[ann] object FunctionalLayerModel { + def apply(layer: FunctionalLayer): FunctionalLayerModel = + new FunctionalLayerModel(layer.activationFunction) +} + +/** + * Trait for the artificial neural network (ANN) topology properties + */ +private[ann] trait Topology extends Serializable{ + def getInstance(weights: Vector): TopologyModel + def getInstance(seed: Long): TopologyModel +} + +/** + * Trait for ANN topology model + */ +private[ann] trait TopologyModel extends Serializable{ + /** + * Forward propagation + * @param data input data + * @return array of outputs for each of the layers + */ + def forward(data: BDM[Double]): Array[BDM[Double]] + + /** + * Prediction of the model + * @param data input data + * @return prediction + */ + def predict(data: Vector): Vector + + /** + * Computes gradient for the network + * @param data input data + * @param target target output + * @param cumGradient cumulative gradient + * @param blockSize block size + * @return error + */ + def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + blockSize: Int): Double + + /** + * Returns the weights of the ANN + * @return weights + */ + def weights(): Vector +} + +/** + * Feed forward ANN + * @param layers + */ +class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { + override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + + override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) +} + +/** + * Factory for some of the frequently-used topologies + */ +object FeedForwardTopology { + /** + * Creates a feed forward topology from the array of layers + * @param layers array of layers + * @return feed forward topology + */ + def apply(layers: Array[Layer]): FeedForwardTopology = { + new FeedForwardTopology(layers) + } + + /** + * Creates a multi-layer perceptron + * @param layerSizes sizes of layers including input and output size + * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * Softmax is default + * @return multilayer perceptron topology + */ + def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + val layers = new Array[Layer]((layerSizes.length - 1) * 2) + for(i <- 0 until layerSizes.length - 1){ + layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) + layers(i * 2 + 1) = + if (softmax && i == layerSizes.length - 2) { + new FunctionalLayer(new SoftmaxFunction()) + } else { + new FunctionalLayer(new SigmoidFunction()) + } + } + FeedForwardTopology(layers) + } +} + +/** + * Model of Feed Forward Neural Network. + * Implements forward, gradient computation and can return weights in vector format. + * @param layerModels models of layers + * @param topology topology of the network + */ +class FeedForwardModel private(val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { + override def forward(data: BDM[Double]): Array[BDM[Double]] = { + val outputs = new Array[BDM[Double]](layerModels.length) + outputs(0) = layerModels(0).eval(data) + for(i <- 1 until layerModels.length){ + outputs(i) = layerModels(i).eval(outputs(i-1)) + } + outputs + } + + override def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + realBatchSize: Int): Double = { + val outputs = forward(data) + val deltas = new Array[BDM[Double]](layerModels.length) + val L = layerModels.length - 1 + val (newE, newError) = layerModels.last match { + case flm: FunctionalLayerModel => flm.error(outputs.last, target) + case _ => + throw new UnsupportedOperationException("Non-functional layer not supported at the top") + } + deltas(L) = new BDM[Double](0, 0) + deltas(L - 1) = newE + for (i <- (L - 2) to (0, -1)) { + deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) + } + val grads = new Array[Array[Double]](layerModels.length) + for (i <- 0 until layerModels.length) { + val input = if (i==0) data else outputs(i - 1) + grads(i) = layerModels(i).grad(deltas(i), input) + } + // update cumGradient + val cumGradientArray = cumGradient.toArray + var offset = 0 + // TODO: extract roll + for (i <- 0 until grads.length) { + val gradArray = grads(i) + var k = 0 + while (k < gradArray.length) { + cumGradientArray(offset + k) += gradArray(k) + k += 1 + } + offset += gradArray.length + } + newError + } + + // TODO: do we really need to copy the weights? they should be read-only + override def weights(): Vector = { + // TODO: extract roll + var size = 0 + for(i <- 0 until layerModels.length) { + size += layerModels(i).size + } + val array = new Array[Double](size) + var offset = 0 + for(i <- 0 until layerModels.length) { + val layerWeights = layerModels(i).weights().toArray + System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) + offset += layerWeights.length + } + Vectors.dense(array) + } + + override def predict(data: Vector): Vector = { + val result = forward(data.toBreeze.toDenseVector.toDenseMatrix.t) + Vectors.dense(result.last.toArray) + } +} + +/** + * Fabric for feed forward ANN models + */ +private[ann] object FeedForwardModel { + + /** + * Creates a model from a topology and weights + * @param topology topology + * @param weights weights + * @return model + */ + def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(weights, offset) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } + + /** + * Creates a model given a topology and seed + * @param topology topology + * @param seed seed for generating the weights + * @return model + */ + def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(seed) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } +} + +/** + * Neural network gradient. Does nothing but calling Model's gradient + * @param topology topology + * @param dataStacker data stacker + */ +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } + + override def compute(data: Vector, label: Double, weights: Vector, + cumGradient: Vector): Double = { + val (input, target, realBatchSize) = dataStacker.unstack(data) + val model = topology.getInstance(weights) + model.computeGradient(input, target, cumGradient, realBatchSize) + } +} + +/** + * Class that stacks the training samples RDD[(Vector, Vector)] in one vector allowing them to pass + * through Optimizer/Gradient interfaces and thus allowing batch computations. + * Can unstack the training samples into matrices. + * @param stackSize stack size + * @param inputSize size of the input vectors + * @param outputSize size of the output vectors + */ +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) + extends Serializable { + + /** + * Stacks the data + * @param data RDD of vector pairs + * @return RDD of double (always zero) and vector that contains the stacked vectors + */ + def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { + val stackedData = if (stackSize == 1) { + data.map(v => + (0.0, + Vectors.fromBreeze(BDV.vertcat( + v._1.toBreeze.toDenseVector, + v._2.toBreeze.toDenseVector)) + )) + } else { + data.mapPartitions { it => + it.grouped(stackSize).map { seq => + val size = seq.size + val bigVector = new Array[Double](inputSize * size + outputSize * size) + var i = 0 + seq.foreach { case (in, out) => + System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize) + System.arraycopy(out.toArray, 0, bigVector, + inputSize * size + i * outputSize, outputSize) + i += 1 + } + (0.0, Vectors.dense(bigVector)) + } + } + } + stackedData + } + + /** + * Unstack the stacked vectors into matrices for batch operations + * @param data stacked vector + * @return pair of matrices holding input and output data and the real stack size + */ + def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = { + val arrData = data.toArray + val realStackSize = arrData.length / (inputSize + outputSize) + val input = new BDM(inputSize, realStackSize, arrData) + val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize) + (input, target, realStackSize) + } +} + +/** + * Simple updater + */ +private[ann] class ANNUpdater extends Updater { + + override def compute(weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { + val thisIterStepSize = stepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) + } +} + +/** + * Llib-style trainer class that trains a network given the data and topology + * @param topology topology of ANN + * @param inputSize input size + * @param outputSize output size + */ +class FeedForwardTrainer (topology: Topology, val inputSize: Int, + val outputSize: Int) extends Serializable { + + // TODO: what if we need to pass random seed? + private var _weights = topology.getInstance(11L).weights() + private var _stackSize = 1 + private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) + private var _gradient: Gradient = new ANNGradient(topology, dataStacker) + private var _updater: Updater = new ANNUpdater() + private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + + /** + * Returns weights + * @return weights + */ + def getWeights: Vector = _weights + + /** + * Sets weights + * @param value weights + * @return trainer + */ + def setWeights(value: Vector): FeedForwardTrainer = { + _weights = value + this + } + + /** + * Sets the stack size + * @param value stack size + * @return trainer + */ + def setStackSize(value: Int): FeedForwardTrainer = { + _stackSize = value + dataStacker = new DataStacker(value, inputSize, outputSize) + this + } + + /** + * Sets the SGD optimizer + * @return SGD optimizer + */ + def SGDOptimizer: GradientDescent = { + val sgd = new GradientDescent(_gradient, _updater) + optimizer = sgd + sgd + } + + /** + * Sets the LBFGS optimizer + * @return LBGS optimizer + */ + def LBFGSOptimizer: LBFGS = { + val lbfgs = new LBFGS(_gradient, _updater) + optimizer = lbfgs + lbfgs + } + + /** + * Sets the updater + * @param value updater + * @return trainer + */ + def setUpdater(value: Updater): FeedForwardTrainer = { + _updater = value + updateUpdater(value) + this + } + + /** + * Sets the gradient + * @param value gradient + * @return trainer + */ + def setGradient(value: Gradient): FeedForwardTrainer = { + _gradient = value + updateGradient(value) + this + } + + private[this] def updateGradient(gradient: Gradient): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setGradient(gradient) + case sgd: GradientDescent => sgd.setGradient(gradient) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + private[this] def updateUpdater(updater: Updater): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setUpdater(updater) + case sgd: GradientDescent => sgd.setUpdater(updater) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + /** + * Trains the ANN + * @param data RDD of input and output vector pairs + * @return model + */ + def train(data: RDD[(Vector, Vector)]): TopologyModel = { + val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) + topology.getInstance(newWeights) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala new file mode 100644 index 0000000000000..13bf7f707ced9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning as 2-class classification problem") { + val inputs = Array[Array[Double]]( + Array[Double](0, 0), + Array[Double](0, 1), + Array[Double](1, 0), + Array[Double](1, 1) + ) + val outputs = Array[Double](0, 1, 1, 0) + val data = inputs.zip(outputs).map{ case(input, output) => + new LabeledPoint(output, Vectors.dense(input))} + val rddData = sc.parallelize(data, 2) + val layers = Array[Int](2, 5, 2) + val dataFrame = sqlContext.createDataFrame(rddData) + val trainer = new MultilayerPerceptronClassifier("mlpc") + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100) + val model = trainer.fit(dataFrame) + val result = model.transform(dataFrame) + val predictionAndLabels = result.select("prediction", "label").collect() + assert(predictionAndLabels.forall { case Row (p: Double, l: Double) => + (p - l) == 0 }) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala new file mode 100644 index 0000000000000..78e21877b72dc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row +import org.apache.spark.mllib.util.TestingUtils._ + +class MultilayerPerceptronRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning") { + val inputs = Array[Array[Double]]( + Array[Double](0, 0), + Array[Double](0, 1), + Array[Double](1, 0), + Array[Double](1, 1) + ) + val outputs = Array[Double](0, 1, 1, 0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(Array(label))) + } + val rddData = sc.parallelize(data, 1) + val dataFrame = sqlContext.createDataFrame(rddData).toDF("inputCol","outputCol") + val hiddenLayersTopology = Array[Int](5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val trainer = new MultilayerPerceptronRegressor("mlpr") + .setInputCol("inputCol") + .setOutputCol("outputCol") + .setBlockSize(1) + .setLayers(layerSizes) + .setMaxIter(100) + .setTol(1e-4) + .setSeed(11L) + val model = trainer.fit(dataFrame) + .setInputCol("inputCol") + model.transform(dataFrame) + .select("rawPrediction", "outputCol").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-3, "Transformed vector is different with expected vector.") + } } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala new file mode 100644 index 0000000000000..81184c7b59ea9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala @@ -0,0 +1,74 @@ +package org.apache.spark.mllib.ann + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.scalatest.FunSuite + + +class ANNSuite extends FunSuite with MLlibTestSparkContext { + + // TODO: add test for uneven batching + // TODO: add test for gradient + + test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { + val inputs = Array[Array[Double]]( + Array[Double](0, 0), + Array[Double](0, 1), + Array[Double](1, 0), + Array[Double](1, 1) + ) + val outputs = Array[Double](0, 1, 1, 0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(Array(label))) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array[Int](5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 1) + trainer.setWeights(initialWeights) + trainer.LBFGSOptimizer.setNumIterations(20) + val model = trainer.train(rddData) + //val model = FeedForwardTrainer.train(rddData, 1, 20, topology, initialWeights) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input)(0), label(0)) + }.collect() + assert(predictionAndLabels.forall { case (p, l) => (math.round(p) - l) == 0}) + } + + test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { + val inputs = Array[Array[Double]]( + Array[Double](0, 0), + Array[Double](0, 1), + Array[Double](1, 0), + Array[Double](1, 1) + ) + val outputs = Array[Array[Double]]( + Array[Double](1, 0), + Array[Double](0, 1), + Array[Double](0, 1), + Array[Double](1, 0) + ) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array[Int](5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 2) + trainer.SGDOptimizer.setNumIterations(2000) + //trainer.LBFGSOptimizer.setNumIterations(100) + trainer.setWeights(initialWeights) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input).toArray.map(math.round(_)), label.toArray) + }.collect() + assert(predictionAndLabels.forall { case (p, l) => p.deep == l.deep}) + } + +} From e191301d8ffa1f904482bc1d3213576f86ce1bff Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 23 Jul 2015 09:15:39 -0700 Subject: [PATCH 2/9] Apache header --- .../org/apache/spark/mllib/ann/ANNSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala index 81184c7b59ea9..e25121adcf5bb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.mllib.ann import org.apache.spark.mllib.linalg.Vectors From 35125ab01f21e123b76ce33ae6d6977767a69152 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 23 Jul 2015 09:25:59 -0700 Subject: [PATCH 3/9] Style fix in tests --- .../MultilayerPerceptronRegressorSuite.scala | 2 +- .../scala/org/apache/spark/mllib/ann/ANNSuite.scala | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala index 78e21877b72dc..eef266e91f8da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala @@ -37,7 +37,7 @@ class MultilayerPerceptronRegressorSuite extends SparkFunSuite with MLlibTestSpa (Vectors.dense(features), Vectors.dense(Array(label))) } val rddData = sc.parallelize(data, 1) - val dataFrame = sqlContext.createDataFrame(rddData).toDF("inputCol","outputCol") + val dataFrame = sqlContext.createDataFrame(rddData).toDF("inputCol", "outputCol") val hiddenLayersTopology = Array[Int](5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size diff --git a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala index e25121adcf5bb..340bcf0a9d60b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.mllib.ann +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.scalatest.FunSuite - -class ANNSuite extends FunSuite with MLlibTestSparkContext { - - // TODO: add test for uneven batching - // TODO: add test for gradient +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { val inputs = Array[Array[Double]]( @@ -48,7 +44,6 @@ class ANNSuite extends FunSuite with MLlibTestSparkContext { trainer.setWeights(initialWeights) trainer.LBFGSOptimizer.setNumIterations(20) val model = trainer.train(rddData) - //val model = FeedForwardTrainer.train(rddData, 1, 20, topology, initialWeights) val predictionAndLabels = rddData.map { case (input, label) => (model.predict(input)(0), label(0)) }.collect() @@ -79,7 +74,6 @@ class ANNSuite extends FunSuite with MLlibTestSparkContext { val initialWeights = FeedForwardModel(topology, 23124).weights() val trainer = new FeedForwardTrainer(topology, 2, 2) trainer.SGDOptimizer.setNumIterations(2000) - //trainer.LBFGSOptimizer.setNumIterations(100) trainer.setWeights(initialWeights) val model = trainer.train(rddData) val predictionAndLabels = rddData.map { case (input, label) => @@ -87,5 +81,4 @@ class ANNSuite extends FunSuite with MLlibTestSparkContext { }.collect() assert(predictionAndLabels.forall { case (p, l) => p.deep == l.deep}) } - } From 9d18469b0e4831c391a8c263fd576752fb162e2b Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Fri, 24 Jul 2015 08:22:33 -0700 Subject: [PATCH 4/9] Addressing reviewers comments: unnecessary copy of data in predict --- mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala index 3f85b80f112b7..27b1d95794e76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala @@ -599,7 +599,8 @@ class FeedForwardModel private(val layerModels: Array[LayerModel], } override def predict(data: Vector): Vector = { - val result = forward(data.toBreeze.toDenseVector.toDenseMatrix.t) + val size = data.size + val result = forward(new BDM[Double](size, 1, data.toArray)) Vectors.dense(result.last.toArray) } } From 43b0ae249bfed6497fdb3f79a2de395c628891b3 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Wed, 29 Jul 2015 05:08:04 -0700 Subject: [PATCH 5/9] Addressing reviewers comments. Adding multiclass test. --- .../MultilayerPerceptronClassifier.scala | 93 ++++++-- .../org/apache/spark/ml/param/params.scala | 14 ++ .../MultilayerPerceptronRegressor.scala | 206 ------------------ .../org/apache/spark/mllib/ann/Layer.scala | 2 +- .../MultilayerPerceptronClassifierSuite.scala | 69 ++++-- .../MultilayerPerceptronRegressorSuite.scala | 59 ----- .../org/apache/spark/mllib/ann/ANNSuite.scala | 1 + 7 files changed, 146 insertions(+), 298 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index b40597c556221..3c2badb528f12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -20,22 +20,78 @@ package org.apache.spark.ml.classification import breeze.linalg.{argmax => Bargmax} import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.regression.MultilayerPerceptronParams import org.apache.spark.mllib.ann.{FeedForwardTrainer, FeedForwardTopology} import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.DataFrame -/** - * :: Experimental :: - * Label to vector converter. - */ -@Experimental -private object LabelConverter { +/** Params for Multilayer Perceptron. */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams +with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = new IntArrayParam(this, "layers", + "Sizes of layers from input layer to output layer" + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "one hidden layer with 100 neurons and output layer of 10 neurons.", + // TODO: how to check ALSO that all elements are greater than 0? + ParamValidators.lengthGt(1) + ) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** + * Block size for stacking input data in matrices. Speeds up the computations. + * Cannot be more than the size of the dataset. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices.", ParamValidators.gt(0)) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 1) +} + + +/** Label to vector converter. */ +private object LabelConverter { + // TODO: Use OneHotEncoder instead /** * Encodes a label as a vector. * Returns a vector of given length with zeroes at all positions @@ -59,7 +115,7 @@ private object LabelConverter { * @return label */ def apply(output: Vector): Double = { - Bargmax(output.toBreeze.toDenseVector).toDouble + output.argmax.toDouble } } @@ -72,14 +128,14 @@ private object LabelConverter { * */ @Experimental -class MultilayerPerceptronClassifier (override val uid: String) +class MultilayerPerceptronClassifier(override val uid: String) extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] with MultilayerPerceptronParams { - override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) - def this() = this(Identifiable.randomUID("mlpc")) + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + /** * Train a model using the given dataset and parameters. * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation @@ -106,11 +162,16 @@ class MultilayerPerceptronClassifier (override val uid: String) * :: Experimental :: * Classifier model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. + * @param uid uid + * @param layers array of layer sizes including input and output layers + * @param weights vector of initial weights for the model + * @return prediction model */ @Experimental -class MultilayerPerceptronClassifierModel private[ml] (override val uid: String, - layers: Array[Int], - weights: Vector) +class MultilayerPerceptronClassifierModel private[ml]( + override val uid: String, + layers: Array[Int], + weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 824efa5ed4b28..749d2a47682e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -166,6 +166,20 @@ object ParamValidators { def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => allowed.contains(value) } + + /** Private method for checking array types and converting to Array. */ + private def getArray[T](value: T): Array[_] = value match { + case x: Array[_] => x + case _ => + // The type should be checked before this is ever called. + throw new IllegalArgumentException("Array Param validation failed because" + + s" of unexpected input type: ${value.getClass}") + } + + /** Check that the array length is greater than lowerBound. */ + def lengthGt[T](lowerBound: Double): T => Boolean = { (value: T) => + getArray(value).length > lowerBound + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala deleted file mode 100644 index 28deeaab8428b..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressor.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.regression - -import breeze.linalg.{argmax => Bargmax} - -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Model, Transformer, Estimator, PredictorParams} -import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.ann.{FeedForwardTopology, FeedForwardTrainer} -import org.apache.spark.mllib.linalg.{VectorUDT, Vector} -import org.apache.spark.sql.{Row, DataFrame} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StructField, StructType} - -/** - * Params for Multilayer Perceptron. - */ -private[ml] trait MultilayerPerceptronParams extends PredictorParams -with HasSeed with HasMaxIter with HasTol { - /** - * Layer sizes including input size and output size. - * @group param - */ - final val layers: IntArrayParam = - // TODO: we need IntegerArrayParam! - new IntArrayParam(this, "layers", - "Sizes of layers including input and output from bottom to the top." + - " E.g., Array(780, 100, 10) means 780 inputs, " + - "hidden layer with 100 neurons and output layer of 10 neurons." - // TODO: how to check that array is not empty? - ) - - /** - * Block size for stacking input data in matrices. Speeds up the computations. - * Cannot be more than the size of the dataset. - * @group expertParam - */ - final val blockSize: IntParam = new IntParam(this, "blockSize", - "Block size for stacking input data in matrices.", - ParamValidators.gt(0)) - - /** @group setParam */ - def setLayers(value: Array[Int]): this.type = set(layers, value) - - /** @group getParam */ - final def getLayers: Array[Int] = $(layers) - - /** @group setParam */ - def setBlockSize(value: Int): this.type = set(blockSize, value) - - /** @group getParam */ - final def getBlockSize: Int = $(blockSize) - - /** - * Set the maximum number of iterations. - * Default is 100. - * @group setParam - */ - def setMaxIter(value: Int): this.type = set(maxIter, value) - - /** - * Set the convergence tolerance of iterations. - * Smaller value will lead to higher accuracy with the cost of more iterations. - * Default is 1E-4. - * @group setParam - */ - def setTol(value: Double): this.type = set(tol, value) - - /** - * Set the seed for weights initialization. - * Default is 11L. - * @group setParam - */ - def setSeed(value: Long): this.type = set(seed, value) - - setDefault(seed -> 11L, maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 1) -} - -/** - * :: Experimental :: - * Multi-layer perceptron regression. Contains sigmoid activation function on all layers. - * See https://en.wikipedia.org/wiki/Multilayer_perceptron for details. - * - */ -@Experimental -class MultilayerPerceptronRegressor (override val uid: String) - extends Estimator[MultilayerPerceptronRegressorModel] - with MultilayerPerceptronParams with HasInputCol with HasOutputCol with HasRawPredictionCol - with Logging { - - /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) - - /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) - - /** - * Fits a model to the input and output data. - * InputCol has to contain input vectors. - * OutputCol has to contain output vectors. - */ - override def fit(dataset: DataFrame): MultilayerPerceptronRegressorModel = { - val data = dataset.select($(inputCol), $(outputCol)).map { - case Row(x: Vector, y: Vector) => (x, y) - } - val myLayers = getLayers - val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, false) - val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) - FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol(getTol).setNumIterations(getMaxIter) - FeedForwardTrainer.setStackSize(getBlockSize) - val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronRegressorModel(uid, myLayers, mlpModel.weights()) - } - - /** - * :: DeveloperApi :: - * - * Derives the output schema from the input schema. - */ - override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - val outputType = schema($(outputCol)).dataType - require(outputType.isInstanceOf[VectorUDT], - s"Input column ${$(outputCol)} must be a vector column") - require(!schema.fieldNames.contains($(rawPredictionCol)), - s"Output column ${$(rawPredictionCol)} already exists.") - val outputFields = schema.fields :+ StructField($(rawPredictionCol), new VectorUDT, false) - StructType(outputFields) - } - - def this() = this(Identifiable.randomUID("mlpr")) - - override def copy(extra: ParamMap): MultilayerPerceptronRegressor = defaultCopy(extra) -} - -/** - * :: Experimental :: - * Multi-layer perceptron regression model. - * - * @param layers array of layer sizes including input and output - * @param weights weights (or parameters) of the model - */ -@Experimental -class MultilayerPerceptronRegressorModel private[ml] (override val uid: String, - layers: Array[Int], - weights: Vector) - extends Model[MultilayerPerceptronRegressorModel] - with HasInputCol with HasRawPredictionCol { - - private val mlpModel = - FeedForwardTopology.multiLayerPerceptron(layers, false).getInstance(weights) - - /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) - - /** - * Transforms the input dataset. - * InputCol has to contain input vectors. - * RawPrediction column will contain predictions (outputs of the regressor). - */ - override def transform(dataset: DataFrame): DataFrame = { - transformSchema(dataset.schema, logging = true) - val pcaOp = udf { mlpModel.predict _ } - dataset.withColumn($(rawPredictionCol), pcaOp(col($(inputCol)))) - } - - /** - * :: DeveloperApi :: - * - * Derives the output schema from the input schema. - */ - override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(rawPredictionCol)), - s"Output column ${$(rawPredictionCol)} already exists.") - val outputFields = schema.fields :+ StructField($(rawPredictionCol), new VectorUDT, false) - StructType(outputFields) - } - - override def copy(extra: ParamMap): MultilayerPerceptronRegressorModel = { - copyValues(new MultilayerPerceptronRegressorModel(uid, layers, weights), extra) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala index 27b1d95794e76..cedaccc6192bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala @@ -534,7 +534,7 @@ object FeedForwardTopology { * @param layerModels models of layers * @param topology topology of the network */ -class FeedForwardModel private(val layerModels: Array[LayerModel], +private[spark] class FeedForwardModel private(val layerModels: Array[LayerModel], val topology: FeedForwardTopology) extends TopologyModel { override def forward(data: BDM[Double]): Array[BDM[Double]] = { val outputs = new Array[BDM[Double]](layerModels.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 13bf7f707ced9..0d2016daff2dd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -18,27 +18,25 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { - test("XOR function learning as 2-class classification problem") { - val inputs = Array[Array[Double]]( - Array[Double](0, 0), - Array[Double](0, 1), - Array[Double](1, 0), - Array[Double](1, 1) - ) - val outputs = Array[Double](0, 1, 1, 0) - val data = inputs.zip(outputs).map{ case(input, output) => - new LabeledPoint(output, Vectors.dense(input))} - val rddData = sc.parallelize(data, 2) + test("XOR function learning as binary classification problem with two outputs.") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") val layers = Array[Int](2, 5, 2) - val dataFrame = sqlContext.createDataFrame(rddData) - val trainer = new MultilayerPerceptronClassifier("mlpc") + val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) .setSeed(11L) @@ -46,7 +44,46 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val model = trainer.fit(dataFrame) val result = model.transform(dataFrame) val predictionAndLabels = result.select("prediction", "label").collect() - assert(predictionAndLabels.forall { case Row (p: Double, l: Double) => - (p - l) == 0 }) + predictionAndLabels.foreach { case Row(p: Double, l: Double) => + assert(p == l) } + } + + test("3 class classification with 2 hidden layers") { + val nPoints = 1000 + + // The following weights are taken from OneVsRestSuite.scala + // they represent 3-class iris dataset + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val numClasses = 3 + val numIterations = 100 + val layers = Array[Int](4, 5, 4, numClasses) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(numIterations) + val model = trainer.fit(dataFrame) + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } + // train multinomial logistic regression + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(true) + .setNumClasses(numClasses) + lr.optimizer.setRegParam(0.0) + .setNumIterations(numIterations) + val lrModel = lr.run(rdd) + val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // MLP's predictions should not differ a lot from LR's. + val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) + val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) + assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala deleted file mode 100644 index eef266e91f8da..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/MultilayerPerceptronRegressorSuite.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.regression - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row -import org.apache.spark.mllib.util.TestingUtils._ - -class MultilayerPerceptronRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("XOR function learning") { - val inputs = Array[Array[Double]]( - Array[Double](0, 0), - Array[Double](0, 1), - Array[Double](1, 0), - Array[Double](1, 1) - ) - val outputs = Array[Double](0, 1, 1, 0) - val data = inputs.zip(outputs).map { case (features, label) => - (Vectors.dense(features), Vectors.dense(Array(label))) - } - val rddData = sc.parallelize(data, 1) - val dataFrame = sqlContext.createDataFrame(rddData).toDF("inputCol", "outputCol") - val hiddenLayersTopology = Array[Int](5) - val dataSample = rddData.first() - val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size - val trainer = new MultilayerPerceptronRegressor("mlpr") - .setInputCol("inputCol") - .setOutputCol("outputCol") - .setBlockSize(1) - .setLayers(layerSizes) - .setMaxIter(100) - .setTol(1e-4) - .setSeed(11L) - val model = trainer.fit(dataFrame) - .setInputCol("inputCol") - model.transform(dataFrame) - .select("rawPrediction", "outputCol").collect().foreach { - case Row(x: Vector, y: Vector) => - assert(x ~== y absTol 1e-3, "Transformed vector is different with expected vector.") - } } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala index 340bcf0a9d60b..24cbb1d1b47a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { + // TODO: test for weights comparison with Weka MLP test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { val inputs = Array[Array[Double]]( Array[Double](0, 0), From 374bea6068856bea61efe9377d4c88e17194dd6b Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Wed, 29 Jul 2015 15:19:09 -0700 Subject: [PATCH 6/9] Moving ANN to ML package. GradientDescent constructor is now spark private. --- .../org/apache/spark/{mllib => ml}/ann/BreezeUtil.scala | 2 +- .../scala/org/apache/spark/{mllib => ml}/ann/Layer.scala | 6 +++--- .../ml/classification/MultilayerPerceptronClassifier.scala | 5 +---- .../apache/spark/mllib/optimization/GradientDescent.scala | 2 +- .../scala/org/apache/spark/{mllib => ml}/ann/ANNSuite.scala | 2 +- 5 files changed, 7 insertions(+), 10 deletions(-) rename mllib/src/main/scala/org/apache/spark/{mllib => ml}/ann/BreezeUtil.scala (98%) rename mllib/src/main/scala/org/apache/spark/{mllib => ml}/ann/Layer.scala (99%) rename mllib/src/test/scala/org/apache/spark/{mllib => ml}/ann/ANNSuite.scala (99%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala rename to mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala index 02e3601532cb7..ca049dcb7fffd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/ann/BreezeUtil.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.ann +package org.apache.spark.ml.ann import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala rename to mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index cedaccc6192bc..422cebe343db2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.ann +package org.apache.spark.ml.ann import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => brzAxpy, sum => Bsum} @@ -741,12 +741,12 @@ private[ann] class ANNUpdater extends Updater { } /** - * Llib-style trainer class that trains a network given the data and topology + * MLlib-style trainer class that trains a network given the data and topology * @param topology topology of ANN * @param inputSize input size * @param outputSize output size */ -class FeedForwardTrainer (topology: Topology, val inputSize: Int, +private[ml] class FeedForwardTrainer (topology: Topology, val inputSize: Int, val outputSize: Int) extends Serializable { // TODO: what if we need to pass random seed? diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 3c2badb528f12..8702abf964246 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,14 +17,12 @@ package org.apache.spark.ml.classification -import breeze.linalg.{argmax => Bargmax} - import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.DataFrame @@ -88,7 +86,6 @@ with HasSeed with HasMaxIter with HasTol { setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 1) } - /** Label to vector converter. */ private object LabelConverter { // TODO: Use OneHotEncoder instead diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index ab7611fd077ef..8f0d1e4aa010a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala similarity index 99% rename from mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index 24cbb1d1b47a6..fcda6c64a3c0b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.ann +package org.apache.spark.ml.ann import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors From f69bb3db2e7d370206fea87a81ccfbf6ab5fe54a Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 30 Jul 2015 03:23:38 -0700 Subject: [PATCH 7/9] Addressing reviewers comments. --- .../org/apache/spark/ml/ann/BreezeUtil.scala | 14 +-- .../scala/org/apache/spark/ml/ann/Layer.scala | 98 ++++++++++++------- .../MultilayerPerceptronClassifier.scala | 30 +++--- .../org/apache/spark/ml/param/params.scala | 20 ++-- .../org/apache/spark/ml/ann/ANNSuite.scala | 30 +++--- .../MultilayerPerceptronClassifierSuite.scala | 1 + 6 files changed, 106 insertions(+), 87 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala index ca049dcb7fffd..7429f9d652ac5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -23,8 +23,9 @@ import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} /** * In-place DGEMM and DGEMV for Breeze */ -object BreezeUtil { +private[ann] object BreezeUtil { + // TODO: switch to MLlib BLAS interface private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" /** @@ -40,12 +41,9 @@ object BreezeUtil { require(a.cols == b.rows, "A & B Dimension mismatch!") require(a.rows == c.rows, "A & C Dimension mismatch!") require(b.cols == c.cols, "A & C Dimension mismatch!") - if(a.rows == 0 || b.rows == 0 || a.cols == 0 || b.cols == 0) { - } else { - NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, - alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, - beta, c.data, c.offset, c.rows) - } + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) } /** @@ -57,9 +55,7 @@ object BreezeUtil { * @param y y */ def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { - require(a.cols == x.length, "A & b Dimension mismatch!") - NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, beta, y.data, y.offset, y.stride) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 422cebe343db2..98168a5c7ad35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.ann -import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => brzAxpy, -sum => Bsum} +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, + sum => Bsum} import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} + import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.rdd.RDD @@ -177,8 +178,11 @@ private[ann] object AffineLayerModel { * @param numOut number of layer outputs * @return matrix A and vector b */ - def unroll(weights: Vector, position: Int, - numIn: Int, numOut: Int): (BDM[Double], BDV[Double]) = { + def unroll( + weights: Vector, + position: Int, + numIn: Int, + numOut: Int): (BDM[Double], BDV[Double]) = { val weightsCopy = weights.toArray // TODO: the array is not copied to BDMs, make sure this is OK! val a = new BDM[Double](numOut, numIn, weightsCopy, position) @@ -272,8 +276,11 @@ private[ann] object ActivationFunction { } } - def apply(x1: BDM[Double], x2: BDM[Double], y: BDM[Double], - func: (Double, Double) => Double): Unit = { + def apply( + x1: BDM[Double], + x2: BDM[Double], + y: BDM[Double], + func: (Double, Double) => Double): Unit = { var i = 0 while (i < x1.rows) { var j = 0 @@ -284,7 +291,6 @@ private[ann] object ActivationFunction { i += 1 } } - } /** @@ -320,8 +326,10 @@ private[ann] class SoftmaxFunction extends ActivationFunction { } } - override def crossEntropy(output: BDM[Double], target: BDM[Double], - result: BDM[Double]): Double = { + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { def m(o: Double, t: Double): Double = o - t ActivationFunction(output, target, result, m) -Bsum( target :* Blog(output)) / output.cols @@ -346,11 +354,13 @@ private[ann] class SigmoidFunction extends ActivationFunction { ActivationFunction(x, y, s) } - override def crossEntropy(output: BDM[Double], target: BDM[Double], - result: BDM[Double]): Double = { + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { def m(o: Double, t: Double): Double = o - t ActivationFunction(output, target, result, m) - -Bsum( target :* Blog(output)) / output.cols + -Bsum(target :* Blog(output)) / output.cols } override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { @@ -384,13 +394,17 @@ private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) * Functional layer model. Holds no weights. * @param activationFunction activation function */ -private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction - ) extends LayerModel { +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) + extends LayerModel { val size = 0 - + // matrices for in-place computations + // outputs private var f: BDM[Double] = null + // delta private var d: BDM[Double] = null + // matrix for error computation private var e: BDM[Double] = null + // delta gradient private lazy val dg = new Array[Double](0) override def eval(data: BDM[Double]): BDM[Double] = { @@ -487,7 +501,7 @@ private[ann] trait TopologyModel extends Serializable{ * Feed forward ANN * @param layers */ -class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) @@ -496,7 +510,7 @@ class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { /** * Factory for some of the frequently-used topologies */ -object FeedForwardTopology { +private[ml] object FeedForwardTopology { /** * Creates a feed forward topology from the array of layers * @param layers array of layers @@ -534,19 +548,23 @@ object FeedForwardTopology { * @param layerModels models of layers * @param topology topology of the network */ -private[spark] class FeedForwardModel private(val layerModels: Array[LayerModel], - val topology: FeedForwardTopology) extends TopologyModel { +private[ml] class FeedForwardModel private( + val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { override def forward(data: BDM[Double]): Array[BDM[Double]] = { val outputs = new Array[BDM[Double]](layerModels.length) outputs(0) = layerModels(0).eval(data) - for(i <- 1 until layerModels.length){ + for (i <- 1 until layerModels.length) { outputs(i) = layerModels(i).eval(outputs(i-1)) } outputs } - override def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, - realBatchSize: Int): Double = { + override def computeGradient( + data: BDM[Double], + target: BDM[Double], + cumGradient: Vector, + realBatchSize: Int): Double = { val outputs = forward(data) val deltas = new Array[BDM[Double]](layerModels.length) val L = layerModels.length - 1 @@ -585,12 +603,12 @@ private[spark] class FeedForwardModel private(val layerModels: Array[LayerModel] override def weights(): Vector = { // TODO: extract roll var size = 0 - for(i <- 0 until layerModels.length) { + for (i <- 0 until layerModels.length) { size += layerModels(i).size } val array = new Array[Double](size) var offset = 0 - for(i <- 0 until layerModels.length) { + for (i <- 0 until layerModels.length) { val layerWeights = layerModels(i).weights().toArray System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) offset += layerWeights.length @@ -620,7 +638,7 @@ private[ann] object FeedForwardModel { val layers = topology.layers val layerModels = new Array[LayerModel](layers.length) var offset = 0 - for(i <- 0 until layers.length){ + for (i <- 0 until layers.length) { layerModels(i) = layers(i).getInstance(weights, offset) offset += layerModels(i).size } @@ -658,8 +676,11 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext (gradient, loss) } - override def compute(data: Vector, label: Double, weights: Vector, - cumGradient: Vector): Double = { + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { val (input, target, realBatchSize) = dataStacker.unstack(data) val model = topology.getInstance(weights) model.computeGradient(input, target, cumGradient, realBatchSize) @@ -684,12 +705,12 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) */ def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { val stackedData = if (stackSize == 1) { - data.map(v => + data.map { v => (0.0, Vectors.fromBreeze(BDV.vertcat( v._1.toBreeze.toDenseVector, v._2.toBreeze.toDenseVector)) - )) + ) } } else { data.mapPartitions { it => it.grouped(stackSize).map { seq => @@ -728,14 +749,15 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) */ private[ann] class ANNUpdater extends Updater { - override def compute(weightsOld: Vector, - gradient: Vector, - stepSize: Double, - iter: Int, - regParam: Double): (Vector, Double) = { + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector - brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) (Vectors.fromBreeze(brzWeights), 0) } } @@ -746,8 +768,10 @@ private[ann] class ANNUpdater extends Updater { * @param inputSize input size * @param outputSize output size */ -private[ml] class FeedForwardTrainer (topology: Topology, val inputSize: Int, - val outputSize: Int) extends Serializable { +private[ml] class FeedForwardTrainer( + topology: Topology, + val inputSize: Int, + val outputSize: Int) extends Serializable { // TODO: what if we need to pass random seed? private var _weights = topology.getInstance(11L).weights() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8702abf964246..8b608e2c3b50b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.DataFrame /** Params for Multilayer Perceptron. */ private[ml] trait MultilayerPerceptronParams extends PredictorParams -with HasSeed with HasMaxIter with HasTol { + with HasSeed with HasMaxIter with HasTol { /** * Layer sizes including input size and output size. * @group param @@ -39,7 +39,7 @@ with HasSeed with HasMaxIter with HasTol { " E.g., Array(780, 100, 10) means 780 inputs, " + "one hidden layer with 100 neurons and output layer of 10 neurons.", // TODO: how to check ALSO that all elements are greater than 0? - ParamValidators.lengthGt(1) + ParamValidators.arrayLengthGt(1) ) /** @group setParam */ @@ -94,12 +94,12 @@ private object LabelConverter { * Returns a vector of given length with zeroes at all positions * and value 1.0 at the position that corresponds to the label. * - * @param labeledPoint labeled point + * @param labeledPoint labeled point * @param labelCount total number of labels - * @return vector encoding of a label + * @return pair of features and vector encoding of a label */ - def apply(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { - val output = Array.fill(labelCount){0.0} + def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount)(0.0) output(labeledPoint.label.toInt) = 1.0 (labeledPoint.features, Vectors.dense(output)) } @@ -108,10 +108,10 @@ private object LabelConverter { * Converts a vector to a label. * Returns the position of the maximal element of a vector. * - * @param output label encoded with a vector - * @return label + * @param output label encoded with a vector + * @return label */ - def apply(output: Vector): Double = { + def decodeLabel(output: Vector): Double = { output.argmax.toDouble } } @@ -138,14 +138,14 @@ class MultilayerPerceptronClassifier(override val uid: String) * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation * and copying parameters into the model. * - * @param dataset Training dataset - * @return Fitted model + * @param dataset Training dataset + * @return Fitted model */ override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { - val labels = getLayers.last.toInt + val myLayers = $(layers) + val labels = myLayers.last val lpData = extractLabeledPoints(dataset) - val data = lpData.map(lp => LabelConverter(lp, labels)) - val myLayers = getLayers.map(_.toInt) + val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol(getTol).setNumIterations(getMaxIter) @@ -179,7 +179,7 @@ class MultilayerPerceptronClassifierModel private[ml]( * This internal method is used to implement [[transform()]] and output [[predictionCol]]. */ override protected def predict(features: Vector): Double = { - LabelConverter(mlpModel.predict(features)) + LabelConverter.decodeLabel(mlpModel.predict(features)) } override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 749d2a47682e5..5e1855d6a50b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -167,18 +167,16 @@ object ParamValidators { allowed.contains(value) } - /** Private method for checking array types and converting to Array. */ - private def getArray[T](value: T): Array[_] = value match { - case x: Array[_] => x - case _ => - // The type should be checked before this is ever called. - throw new IllegalArgumentException("Array Param validation failed because" + - s" of unexpected input type: ${value.getClass}") - } - /** Check that the array length is greater than lowerBound. */ - def lengthGt[T](lowerBound: Double): T => Boolean = { (value: T) => - getArray(value).length > lowerBound + def arrayLengthGt[T](lowerBound: Double): T => Boolean = { (value: T) => + val array: Array[_] = value match { + case x: Array[_] => x + case _ => + // The type should be checked before this is ever called. + throw new IllegalArgumentException("Array Param validation failed because" + + s" of unexpected input type: ${value.getClass}") + } + array.length > lowerBound } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index fcda6c64a3c0b..449288a48da65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -26,17 +26,17 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: test for weights comparison with Weka MLP test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { val inputs = Array[Array[Double]]( - Array[Double](0, 0), - Array[Double](0, 1), - Array[Double](1, 0), - Array[Double](1, 1) + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) ) - val outputs = Array[Double](0, 1, 1, 0) + val outputs = Array(0.0, 1.0, 1.0, 0.0) val data = inputs.zip(outputs).map { case (features, label) => (Vectors.dense(features), Vectors.dense(Array(label))) } val rddData = sc.parallelize(data, 1) - val hiddenLayersTopology = Array[Int](5) + val hiddenLayersTopology = Array(5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) @@ -53,22 +53,22 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { val inputs = Array[Array[Double]]( - Array[Double](0, 0), - Array[Double](0, 1), - Array[Double](1, 0), - Array[Double](1, 1) + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) ) val outputs = Array[Array[Double]]( - Array[Double](1, 0), - Array[Double](0, 1), - Array[Double](0, 1), - Array[Double](1, 0) + Array(1.0, 0.0), + Array(0.0, 1.0), + Array(0.0, 1.0), + Array(1.0, 0.0) ) val data = inputs.zip(outputs).map { case (features, label) => (Vectors.dense(features), Vectors.dense(label)) } val rddData = sc.parallelize(data, 1) - val hiddenLayersTopology = Array[Int](5) + val hiddenLayersTopology = Array(5) val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 0d2016daff2dd..a42b0b362345f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -48,6 +48,7 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp assert(p == l) } } + // TODO: implement a more rigorous test test("3 class classification with 2 hidden layers") { val nPoints = 1000 From a7e7951a071462633cc100cd2ddb9ef9d23860fc Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 30 Jul 2015 04:25:03 -0700 Subject: [PATCH 8/9] Default blockSize: 100. Added documentation to blockSize parameter and DataStacker class --- .../main/scala/org/apache/spark/ml/ann/Layer.scala | 9 +++++---- .../MultilayerPerceptronClassifier.scala | 13 +++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 98168a5c7ad35..45e3da51a6c5d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -688,9 +688,10 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext } /** - * Class that stacks the training samples RDD[(Vector, Vector)] in one vector allowing them to pass - * through Optimizer/Gradient interfaces and thus allowing batch computations. - * Can unstack the training samples into matrices. + * Stacks pairs of training samples (input, output) in one vector allowing them to pass + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks + * or matrices of inputs and outputs and then stack them in one vector. + * This can be used for further batch computations after unstacking. * @param stackSize stack size * @param inputSize size of the input vectors * @param outputSize size of the output vectors @@ -775,7 +776,7 @@ private[ml] class FeedForwardTrainer( // TODO: what if we need to pass random seed? private var _weights = topology.getInstance(11L).weights() - private var _stackSize = 1 + private var _stackSize = 100 private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) private var _gradient: Gradient = new ANNGradient(topology, dataStacker) private var _updater: Updater = new ANNUpdater() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 8b608e2c3b50b..f92baf41617f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -49,12 +49,17 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams final def getLayers: Array[Int] = $(layers) /** - * Block size for stacking input data in matrices. Speeds up the computations. - * Cannot be more than the size of the dataset. + * Block size for stacking input data in matrices to speed up the computation. + * Data is stacked within partitions. If block size is more than remaining data in + * a partition then it is adjusted to the size of this data. + * Recommended size is between 10 and 1000. * @group expertParam */ final val blockSize: IntParam = new IntParam(this, "blockSize", - "Block size for stacking input data in matrices.", ParamValidators.gt(0)) + "Block size for stacking input data in matrices. Data is stacked within partitions." + + " If block size is more than remaining data in a partition then " + + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", + ParamValidators.gt(0)) /** @group setParam */ def setBlockSize(value: Int): this.type = set(blockSize, value) @@ -83,7 +88,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams */ def setSeed(value: Long): this.type = set(seed, value) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 1) + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 100) } /** Label to vector converter. */ From 4806b6fa75d12002c1e19d929c23c7153a0bedd3 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 30 Jul 2015 17:00:20 -0700 Subject: [PATCH 9/9] Addressing reviewers comments. --- .../scala/org/apache/spark/ml/ann/Layer.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 10 +++++----- .../org/apache/spark/ml/param/params.scala | 11 ++-------- .../org/apache/spark/ml/ann/ANNSuite.scala | 20 ++++++++++++------- .../MultilayerPerceptronClassifierSuite.scala | 3 ++- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 45e3da51a6c5d..b5258ff348477 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -776,7 +776,7 @@ private[ml] class FeedForwardTrainer( // TODO: what if we need to pass random seed? private var _weights = topology.getInstance(11L).weights() - private var _stackSize = 100 + private var _stackSize = 128 private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) private var _gradient: Gradient = new ANNGradient(topology, dataStacker) private var _updater: Updater = new ANNUpdater() diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index f92baf41617f0..8cd2103d7d5e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -88,7 +88,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams */ def setSeed(value: Long): this.type = set(seed, value) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 100) + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) } /** Label to vector converter. */ @@ -153,8 +153,8 @@ class MultilayerPerceptronClassifier(override val uid: String) val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) - FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol(getTol).setNumIterations(getMaxIter) - FeedForwardTrainer.setStackSize(getBlockSize) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) + FeedForwardTrainer.setStackSize($(blockSize)) val mlpModel = FeedForwardTrainer.train(data) new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) } @@ -166,11 +166,11 @@ class MultilayerPerceptronClassifier(override val uid: String) * Each layer has sigmoid activation function, output layer has softmax. * @param uid uid * @param layers array of layer sizes including input and output layers - * @param weights vector of initial weights for the model + * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model */ @Experimental -class MultilayerPerceptronClassifierModel private[ml]( +class MultilayerPerceptronClassifierModel private[ml] ( override val uid: String, layers: Array[Int], weights: Vector) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5e1855d6a50b9..cbff804c806fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -168,15 +168,8 @@ object ParamValidators { } /** Check that the array length is greater than lowerBound. */ - def arrayLengthGt[T](lowerBound: Double): T => Boolean = { (value: T) => - val array: Array[_] = value match { - case x: Array[_] => x - case _ => - // The type should be checked before this is ever called. - throw new IllegalArgumentException("Array Param validation failed because" + - s" of unexpected input type: ${value.getClass}") - } - array.length > lowerBound + def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => + value.length > lowerBound } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index 449288a48da65..1292e57d7c01a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.ann import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: test for weights comparison with Weka MLP test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { - val inputs = Array[Array[Double]]( + val inputs = Array( Array(0.0, 0.0), Array(0.0, 1.0), Array(1.0, 0.0), @@ -33,7 +35,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { ) val outputs = Array(0.0, 1.0, 1.0, 0.0) val data = inputs.zip(outputs).map { case (features, label) => - (Vectors.dense(features), Vectors.dense(Array(label))) + (Vectors.dense(features), Vectors.dense(label)) } val rddData = sc.parallelize(data, 1) val hiddenLayersTopology = Array(5) @@ -48,17 +50,19 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val predictionAndLabels = rddData.map { case (input, label) => (model.predict(input)(0), label(0)) }.collect() - assert(predictionAndLabels.forall { case (p, l) => (math.round(p) - l) == 0}) + predictionAndLabels.foreach { case (p, l) => + assert(math.round(p) === l) + } } test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { - val inputs = Array[Array[Double]]( + val inputs = Array( Array(0.0, 0.0), Array(0.0, 1.0), Array(1.0, 0.0), Array(1.0, 1.0) ) - val outputs = Array[Array[Double]]( + val outputs = Array( Array(1.0, 0.0), Array(0.0, 1.0), Array(0.0, 1.0), @@ -78,8 +82,10 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { trainer.setWeights(initialWeights) val model = trainer.train(rddData) val predictionAndLabels = rddData.map { case (input, label) => - (model.predict(input).toArray.map(math.round(_)), label.toArray) + (model.predict(input), label) }.collect() - assert(predictionAndLabels.forall { case (p, l) => p.deep == l.deep}) + predictionAndLabels.foreach { case (p, l) => + assert(p ~== l absTol 0.5) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index a42b0b362345f..ddc948f65df45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -45,7 +45,8 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val result = model.transform(dataFrame) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => - assert(p == l) } + assert(p == l) + } } // TODO: implement a more rigorous test