diff --git a/README.md b/README.md
index ab449b94c..6b3fb60c6 100644
--- a/README.md
+++ b/README.md
@@ -136,6 +136,7 @@ List of supported libraries:
- [exposed](https://github.com/JetBrains/Exposed) - Kotlin SQL framework
- [mysql](https://github.com/mysql/mysql-connector-j) - MySql JDBC Connector
- [smile](https://github.com/haifengl/smile) - Statistical Machine Intelligence and Learning Engine
+ - [deeplearning4j](https://github.com/eclipse/deeplearning4j) - Deep learning library for the JVM
### Rich output
diff --git a/libraries/deeplearning4j-cuda.json b/libraries/deeplearning4j-cuda.json
new file mode 100644
index 000000000..3041a081c
--- /dev/null
+++ b/libraries/deeplearning4j-cuda.json
@@ -0,0 +1,43 @@
+{
+ "properties": {
+ "v": "1.0.0-beta6",
+ "cuda": "10.2",
+ "slf4j": "1.7.25",
+ "freemarker": "2.3.29"
+ },
+ "link": "https://github.com/eclipse/deeplearning4j",
+ "dependencies": [
+ "org.freemarker:freemarker:$freemarker",
+ "org.nd4j:nd4j-api:$v",
+ "org.nd4j:nd4j-cuda-$cuda:$v",
+ "org.nd4j:nd4j-cuda-$cuda-platform:$v",
+ "org.deeplearning4j:deeplearning4j-core:$v",
+ "org.deeplearning4j:deeplearning4j-common:$v",
+ "org.deeplearning4j:deeplearning4j-datasets:$v",
+ "org.deeplearning4j:deeplearning4j-nn:$v",
+ "org.deeplearning4j:deeplearning4j-nlp:$v",
+ "org.deeplearning4j:deeplearning4j-ui:$v",
+ "org.deeplearning4j:deeplearning4j-cuda-$cuda:$v",
+ "org.slf4j:slf4j-simple:$slf4j",
+ "org.slf4j:slf4j-api:$slf4j"
+ ],
+ "imports": [
+ "org.nd4j.config.*",
+ "org.nd4j.linalg.activations.*",
+ "org.nd4j.linalg.api.ndarray.INDArray",
+ "org.nd4j.linalg.dataset.DataSet",
+ "org.nd4j.linalg.dataset.api.iterator.DataSetIterator",
+ "org.nd4j.linalg.factory.Nd4j",
+ "org.nd4j.linalg.learning.config.*",
+ "org.nd4j.linalg.lossfunctions.LossFunctions.*",
+ "org.deeplearning4j.eval.Evaluation",
+ "org.deeplearning4j.nn.conf.*",
+ "org.deeplearning4j.nn.conf.layers.*",
+ "org.deeplearning4j.nn.multilayer.MultiLayerNetwork",
+ "org.deeplearning4j.nn.weights.WeightInit",
+ "org.deeplearning4j.optimize.listeners.ScoreIterationListener",
+ "org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator"
+ ],
+ "init": [
+ ]
+}
diff --git a/libraries/deeplearning4j.json b/libraries/deeplearning4j.json
new file mode 100644
index 000000000..61b6603cc
--- /dev/null
+++ b/libraries/deeplearning4j.json
@@ -0,0 +1,41 @@
+{
+ "properties": {
+ "v": "1.0.0-beta6",
+ "slf4j": "1.7.25",
+ "freemarker": "2.3.29"
+ },
+ "link": "https://github.com/eclipse/deeplearning4j",
+ "dependencies": [
+ "org.freemarker:freemarker:$freemarker",
+ "org.nd4j:nd4j-api:$v",
+ "org.nd4j:nd4j-native:$v",
+ "org.nd4j:nd4j-native-platform:$v",
+ "org.deeplearning4j:deeplearning4j-core:$v",
+ "org.deeplearning4j:deeplearning4j-common:$v",
+ "org.deeplearning4j:deeplearning4j-datasets:$v",
+ "org.deeplearning4j:deeplearning4j-nn:$v",
+ "org.deeplearning4j:deeplearning4j-nlp:$v",
+ "org.deeplearning4j:deeplearning4j-ui:$v",
+ "org.slf4j:slf4j-simple:$slf4j",
+ "org.slf4j:slf4j-api:$slf4j"
+ ],
+ "imports": [
+ "org.nd4j.config.*",
+ "org.nd4j.linalg.activations.*",
+ "org.nd4j.linalg.api.ndarray.INDArray",
+ "org.nd4j.linalg.dataset.DataSet",
+ "org.nd4j.linalg.dataset.api.iterator.DataSetIterator",
+ "org.nd4j.linalg.factory.Nd4j",
+ "org.nd4j.linalg.learning.config.*",
+ "org.nd4j.linalg.lossfunctions.LossFunctions.*",
+ "org.deeplearning4j.eval.Evaluation",
+ "org.deeplearning4j.nn.conf.*",
+ "org.deeplearning4j.nn.conf.layers.*",
+ "org.deeplearning4j.nn.multilayer.MultiLayerNetwork",
+ "org.deeplearning4j.nn.weights.WeightInit",
+ "org.deeplearning4j.optimize.listeners.ScoreIterationListener",
+ "org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator"
+ ],
+ "init": [
+ ]
+}
\ No newline at end of file
diff --git a/samples/DeepLearning4j-Cuda.ipynb b/samples/DeepLearning4j-Cuda.ipynb
new file mode 100644
index 000000000..d816b512d
--- /dev/null
+++ b/samples/DeepLearning4j-Cuda.ipynb
@@ -0,0 +1,257 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Handwriting classification\n",
+ "\n",
+ "
\n",
+ "\n",
+ "In this sample, you will create a deep neural network using Deeplearning4j and train a model capable of classifying random handwriting digits. \n",
+ "\n",
+ "This example use the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset.\n",
+ "\n",
+ "This sample is written in Kotlin, a Java-based language that is well suited for notebooks like this one.\n",
+ "\n",
+ "### What you will learn\n",
+ "\n",
+ "1. Load a dataset for a neural network.\n",
+ "2. Format EMNIST for image recognition.\n",
+ "3. Create a deep neural network.\n",
+ "4. Train a model.\n",
+ "5. Evaluate the performance of your model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%use deeplearning4j-cuda(cuda=10.2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "//number of rows and columns in the input pictures\n",
+ "val numRows = 28\n",
+ "val numColumns = 28\n",
+ "val outputNum = 10 // number of output classes\n",
+ "val batchSize = 64 // batch size for each epoch\n",
+ "val rngSeed = 123 // random number seed for reproducibility\n",
+ "val numEpochs = 5 // number of epochs to perform\n",
+ "val rate = 0.0015 // learning rate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator\n",
+ "\n",
+ "//Get the DataSetIterators:\n",
+ "val mnistTrain = MnistDataSetIterator(batchSize, true, rngSeed)\n",
+ "val mnistTest = MnistDataSetIterator(batchSize, false, rngSeed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "=======================================================================\n",
+ "LayerName (LayerType) nIn,nOut TotalParams ParamsShape \n",
+ "=======================================================================\n",
+ "layer0 (DenseLayer) 784,500 392 500 W:{784,500}, b:{1,500}\n",
+ "layer1 (DenseLayer) 500,100 50 100 W:{500,100}, b:{1,100}\n",
+ "layer2 (OutputLayer) 100,10 1 010 W:{100,10}, b:{1,10} \n",
+ "-----------------------------------------------------------------------\n",
+ " Total Parameters: 443 610\n",
+ " Trainable Parameters: 443 610\n",
+ " Frozen Parameters: 0\n",
+ "=======================================================================\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import org.nd4j.linalg.learning.config.Nesterovs\n",
+ "import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction\n",
+ "import org.nd4j.linalg.activations.Activation\n",
+ "\n",
+ "val conf = NeuralNetConfiguration.Builder()\n",
+ " .seed(rngSeed.toLong()) //include a random seed for reproducibility\n",
+ " // use stochastic gradient descent as an optimization algorithm\n",
+ "\n",
+ " .activation(Activation.RELU)\n",
+ " .weightInit(WeightInit.XAVIER)\n",
+ " .updater(Nesterovs(rate, 0.98)) //specify the rate of change of the learning rate.\n",
+ " .l2(rate * 0.005) // regularize learning model\n",
+ " .list()\n",
+ " .layer(DenseLayer.Builder() //create the first input layer.\n",
+ " .nIn(numRows * numColumns)\n",
+ " .nOut(500)\n",
+ " .build())\n",
+ " .layer(DenseLayer.Builder() //create the second input layer\n",
+ " .nIn(500)\n",
+ " .nOut(100)\n",
+ " .build())\n",
+ " .layer(OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer\n",
+ " .activation(Activation.SOFTMAX)\n",
+ " .nIn(100)\n",
+ " .nOut(outputNum)\n",
+ " .build())\n",
+ " .build()\n",
+ " \n",
+ "val model = MultiLayerNetwork(conf)\n",
+ "model.init()\n",
+ "println(model.summary())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Launching Deeplearning4j UI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import org.deeplearning4j.ui.api.UIServer\n",
+ "import org.deeplearning4j.optimize.listeners.ScoreIterationListener\n",
+ "import org.deeplearning4j.api.storage.StatsStorageRouter\n",
+ "import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter\n",
+ "import org.deeplearning4j.ui.stats.StatsListener\n",
+ "\n",
+ "val uiServer: UIServer = UIServer.getInstance()\n",
+ "uiServer.enableRemoteListener()\n",
+ "//Create the remote stats storage router - this sends the results to the UI via HTTP, assuming the UI is at http://localhost:9000\n",
+ "val remoteUIRouter: StatsStorageRouter = RemoteUIStatsStorageRouter(\"http://localhost:9000\")\n",
+ "model.setListeners(ScoreIterationListener(100), StatsListener(remoteUIRouter))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "open a new tab in your browser and go to [http://localhost:9000](http://localhost:9000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Train model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.fit(mnistTrain, numEpochs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Evaluate model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "========================Evaluation Metrics========================\n",
+ " # of classes: 10\n",
+ " Accuracy: 0,9718\n",
+ " Precision: 0,9722\n",
+ " Recall: 0,9711\n",
+ " F1 Score: 0,9714\n",
+ "Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)\n",
+ "\n",
+ "\n",
+ "=========================Confusion Matrix=========================\n",
+ " 0 1 2 3 4 5 6 7 8 9\n",
+ "---------------------------------------------------\n",
+ " 967 0 0 1 1 0 4 2 2 3 | 0 = 0\n",
+ " 0 1123 2 1 0 0 5 1 3 0 | 1 = 1\n",
+ " 4 3 995 0 2 0 4 14 10 0 | 2 = 2\n",
+ " 0 1 2 988 0 1 0 10 6 2 | 3 = 3\n",
+ " 3 0 0 0 956 0 6 4 0 13 | 4 = 4\n",
+ " 7 1 0 19 2 825 20 1 11 6 | 5 = 5\n",
+ " 6 3 0 0 4 1 943 1 0 0 | 6 = 6\n",
+ " 0 8 5 2 0 0 0 1010 1 2 | 7 = 7\n",
+ " 5 0 2 6 4 1 8 7 937 4 | 8 = 8\n",
+ " 3 6 1 6 10 0 2 6 1 974 | 9 = 9\n",
+ "\n",
+ "Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times\n",
+ "==================================================================\n"
+ ]
+ }
+ ],
+ "source": [
+ "val eval: org.nd4j.evaluation.classification.Evaluation = model.evaluate(mnistTest)\n",
+ "println(eval.stats())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To stop the UI:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "uiServer.stop()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Kotlin",
+ "language": "kotlin",
+ "name": "kotlin"
+ },
+ "language_info": {
+ "codemirror_mode": "text/x-kotlin",
+ "file_extension": ".kt",
+ "mimetype": "text/x-kotlin",
+ "name": "kotlin",
+ "pygments_lexer": "kotlin",
+ "version": "1.4.0-dev-7568"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/samples/DeepLearning4j.ipynb b/samples/DeepLearning4j.ipynb
new file mode 100644
index 000000000..8bd9a81e8
--- /dev/null
+++ b/samples/DeepLearning4j.ipynb
@@ -0,0 +1,656 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "%use deeplearning4j\n",
+ "%use krangl\n",
+ "%use lets-plot"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "val iris_data = \"sepal-length,sepal-width,petal-length,petal-width,species\\n5.1,3.5,1.4,0.2,Iris-setosa\\n4.9,3.0,1.4,0.2,Iris-setosa\\n4.7,3.2,1.3,0.2,Iris-setosa\\n4.6,3.1,1.5,0.2,Iris-setosa\\n5.0,3.6,1.4,0.2,Iris-setosa\\n5.4,3.9,1.7,0.4,Iris-setosa\\n4.6,3.4,1.4,0.3,Iris-setosa\\n5.0,3.4,1.5,0.2,Iris-setosa\\n4.4,2.9,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.4,3.7,1.5,0.2,Iris-setosa\\n4.8,3.4,1.6,0.2,Iris-setosa\\n4.8,3.0,1.4,0.1,Iris-setosa\\n4.3,3.0,1.1,0.1,Iris-setosa\\n5.8,4.0,1.2,0.2,Iris-setosa\\n5.7,4.4,1.5,0.4,Iris-setosa\\n5.4,3.9,1.3,0.4,Iris-setosa\\n5.1,3.5,1.4,0.3,Iris-setosa\\n5.7,3.8,1.7,0.3,Iris-setosa\\n5.1,3.8,1.5,0.3,Iris-setosa\\n5.4,3.4,1.7,0.2,Iris-setosa\\n5.1,3.7,1.5,0.4,Iris-setosa\\n4.6,3.6,1.0,0.2,Iris-setosa\\n5.1,3.3,1.7,0.5,Iris-setosa\\n4.8,3.4,1.9,0.2,Iris-setosa\\n5.0,3.0,1.6,0.2,Iris-setosa\\n5.0,3.4,1.6,0.4,Iris-setosa\\n5.2,3.5,1.5,0.2,Iris-setosa\\n5.2,3.4,1.4,0.2,Iris-setosa\\n4.7,3.2,1.6,0.2,Iris-setosa\\n4.8,3.1,1.6,0.2,Iris-setosa\\n5.4,3.4,1.5,0.4,Iris-setosa\\n5.2,4.1,1.5,0.1,Iris-setosa\\n5.5,4.2,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.0,3.2,1.2,0.2,Iris-setosa\\n5.5,3.5,1.3,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n4.4,3.0,1.3,0.2,Iris-setosa\\n5.1,3.4,1.5,0.2,Iris-setosa\\n5.0,3.5,1.3,0.3,Iris-setosa\\n4.5,2.3,1.3,0.3,Iris-setosa\\n4.4,3.2,1.3,0.2,Iris-setosa\\n5.0,3.5,1.6,0.6,Iris-setosa\\n5.1,3.8,1.9,0.4,Iris-setosa\\n4.8,3.0,1.4,0.3,Iris-setosa\\n5.1,3.8,1.6,0.2,Iris-setosa\\n4.6,3.2,1.4,0.2,Iris-setosa\\n5.3,3.7,1.5,0.2,Iris-setosa\\n5.0,3.3,1.4,0.2,Iris-setosa\\n7.0,3.2,4.7,1.4,Iris-versicolor\\n6.4,3.2,4.5,1.5,Iris-versicolor\\n6.9,3.1,4.9,1.5,Iris-versicolor\\n5.5,2.3,4.0,1.3,Iris-versicolor\\n6.5,2.8,4.6,1.5,Iris-versicolor\\n5.7,2.8,4.5,1.3,Iris-versicolor\\n6.3,3.3,4.7,1.6,Iris-versicolor\\n4.9,2.4,3.3,1.0,Iris-versicolor\\n6.6,2.9,4.6,1.3,Iris-versicolor\\n5.2,2.7,3.9,1.4,Iris-versicolor\\n5.0,2.0,3.5,1.0,Iris-versicolor\\n5.9,3.0,4.2,1.5,Iris-versicolor\\n6.0,2.2,4.0,1.0,Iris-versicolor\\n6.1,2.9,4.7,1.4,Iris-versicolor\\n5.6,2.9,3.6,1.3,Iris-versicolor\\n6.7,3.1,4.4,1.4,Iris-versicolor\\n5.6,3.0,4.5,1.5,Iris-versicolor\\n5.8,2.7,4.1,1.0,Iris-versicolor\\n6.2,2.2,4.5,1.5,Iris-versicolor\\n5.6,2.5,3.9,1.1,Iris-versicolor\\n5.9,3.2,4.8,1.8,Iris-versicolor\\n6.1,2.8,4.0,1.3,Iris-versicolor\\n6.3,2.5,4.9,1.5,Iris-versicolor\\n6.1,2.8,4.7,1.2,Iris-versicolor\\n6.4,2.9,4.3,1.3,Iris-versicolor\\n6.6,3.0,4.4,1.4,Iris-versicolor\\n6.8,2.8,4.8,1.4,Iris-versicolor\\n6.7,3.0,5.0,1.7,Iris-versicolor\\n6.0,2.9,4.5,1.5,Iris-versicolor\\n5.7,2.6,3.5,1.0,Iris-versicolor\\n5.5,2.4,3.8,1.1,Iris-versicolor\\n5.5,2.4,3.7,1.0,Iris-versicolor\\n5.8,2.7,3.9,1.2,Iris-versicolor\\n6.0,2.7,5.1,1.6,Iris-versicolor\\n5.4,3.0,4.5,1.5,Iris-versicolor\\n6.0,3.4,4.5,1.6,Iris-versicolor\\n6.7,3.1,4.7,1.5,Iris-versicolor\\n6.3,2.3,4.4,1.3,Iris-versicolor\\n5.6,3.0,4.1,1.3,Iris-versicolor\\n5.5,2.5,4.0,1.3,Iris-versicolor\\n5.5,2.6,4.4,1.2,Iris-versicolor\\n6.1,3.0,4.6,1.4,Iris-versicolor\\n5.8,2.6,4.0,1.2,Iris-versicolor\\n5.0,2.3,3.3,1.0,Iris-versicolor\\n5.6,2.7,4.2,1.3,Iris-versicolor\\n5.7,3.0,4.2,1.2,Iris-versicolor\\n5.7,2.9,4.2,1.3,Iris-versicolor\\n6.2,2.9,4.3,1.3,Iris-versicolor\\n5.1,2.5,3.0,1.1,Iris-versicolor\\n5.7,2.8,4.1,1.3,Iris-versicolor\\n6.3,3.3,6.0,2.5,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n7.1,3.0,5.9,2.1,Iris-virginica\\n6.3,2.9,5.6,1.8,Iris-virginica\\n6.5,3.0,5.8,2.2,Iris-virginica\\n7.6,3.0,6.6,2.1,Iris-virginica\\n4.9,2.5,4.5,1.7,Iris-virginica\\n7.3,2.9,6.3,1.8,Iris-virginica\\n6.7,2.5,5.8,1.8,Iris-virginica\\n7.2,3.6,6.1,2.5,Iris-virginica\\n6.5,3.2,5.1,2.0,Iris-virginica\\n6.4,2.7,5.3,1.9,Iris-virginica\\n6.8,3.0,5.5,2.1,Iris-virginica\\n5.7,2.5,5.0,2.0,Iris-virginica\\n5.8,2.8,5.1,2.4,Iris-virginica\\n6.4,3.2,5.3,2.3,Iris-virginica\\n6.5,3.0,5.5,1.8,Iris-virginica\\n7.7,3.8,6.7,2.2,Iris-virginica\\n7.7,2.6,6.9,2.3,Iris-virginica\\n6.0,2.2,5.0,1.5,Iris-virginica\\n6.9,3.2,5.7,2.3,Iris-virginica\\n5.6,2.8,4.9,2.0,Iris-virginica\\n7.7,2.8,6.7,2.0,Iris-virginica\\n6.3,2.7,4.9,1.8,Iris-virginica\\n6.7,3.3,5.7,2.1,Iris-virginica\\n7.2,3.2,6.0,1.8,Iris-virginica\\n6.2,2.8,4.8,1.8,Iris-virginica\\n6.1,3.0,4.9,1.8,Iris-virginica\\n6.4,2.8,5.6,2.1,Iris-virginica\\n7.2,3.0,5.8,1.6,Iris-virginica\\n7.4,2.8,6.1,1.9,Iris-virginica\\n7.9,3.8,6.4,2.0,Iris-virginica\\n6.4,2.8,5.6,2.2,Iris-virginica\\n6.3,2.8,5.1,1.5,Iris-virginica\\n6.1,2.6,5.6,1.4,Iris-virginica\\n7.7,3.0,6.1,2.3,Iris-virginica\\n6.3,3.4,5.6,2.4,Iris-virginica\\n6.4,3.1,5.5,1.8,Iris-virginica\\n6.0,3.0,4.8,1.8,Iris-virginica\\n6.9,3.1,5.4,2.1,Iris-virginica\\n6.7,3.1,5.6,2.4,Iris-virginica\\n6.9,3.1,5.1,2.3,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n6.8,3.2,5.9,2.3,Iris-virginica\\n6.7,3.3,5.7,2.5,Iris-virginica\\n6.7,3.0,5.2,2.3,Iris-virginica\\n6.3,2.5,5.0,1.9,Iris-virginica\\n6.5,3.0,5.2,2.0,Iris-virginica\\n6.2,3.4,5.4,2.3,Iris-virginica\\n5.9,3.0,5.1,1.8,Iris-virginica\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "sepal-length | sepal-width | petal-length | petal-width | species |
---|
5.1 | 3.3 | 1.7 | 0.5 | Iris-setosa |
5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica |
5.6 | 2.8 | 4.9 | 2.0 | Iris-virginica |
4.8 | 3.0 | 1.4 | 0.3 | Iris-setosa |
7.7 | 2.6 | 6.9 | 2.3 | Iris-virginica |
"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import java.util.*\n",
+ "import java.io.StringReader\n",
+ "\n",
+ "val iris = DataFrame.readDelim(StringReader(iris_data)).shuffle()\n",
+ "iris.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "val points = geom_point(\n",
+ " data = mapOf(\n",
+ " \"x\" to iris[\"sepal-length\"].asDoubles().toList(),\n",
+ " \"y\" to iris[\"sepal-width\"].asDoubles().toList(),\n",
+ " \"color\" to iris[\"species\"].asStrings().toList()\n",
+ " ), alpha=1.0)\n",
+ "{\n",
+ " x = \"x\" \n",
+ " y = \"y\"\n",
+ " color = \"color\"\n",
+ "}\n",
+ "\n",
+ "ggplot() + points"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "sepal-length | sepal-width | petal-length | petal-width |
---|
5.1 | 3.3 | 1.7 | 0.5 |
5.8 | 2.7 | 5.1 | 1.9 |
5.6 | 2.8 | 4.9 | 2.0 |
4.8 | 3.0 | 1.4 | 0.3 |
7.7 | 2.6 | 6.9 | 2.3 |
"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "val irisWithoutLabel = iris.remove(\"species\")\n",
+ "irisWithoutLabel.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[5.1, 3.3, 1.7, 0.5]\n",
+ "[5.8, 2.7, 5.1, 1.9]\n",
+ "[5.6, 2.8, 4.9, 2.0]\n",
+ "[4.8, 3.0, 1.4, 0.3]\n",
+ "[7.7, 2.6, 6.9, 2.3]\n",
+ "[5.6, 2.9, 3.6, 1.3]\n",
+ "[6.9, 3.1, 5.4, 2.1]\n",
+ "[5.9, 3.0, 4.2, 1.5]\n",
+ "[4.9, 3.1, 1.5, 0.1]\n",
+ "[6.8, 2.8, 4.8, 1.4]\n",
+ "[6.0, 2.2, 5.0, 1.5]\n",
+ "[6.0, 3.4, 4.5, 1.6]\n",
+ "[5.4, 3.9, 1.3, 0.4]\n",
+ "[5.7, 3.0, 4.2, 1.2]\n",
+ "[7.2, 3.0, 5.8, 1.6]\n",
+ "[6.0, 2.7, 5.1, 1.6]\n",
+ "[6.4, 3.2, 5.3, 2.3]\n",
+ "[5.7, 2.8, 4.1, 1.3]\n",
+ "[5.7, 2.5, 5.0, 2.0]\n",
+ "[6.2, 2.8, 4.8, 1.8]\n",
+ "[5.0, 3.5, 1.3, 0.3]\n",
+ "[5.7, 4.4, 1.5, 0.4]\n",
+ "[6.3, 2.5, 5.0, 1.9]\n",
+ "[7.7, 3.0, 6.1, 2.3]\n",
+ "[4.8, 3.0, 1.4, 0.1]\n",
+ "[5.8, 2.7, 3.9, 1.2]\n",
+ "[5.1, 2.5, 3.0, 1.1]\n",
+ "[6.4, 2.8, 5.6, 2.1]\n",
+ "[5.3, 3.7, 1.5, 0.2]\n",
+ "[4.6, 3.4, 1.4, 0.3]\n",
+ "[7.6, 3.0, 6.6, 2.1]\n",
+ "[4.5, 2.3, 1.3, 0.3]\n",
+ "[5.6, 2.7, 4.2, 1.3]\n",
+ "[5.7, 2.6, 3.5, 1.0]\n",
+ "[6.7, 3.0, 5.0, 1.7]\n",
+ "[6.5, 3.0, 5.8, 2.2]\n",
+ "[5.0, 2.3, 3.3, 1.0]\n",
+ "[6.1, 3.0, 4.9, 1.8]\n",
+ "[6.5, 3.0, 5.2, 2.0]\n",
+ "[6.2, 3.4, 5.4, 2.3]\n",
+ "[4.4, 2.9, 1.4, 0.2]\n",
+ "[5.2, 3.5, 1.5, 0.2]\n",
+ "[7.2, 3.6, 6.1, 2.5]\n",
+ "[5.5, 4.2, 1.4, 0.2]\n",
+ "[6.4, 2.9, 4.3, 1.3]\n",
+ "[4.9, 3.0, 1.4, 0.2]\n",
+ "[6.3, 2.5, 4.9, 1.5]\n",
+ "[5.5, 2.4, 3.7, 1.0]\n",
+ "[4.7, 3.2, 1.6, 0.2]\n",
+ "[6.3, 2.7, 4.9, 1.8]\n",
+ "[6.3, 2.3, 4.4, 1.3]\n",
+ "[7.1, 3.0, 5.9, 2.1]\n",
+ "[5.0, 3.5, 1.6, 0.6]\n",
+ "[6.8, 3.0, 5.5, 2.1]\n",
+ "[4.8, 3.4, 1.9, 0.2]\n",
+ "[6.7, 3.1, 5.6, 2.4]\n",
+ "[5.8, 2.6, 4.0, 1.2]\n",
+ "[5.0, 3.2, 1.2, 0.2]\n",
+ "[6.7, 3.3, 5.7, 2.5]\n",
+ "[5.1, 3.5, 1.4, 0.2]\n",
+ "[6.4, 2.7, 5.3, 1.9]\n",
+ "[7.0, 3.2, 4.7, 1.4]\n",
+ "[6.1, 2.8, 4.7, 1.2]\n",
+ "[5.4, 3.4, 1.7, 0.2]\n",
+ "[4.9, 2.4, 3.3, 1.0]\n",
+ "[5.2, 3.4, 1.4, 0.2]\n",
+ "[6.5, 2.8, 4.6, 1.5]\n",
+ "[5.4, 3.0, 4.5, 1.5]\n",
+ "[7.3, 2.9, 6.3, 1.8]\n",
+ "[5.2, 2.7, 3.9, 1.4]\n",
+ "[5.4, 3.9, 1.7, 0.4]\n",
+ "[6.2, 2.2, 4.5, 1.5]\n",
+ "[5.1, 3.5, 1.4, 0.3]\n",
+ "[4.8, 3.4, 1.6, 0.2]\n",
+ "[7.7, 3.8, 6.7, 2.2]\n",
+ "[5.6, 3.0, 4.5, 1.5]\n",
+ "[6.3, 3.4, 5.6, 2.4]\n",
+ "[5.8, 2.8, 5.1, 2.4]\n",
+ "[5.5, 2.3, 4.0, 1.3]\n",
+ "[4.9, 2.5, 4.5, 1.7]\n",
+ "[6.0, 2.2, 4.0, 1.0]\n",
+ "[5.0, 2.0, 3.5, 1.0]\n",
+ "[5.9, 3.2, 4.8, 1.8]\n",
+ "[5.4, 3.4, 1.5, 0.4]\n",
+ "[6.9, 3.1, 4.9, 1.5]\n",
+ "[4.9, 3.1, 1.5, 0.1]\n",
+ "[5.2, 4.1, 1.5, 0.1]\n",
+ "[5.1, 3.8, 1.5, 0.3]\n",
+ "[5.1, 3.8, 1.6, 0.2]\n",
+ "[6.7, 3.1, 4.7, 1.5]\n",
+ "[5.9, 3.0, 5.1, 1.8]\n",
+ "[5.8, 4.0, 1.2, 0.2]\n",
+ "[4.3, 3.0, 1.1, 0.1]\n",
+ "[6.7, 2.5, 5.8, 1.8]\n",
+ "[6.3, 3.3, 6.0, 2.5]\n",
+ "[5.6, 2.5, 3.9, 1.1]\n",
+ "[4.4, 3.2, 1.3, 0.2]\n",
+ "[4.6, 3.1, 1.5, 0.2]\n",
+ "[5.5, 2.6, 4.4, 1.2]\n",
+ "[6.9, 3.1, 5.1, 2.3]\n",
+ "[6.0, 2.9, 4.5, 1.5]\n",
+ "[7.2, 3.2, 6.0, 1.8]\n",
+ "[6.1, 2.8, 4.0, 1.3]\n",
+ "[5.7, 2.9, 4.2, 1.3]\n",
+ "[5.8, 2.7, 4.1, 1.0]\n",
+ "[4.8, 3.1, 1.6, 0.2]\n",
+ "[6.9, 3.2, 5.7, 2.3]\n",
+ "[5.5, 2.4, 3.8, 1.1]\n",
+ "[5.0, 3.4, 1.5, 0.2]\n",
+ "[4.6, 3.2, 1.4, 0.2]\n",
+ "[4.9, 3.1, 1.5, 0.1]\n",
+ "[6.0, 3.0, 4.8, 1.8]\n",
+ "[6.3, 2.9, 5.6, 1.8]\n",
+ "[6.6, 3.0, 4.4, 1.4]\n",
+ "[7.9, 3.8, 6.4, 2.0]\n",
+ "[5.6, 3.0, 4.1, 1.3]\n",
+ "[5.7, 3.8, 1.7, 0.3]\n",
+ "[5.0, 3.4, 1.6, 0.4]\n",
+ "[5.7, 2.8, 4.5, 1.3]\n",
+ "[6.7, 3.3, 5.7, 2.1]\n",
+ "[6.7, 3.1, 4.4, 1.4]\n",
+ "[6.7, 3.0, 5.2, 2.3]\n",
+ "[5.5, 2.5, 4.0, 1.3]\n",
+ "[5.0, 3.3, 1.4, 0.2]\n",
+ "[4.4, 3.0, 1.3, 0.2]\n",
+ "[6.6, 2.9, 4.6, 1.3]\n",
+ "[7.4, 2.8, 6.1, 1.9]\n",
+ "[6.5, 3.0, 5.5, 1.8]\n",
+ "[6.3, 2.8, 5.1, 1.5]\n",
+ "[6.4, 3.2, 4.5, 1.5]\n",
+ "[6.1, 2.9, 4.7, 1.4]\n",
+ "[4.6, 3.6, 1.0, 0.2]\n",
+ "[5.4, 3.7, 1.5, 0.2]\n",
+ "[5.5, 3.5, 1.3, 0.2]\n",
+ "[6.1, 3.0, 4.6, 1.4]\n",
+ "[5.8, 2.7, 5.1, 1.9]\n",
+ "[6.8, 3.2, 5.9, 2.3]\n",
+ "[6.4, 3.1, 5.5, 1.8]\n",
+ "[7.7, 2.8, 6.7, 2.0]\n",
+ "[5.0, 3.0, 1.6, 0.2]\n",
+ "[6.2, 2.9, 4.3, 1.3]\n",
+ "[5.1, 3.4, 1.5, 0.2]\n",
+ "[6.5, 3.2, 5.1, 2.0]\n",
+ "[5.1, 3.7, 1.5, 0.4]\n",
+ "[6.4, 2.8, 5.6, 2.2]\n",
+ "[5.1, 3.8, 1.9, 0.4]\n",
+ "[5.0, 3.6, 1.4, 0.2]\n",
+ "[4.7, 3.2, 1.3, 0.2]\n",
+ "[6.3, 3.3, 4.7, 1.6]\n",
+ "[6.1, 2.6, 5.6, 1.4]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "//Convert the iris data into 150x4 matrix\n",
+ "val row = 150\n",
+ "val col = 4\n",
+ "\n",
+ "val irisMatrix = Array(row) { DoubleArray(col) }\n",
+ "var i = 0\n",
+ "for (r in 0 until row) {\n",
+ " for (c in 0 until col) {\n",
+ " irisMatrix[r][c] = irisWithoutLabel[c][r] as Double\n",
+ " }\n",
+ "}\n",
+ "println(Arrays.deepToString(irisMatrix).replace(\"], \", \"]\\n\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[1.0, 0.0, 0.0]\n",
+ "[0.0, 1.0, 0.0]\n",
+ "[0.0, 0.0, 1.0]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "//Now do the same for the label data\n",
+ "val irisLabel = iris.select(\"species\")[0]\n",
+ "\n",
+ "val rowLabel = 150\n",
+ "val colLabel = 3\n",
+ "\n",
+ "val twodimLabel = Array(rowLabel) { DoubleArray(colLabel) }\n",
+ "for (r in 0 until rowLabel) {\n",
+ " when (irisLabel[r]) {\n",
+ " \"Iris-setosa\" -> twodimLabel[r][0] = 1.0\n",
+ " \"Iris-versicolor\" -> twodimLabel[r][1] = 1.0\n",
+ " \"Iris-virginica\" -> twodimLabel[r][2] = 1.0\n",
+ " }\n",
+ "}\n",
+ "println(Arrays.deepToString(twodimLabel).replace(\"], \", \"]\\n\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "//Convert the data matrices into training INDArrays\n",
+ "val dataIn = Nd4j.create(irisMatrix)\n",
+ "val dataOut = Nd4j.create(twodimLabel)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import org.nd4j.linalg.lossfunctions.LossFunctions\n",
+ "\n",
+ "val seed: Long = 6\n",
+ "\n",
+ "val conf = NeuralNetConfiguration.Builder()\n",
+ " .seed(seed) //include a random seed for reproducibility\n",
+ " // use stochastic gradient descent as an optimization algorithm\n",
+ " .updater(Nadam()) //specify the rate of change of the learning rate.\n",
+ " .l2(1e-4)\n",
+ " .list()\n",
+ " .layer(DenseLayer.Builder()\n",
+ " .nIn(4)\n",
+ " .nOut(3)\n",
+ " .activation(Activation.TANH)\n",
+ " .weightInit(WeightInit.XAVIER)\n",
+ " .build())\n",
+ " .layer(org.deeplearning4j.nn.conf.layers.DenseLayer.Builder()\n",
+ " .nIn(3)\n",
+ " .nOut(3)\n",
+ " .build())\n",
+ " .layer(OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)\n",
+ " .nIn(3)\n",
+ " .nOut(3)\n",
+ " .activation(Activation.SOFTMAX)\n",
+ " .weightInit(WeightInit.XAVIER)\n",
+ " .build())\n",
+ " .build()\n",
+ "\n",
+ "val model = MultiLayerNetwork(conf)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Score \n",
+ "\n",
+ "========================Evaluation Metrics========================\n",
+ " # of classes: 3\n",
+ " Accuracy: 1,0000\n",
+ " Precision: 1,0000\n",
+ " Recall: 1,0000\n",
+ " F1 Score: 1,0000\n",
+ "Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)\n",
+ "\n",
+ "\n",
+ "=========================Confusion Matrix=========================\n",
+ " 0 1 2\n",
+ "-------\n",
+ " 5 0 0 | 0 = 0\n",
+ " 0 3 0 | 1 = 1\n",
+ " 0 0 7 | 2 = 2\n",
+ "\n",
+ "Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times\n",
+ "==================================================================\n"
+ ]
+ }
+ ],
+ "source": [
+ "import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization\n",
+ "import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize\n",
+ "\n",
+ "//Create a data set from the INDArrays and shuffle it \n",
+ "val fullDataSet = DataSet(dataIn, dataOut)\n",
+ "fullDataSet.shuffle(seed)\n",
+ "\n",
+ "val splitedSet = fullDataSet.splitTestAndTrain(0.90)\n",
+ "val trainingData = splitedSet.train;\n",
+ "val testData = splitedSet.test;\n",
+ "\n",
+ "//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):\n",
+ "val normalizer: DataNormalization = NormalizerStandardize()\n",
+ "normalizer.fit(trainingData) //Collect the statistics (mean/stdev) from the training data. This does not modify the input data\n",
+ "normalizer.transform(trainingData) //Apply normalization to the training data\n",
+ "normalizer.transform(testData) //Apply normalization to the test data. This is using statistics calculated from the *training* set\n",
+ "\n",
+ "// train the network\n",
+ "model.setListeners(ScoreIterationListener(100))\n",
+ "for (l in 0..2000) {\n",
+ " model.fit(trainingData)\n",
+ "}\n",
+ "\n",
+ "// evaluate the network\n",
+ "val eval = Evaluation()\n",
+ "val output: INDArray = model.output(testData.features)\n",
+ "eval.eval(testData.labels, output)\n",
+ "println(\"Score \" + eval.stats())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Kotlin",
+ "language": "kotlin",
+ "name": "kotlin"
+ },
+ "language_info": {
+ "codemirror_mode": "text/x-kotlin",
+ "file_extension": ".kt",
+ "mimetype": "text/x-kotlin",
+ "name": "kotlin",
+ "pygments_lexer": "kotlin",
+ "version": "1.3.70-eap-274"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}