diff --git a/README.md b/README.md index ab449b94c..6b3fb60c6 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ List of supported libraries: - [exposed](https://github.com/JetBrains/Exposed) - Kotlin SQL framework - [mysql](https://github.com/mysql/mysql-connector-j) - MySql JDBC Connector - [smile](https://github.com/haifengl/smile) - Statistical Machine Intelligence and Learning Engine + - [deeplearning4j](https://github.com/eclipse/deeplearning4j) - Deep learning library for the JVM ### Rich output diff --git a/libraries/deeplearning4j-cuda.json b/libraries/deeplearning4j-cuda.json new file mode 100644 index 000000000..3041a081c --- /dev/null +++ b/libraries/deeplearning4j-cuda.json @@ -0,0 +1,43 @@ +{ + "properties": { + "v": "1.0.0-beta6", + "cuda": "10.2", + "slf4j": "1.7.25", + "freemarker": "2.3.29" + }, + "link": "https://github.com/eclipse/deeplearning4j", + "dependencies": [ + "org.freemarker:freemarker:$freemarker", + "org.nd4j:nd4j-api:$v", + "org.nd4j:nd4j-cuda-$cuda:$v", + "org.nd4j:nd4j-cuda-$cuda-platform:$v", + "org.deeplearning4j:deeplearning4j-core:$v", + "org.deeplearning4j:deeplearning4j-common:$v", + "org.deeplearning4j:deeplearning4j-datasets:$v", + "org.deeplearning4j:deeplearning4j-nn:$v", + "org.deeplearning4j:deeplearning4j-nlp:$v", + "org.deeplearning4j:deeplearning4j-ui:$v", + "org.deeplearning4j:deeplearning4j-cuda-$cuda:$v", + "org.slf4j:slf4j-simple:$slf4j", + "org.slf4j:slf4j-api:$slf4j" + ], + "imports": [ + "org.nd4j.config.*", + "org.nd4j.linalg.activations.*", + "org.nd4j.linalg.api.ndarray.INDArray", + "org.nd4j.linalg.dataset.DataSet", + "org.nd4j.linalg.dataset.api.iterator.DataSetIterator", + "org.nd4j.linalg.factory.Nd4j", + "org.nd4j.linalg.learning.config.*", + "org.nd4j.linalg.lossfunctions.LossFunctions.*", + "org.deeplearning4j.eval.Evaluation", + "org.deeplearning4j.nn.conf.*", + "org.deeplearning4j.nn.conf.layers.*", + "org.deeplearning4j.nn.multilayer.MultiLayerNetwork", + "org.deeplearning4j.nn.weights.WeightInit", + "org.deeplearning4j.optimize.listeners.ScoreIterationListener", + "org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator" + ], + "init": [ + ] +} diff --git a/libraries/deeplearning4j.json b/libraries/deeplearning4j.json new file mode 100644 index 000000000..61b6603cc --- /dev/null +++ b/libraries/deeplearning4j.json @@ -0,0 +1,41 @@ +{ + "properties": { + "v": "1.0.0-beta6", + "slf4j": "1.7.25", + "freemarker": "2.3.29" + }, + "link": "https://github.com/eclipse/deeplearning4j", + "dependencies": [ + "org.freemarker:freemarker:$freemarker", + "org.nd4j:nd4j-api:$v", + "org.nd4j:nd4j-native:$v", + "org.nd4j:nd4j-native-platform:$v", + "org.deeplearning4j:deeplearning4j-core:$v", + "org.deeplearning4j:deeplearning4j-common:$v", + "org.deeplearning4j:deeplearning4j-datasets:$v", + "org.deeplearning4j:deeplearning4j-nn:$v", + "org.deeplearning4j:deeplearning4j-nlp:$v", + "org.deeplearning4j:deeplearning4j-ui:$v", + "org.slf4j:slf4j-simple:$slf4j", + "org.slf4j:slf4j-api:$slf4j" + ], + "imports": [ + "org.nd4j.config.*", + "org.nd4j.linalg.activations.*", + "org.nd4j.linalg.api.ndarray.INDArray", + "org.nd4j.linalg.dataset.DataSet", + "org.nd4j.linalg.dataset.api.iterator.DataSetIterator", + "org.nd4j.linalg.factory.Nd4j", + "org.nd4j.linalg.learning.config.*", + "org.nd4j.linalg.lossfunctions.LossFunctions.*", + "org.deeplearning4j.eval.Evaluation", + "org.deeplearning4j.nn.conf.*", + "org.deeplearning4j.nn.conf.layers.*", + "org.deeplearning4j.nn.multilayer.MultiLayerNetwork", + "org.deeplearning4j.nn.weights.WeightInit", + "org.deeplearning4j.optimize.listeners.ScoreIterationListener", + "org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator" + ], + "init": [ + ] +} \ No newline at end of file diff --git a/samples/DeepLearning4j-Cuda.ipynb b/samples/DeepLearning4j-Cuda.ipynb new file mode 100644 index 000000000..d816b512d --- /dev/null +++ b/samples/DeepLearning4j-Cuda.ipynb @@ -0,0 +1,257 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Handwriting classification\n", + "\n", + "\"MNIST\n", + "\n", + "In this sample, you will create a deep neural network using Deeplearning4j and train a model capable of classifying random handwriting digits. \n", + "\n", + "This example use the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset.\n", + "\n", + "This sample is written in Kotlin, a Java-based language that is well suited for notebooks like this one.\n", + "\n", + "### What you will learn\n", + "\n", + "1. Load a dataset for a neural network.\n", + "2. Format EMNIST for image recognition.\n", + "3. Create a deep neural network.\n", + "4. Train a model.\n", + "5. Evaluate the performance of your model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%use deeplearning4j-cuda(cuda=10.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "//number of rows and columns in the input pictures\n", + "val numRows = 28\n", + "val numColumns = 28\n", + "val outputNum = 10 // number of output classes\n", + "val batchSize = 64 // batch size for each epoch\n", + "val rngSeed = 123 // random number seed for reproducibility\n", + "val numEpochs = 5 // number of epochs to perform\n", + "val rate = 0.0015 // learning rate" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator\n", + "\n", + "//Get the DataSetIterators:\n", + "val mnistTrain = MnistDataSetIterator(batchSize, true, rngSeed)\n", + "val mnistTest = MnistDataSetIterator(batchSize, false, rngSeed)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=======================================================================\n", + "LayerName (LayerType) nIn,nOut TotalParams ParamsShape \n", + "=======================================================================\n", + "layer0 (DenseLayer) 784,500 392 500 W:{784,500}, b:{1,500}\n", + "layer1 (DenseLayer) 500,100 50 100 W:{500,100}, b:{1,100}\n", + "layer2 (OutputLayer) 100,10 1 010 W:{100,10}, b:{1,10} \n", + "-----------------------------------------------------------------------\n", + " Total Parameters: 443 610\n", + " Trainable Parameters: 443 610\n", + " Frozen Parameters: 0\n", + "=======================================================================\n", + "\n" + ] + } + ], + "source": [ + "import org.nd4j.linalg.learning.config.Nesterovs\n", + "import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction\n", + "import org.nd4j.linalg.activations.Activation\n", + "\n", + "val conf = NeuralNetConfiguration.Builder()\n", + " .seed(rngSeed.toLong()) //include a random seed for reproducibility\n", + " // use stochastic gradient descent as an optimization algorithm\n", + "\n", + " .activation(Activation.RELU)\n", + " .weightInit(WeightInit.XAVIER)\n", + " .updater(Nesterovs(rate, 0.98)) //specify the rate of change of the learning rate.\n", + " .l2(rate * 0.005) // regularize learning model\n", + " .list()\n", + " .layer(DenseLayer.Builder() //create the first input layer.\n", + " .nIn(numRows * numColumns)\n", + " .nOut(500)\n", + " .build())\n", + " .layer(DenseLayer.Builder() //create the second input layer\n", + " .nIn(500)\n", + " .nOut(100)\n", + " .build())\n", + " .layer(OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer\n", + " .activation(Activation.SOFTMAX)\n", + " .nIn(100)\n", + " .nOut(outputNum)\n", + " .build())\n", + " .build()\n", + " \n", + "val model = MultiLayerNetwork(conf)\n", + "model.init()\n", + "println(model.summary())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Launching Deeplearning4j UI" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import org.deeplearning4j.ui.api.UIServer\n", + "import org.deeplearning4j.optimize.listeners.ScoreIterationListener\n", + "import org.deeplearning4j.api.storage.StatsStorageRouter\n", + "import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter\n", + "import org.deeplearning4j.ui.stats.StatsListener\n", + "\n", + "val uiServer: UIServer = UIServer.getInstance()\n", + "uiServer.enableRemoteListener()\n", + "//Create the remote stats storage router - this sends the results to the UI via HTTP, assuming the UI is at http://localhost:9000\n", + "val remoteUIRouter: StatsStorageRouter = RemoteUIStatsStorageRouter(\"http://localhost:9000\")\n", + "model.setListeners(ScoreIterationListener(100), StatsListener(remoteUIRouter))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "open a new tab in your browser and go to [http://localhost:9000](http://localhost:9000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model.fit(mnistTrain, numEpochs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluate model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "========================Evaluation Metrics========================\n", + " # of classes: 10\n", + " Accuracy: 0,9718\n", + " Precision: 0,9722\n", + " Recall: 0,9711\n", + " F1 Score: 0,9714\n", + "Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)\n", + "\n", + "\n", + "=========================Confusion Matrix=========================\n", + " 0 1 2 3 4 5 6 7 8 9\n", + "---------------------------------------------------\n", + " 967 0 0 1 1 0 4 2 2 3 | 0 = 0\n", + " 0 1123 2 1 0 0 5 1 3 0 | 1 = 1\n", + " 4 3 995 0 2 0 4 14 10 0 | 2 = 2\n", + " 0 1 2 988 0 1 0 10 6 2 | 3 = 3\n", + " 3 0 0 0 956 0 6 4 0 13 | 4 = 4\n", + " 7 1 0 19 2 825 20 1 11 6 | 5 = 5\n", + " 6 3 0 0 4 1 943 1 0 0 | 6 = 6\n", + " 0 8 5 2 0 0 0 1010 1 2 | 7 = 7\n", + " 5 0 2 6 4 1 8 7 937 4 | 8 = 8\n", + " 3 6 1 6 10 0 2 6 1 974 | 9 = 9\n", + "\n", + "Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times\n", + "==================================================================\n" + ] + } + ], + "source": [ + "val eval: org.nd4j.evaluation.classification.Evaluation = model.evaluate(mnistTest)\n", + "println(eval.stats())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To stop the UI:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "uiServer.stop()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Kotlin", + "language": "kotlin", + "name": "kotlin" + }, + "language_info": { + "codemirror_mode": "text/x-kotlin", + "file_extension": ".kt", + "mimetype": "text/x-kotlin", + "name": "kotlin", + "pygments_lexer": "kotlin", + "version": "1.4.0-dev-7568" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/samples/DeepLearning4j.ipynb b/samples/DeepLearning4j.ipynb new file mode 100644 index 000000000..8bd9a81e8 --- /dev/null +++ b/samples/DeepLearning4j.ipynb @@ -0,0 +1,656 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%use deeplearning4j\n", + "%use krangl\n", + "%use lets-plot" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "val iris_data = \"sepal-length,sepal-width,petal-length,petal-width,species\\n5.1,3.5,1.4,0.2,Iris-setosa\\n4.9,3.0,1.4,0.2,Iris-setosa\\n4.7,3.2,1.3,0.2,Iris-setosa\\n4.6,3.1,1.5,0.2,Iris-setosa\\n5.0,3.6,1.4,0.2,Iris-setosa\\n5.4,3.9,1.7,0.4,Iris-setosa\\n4.6,3.4,1.4,0.3,Iris-setosa\\n5.0,3.4,1.5,0.2,Iris-setosa\\n4.4,2.9,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.4,3.7,1.5,0.2,Iris-setosa\\n4.8,3.4,1.6,0.2,Iris-setosa\\n4.8,3.0,1.4,0.1,Iris-setosa\\n4.3,3.0,1.1,0.1,Iris-setosa\\n5.8,4.0,1.2,0.2,Iris-setosa\\n5.7,4.4,1.5,0.4,Iris-setosa\\n5.4,3.9,1.3,0.4,Iris-setosa\\n5.1,3.5,1.4,0.3,Iris-setosa\\n5.7,3.8,1.7,0.3,Iris-setosa\\n5.1,3.8,1.5,0.3,Iris-setosa\\n5.4,3.4,1.7,0.2,Iris-setosa\\n5.1,3.7,1.5,0.4,Iris-setosa\\n4.6,3.6,1.0,0.2,Iris-setosa\\n5.1,3.3,1.7,0.5,Iris-setosa\\n4.8,3.4,1.9,0.2,Iris-setosa\\n5.0,3.0,1.6,0.2,Iris-setosa\\n5.0,3.4,1.6,0.4,Iris-setosa\\n5.2,3.5,1.5,0.2,Iris-setosa\\n5.2,3.4,1.4,0.2,Iris-setosa\\n4.7,3.2,1.6,0.2,Iris-setosa\\n4.8,3.1,1.6,0.2,Iris-setosa\\n5.4,3.4,1.5,0.4,Iris-setosa\\n5.2,4.1,1.5,0.1,Iris-setosa\\n5.5,4.2,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.0,3.2,1.2,0.2,Iris-setosa\\n5.5,3.5,1.3,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n4.4,3.0,1.3,0.2,Iris-setosa\\n5.1,3.4,1.5,0.2,Iris-setosa\\n5.0,3.5,1.3,0.3,Iris-setosa\\n4.5,2.3,1.3,0.3,Iris-setosa\\n4.4,3.2,1.3,0.2,Iris-setosa\\n5.0,3.5,1.6,0.6,Iris-setosa\\n5.1,3.8,1.9,0.4,Iris-setosa\\n4.8,3.0,1.4,0.3,Iris-setosa\\n5.1,3.8,1.6,0.2,Iris-setosa\\n4.6,3.2,1.4,0.2,Iris-setosa\\n5.3,3.7,1.5,0.2,Iris-setosa\\n5.0,3.3,1.4,0.2,Iris-setosa\\n7.0,3.2,4.7,1.4,Iris-versicolor\\n6.4,3.2,4.5,1.5,Iris-versicolor\\n6.9,3.1,4.9,1.5,Iris-versicolor\\n5.5,2.3,4.0,1.3,Iris-versicolor\\n6.5,2.8,4.6,1.5,Iris-versicolor\\n5.7,2.8,4.5,1.3,Iris-versicolor\\n6.3,3.3,4.7,1.6,Iris-versicolor\\n4.9,2.4,3.3,1.0,Iris-versicolor\\n6.6,2.9,4.6,1.3,Iris-versicolor\\n5.2,2.7,3.9,1.4,Iris-versicolor\\n5.0,2.0,3.5,1.0,Iris-versicolor\\n5.9,3.0,4.2,1.5,Iris-versicolor\\n6.0,2.2,4.0,1.0,Iris-versicolor\\n6.1,2.9,4.7,1.4,Iris-versicolor\\n5.6,2.9,3.6,1.3,Iris-versicolor\\n6.7,3.1,4.4,1.4,Iris-versicolor\\n5.6,3.0,4.5,1.5,Iris-versicolor\\n5.8,2.7,4.1,1.0,Iris-versicolor\\n6.2,2.2,4.5,1.5,Iris-versicolor\\n5.6,2.5,3.9,1.1,Iris-versicolor\\n5.9,3.2,4.8,1.8,Iris-versicolor\\n6.1,2.8,4.0,1.3,Iris-versicolor\\n6.3,2.5,4.9,1.5,Iris-versicolor\\n6.1,2.8,4.7,1.2,Iris-versicolor\\n6.4,2.9,4.3,1.3,Iris-versicolor\\n6.6,3.0,4.4,1.4,Iris-versicolor\\n6.8,2.8,4.8,1.4,Iris-versicolor\\n6.7,3.0,5.0,1.7,Iris-versicolor\\n6.0,2.9,4.5,1.5,Iris-versicolor\\n5.7,2.6,3.5,1.0,Iris-versicolor\\n5.5,2.4,3.8,1.1,Iris-versicolor\\n5.5,2.4,3.7,1.0,Iris-versicolor\\n5.8,2.7,3.9,1.2,Iris-versicolor\\n6.0,2.7,5.1,1.6,Iris-versicolor\\n5.4,3.0,4.5,1.5,Iris-versicolor\\n6.0,3.4,4.5,1.6,Iris-versicolor\\n6.7,3.1,4.7,1.5,Iris-versicolor\\n6.3,2.3,4.4,1.3,Iris-versicolor\\n5.6,3.0,4.1,1.3,Iris-versicolor\\n5.5,2.5,4.0,1.3,Iris-versicolor\\n5.5,2.6,4.4,1.2,Iris-versicolor\\n6.1,3.0,4.6,1.4,Iris-versicolor\\n5.8,2.6,4.0,1.2,Iris-versicolor\\n5.0,2.3,3.3,1.0,Iris-versicolor\\n5.6,2.7,4.2,1.3,Iris-versicolor\\n5.7,3.0,4.2,1.2,Iris-versicolor\\n5.7,2.9,4.2,1.3,Iris-versicolor\\n6.2,2.9,4.3,1.3,Iris-versicolor\\n5.1,2.5,3.0,1.1,Iris-versicolor\\n5.7,2.8,4.1,1.3,Iris-versicolor\\n6.3,3.3,6.0,2.5,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n7.1,3.0,5.9,2.1,Iris-virginica\\n6.3,2.9,5.6,1.8,Iris-virginica\\n6.5,3.0,5.8,2.2,Iris-virginica\\n7.6,3.0,6.6,2.1,Iris-virginica\\n4.9,2.5,4.5,1.7,Iris-virginica\\n7.3,2.9,6.3,1.8,Iris-virginica\\n6.7,2.5,5.8,1.8,Iris-virginica\\n7.2,3.6,6.1,2.5,Iris-virginica\\n6.5,3.2,5.1,2.0,Iris-virginica\\n6.4,2.7,5.3,1.9,Iris-virginica\\n6.8,3.0,5.5,2.1,Iris-virginica\\n5.7,2.5,5.0,2.0,Iris-virginica\\n5.8,2.8,5.1,2.4,Iris-virginica\\n6.4,3.2,5.3,2.3,Iris-virginica\\n6.5,3.0,5.5,1.8,Iris-virginica\\n7.7,3.8,6.7,2.2,Iris-virginica\\n7.7,2.6,6.9,2.3,Iris-virginica\\n6.0,2.2,5.0,1.5,Iris-virginica\\n6.9,3.2,5.7,2.3,Iris-virginica\\n5.6,2.8,4.9,2.0,Iris-virginica\\n7.7,2.8,6.7,2.0,Iris-virginica\\n6.3,2.7,4.9,1.8,Iris-virginica\\n6.7,3.3,5.7,2.1,Iris-virginica\\n7.2,3.2,6.0,1.8,Iris-virginica\\n6.2,2.8,4.8,1.8,Iris-virginica\\n6.1,3.0,4.9,1.8,Iris-virginica\\n6.4,2.8,5.6,2.1,Iris-virginica\\n7.2,3.0,5.8,1.6,Iris-virginica\\n7.4,2.8,6.1,1.9,Iris-virginica\\n7.9,3.8,6.4,2.0,Iris-virginica\\n6.4,2.8,5.6,2.2,Iris-virginica\\n6.3,2.8,5.1,1.5,Iris-virginica\\n6.1,2.6,5.6,1.4,Iris-virginica\\n7.7,3.0,6.1,2.3,Iris-virginica\\n6.3,3.4,5.6,2.4,Iris-virginica\\n6.4,3.1,5.5,1.8,Iris-virginica\\n6.0,3.0,4.8,1.8,Iris-virginica\\n6.9,3.1,5.4,2.1,Iris-virginica\\n6.7,3.1,5.6,2.4,Iris-virginica\\n6.9,3.1,5.1,2.3,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n6.8,3.2,5.9,2.3,Iris-virginica\\n6.7,3.3,5.7,2.5,Iris-virginica\\n6.7,3.0,5.2,2.3,Iris-virginica\\n6.3,2.5,5.0,1.9,Iris-virginica\\n6.5,3.0,5.2,2.0,Iris-virginica\\n6.2,3.4,5.4,2.3,Iris-virginica\\n5.9,3.0,5.1,1.8,Iris-virginica\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
sepal-lengthsepal-widthpetal-lengthpetal-widthspecies
5.13.31.70.5Iris-setosa
5.82.75.11.9Iris-virginica
5.62.84.92.0Iris-virginica
4.83.01.40.3Iris-setosa
7.72.66.92.3Iris-virginica
" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import java.util.*\n", + "import java.io.StringReader\n", + "\n", + "val iris = DataFrame.readDelim(StringReader(iris_data)).shuffle()\n", + "iris.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "val points = geom_point(\n", + " data = mapOf(\n", + " \"x\" to iris[\"sepal-length\"].asDoubles().toList(),\n", + " \"y\" to iris[\"sepal-width\"].asDoubles().toList(),\n", + " \"color\" to iris[\"species\"].asStrings().toList()\n", + " ), alpha=1.0)\n", + "{\n", + " x = \"x\" \n", + " y = \"y\"\n", + " color = \"color\"\n", + "}\n", + "\n", + "ggplot() + points" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
sepal-lengthsepal-widthpetal-lengthpetal-width
5.13.31.70.5
5.82.75.11.9
5.62.84.92.0
4.83.01.40.3
7.72.66.92.3
" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "val irisWithoutLabel = iris.remove(\"species\")\n", + "irisWithoutLabel.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[5.1, 3.3, 1.7, 0.5]\n", + "[5.8, 2.7, 5.1, 1.9]\n", + "[5.6, 2.8, 4.9, 2.0]\n", + "[4.8, 3.0, 1.4, 0.3]\n", + "[7.7, 2.6, 6.9, 2.3]\n", + "[5.6, 2.9, 3.6, 1.3]\n", + "[6.9, 3.1, 5.4, 2.1]\n", + "[5.9, 3.0, 4.2, 1.5]\n", + "[4.9, 3.1, 1.5, 0.1]\n", + "[6.8, 2.8, 4.8, 1.4]\n", + "[6.0, 2.2, 5.0, 1.5]\n", + "[6.0, 3.4, 4.5, 1.6]\n", + "[5.4, 3.9, 1.3, 0.4]\n", + "[5.7, 3.0, 4.2, 1.2]\n", + "[7.2, 3.0, 5.8, 1.6]\n", + "[6.0, 2.7, 5.1, 1.6]\n", + "[6.4, 3.2, 5.3, 2.3]\n", + "[5.7, 2.8, 4.1, 1.3]\n", + "[5.7, 2.5, 5.0, 2.0]\n", + "[6.2, 2.8, 4.8, 1.8]\n", + "[5.0, 3.5, 1.3, 0.3]\n", + "[5.7, 4.4, 1.5, 0.4]\n", + "[6.3, 2.5, 5.0, 1.9]\n", + "[7.7, 3.0, 6.1, 2.3]\n", + "[4.8, 3.0, 1.4, 0.1]\n", + "[5.8, 2.7, 3.9, 1.2]\n", + "[5.1, 2.5, 3.0, 1.1]\n", + "[6.4, 2.8, 5.6, 2.1]\n", + "[5.3, 3.7, 1.5, 0.2]\n", + "[4.6, 3.4, 1.4, 0.3]\n", + "[7.6, 3.0, 6.6, 2.1]\n", + "[4.5, 2.3, 1.3, 0.3]\n", + "[5.6, 2.7, 4.2, 1.3]\n", + "[5.7, 2.6, 3.5, 1.0]\n", + "[6.7, 3.0, 5.0, 1.7]\n", + "[6.5, 3.0, 5.8, 2.2]\n", + "[5.0, 2.3, 3.3, 1.0]\n", + "[6.1, 3.0, 4.9, 1.8]\n", + "[6.5, 3.0, 5.2, 2.0]\n", + "[6.2, 3.4, 5.4, 2.3]\n", + "[4.4, 2.9, 1.4, 0.2]\n", + "[5.2, 3.5, 1.5, 0.2]\n", + "[7.2, 3.6, 6.1, 2.5]\n", + "[5.5, 4.2, 1.4, 0.2]\n", + "[6.4, 2.9, 4.3, 1.3]\n", + "[4.9, 3.0, 1.4, 0.2]\n", + "[6.3, 2.5, 4.9, 1.5]\n", + "[5.5, 2.4, 3.7, 1.0]\n", + "[4.7, 3.2, 1.6, 0.2]\n", + "[6.3, 2.7, 4.9, 1.8]\n", + "[6.3, 2.3, 4.4, 1.3]\n", + "[7.1, 3.0, 5.9, 2.1]\n", + "[5.0, 3.5, 1.6, 0.6]\n", + "[6.8, 3.0, 5.5, 2.1]\n", + "[4.8, 3.4, 1.9, 0.2]\n", + "[6.7, 3.1, 5.6, 2.4]\n", + "[5.8, 2.6, 4.0, 1.2]\n", + "[5.0, 3.2, 1.2, 0.2]\n", + "[6.7, 3.3, 5.7, 2.5]\n", + "[5.1, 3.5, 1.4, 0.2]\n", + "[6.4, 2.7, 5.3, 1.9]\n", + "[7.0, 3.2, 4.7, 1.4]\n", + "[6.1, 2.8, 4.7, 1.2]\n", + "[5.4, 3.4, 1.7, 0.2]\n", + "[4.9, 2.4, 3.3, 1.0]\n", + "[5.2, 3.4, 1.4, 0.2]\n", + "[6.5, 2.8, 4.6, 1.5]\n", + "[5.4, 3.0, 4.5, 1.5]\n", + "[7.3, 2.9, 6.3, 1.8]\n", + "[5.2, 2.7, 3.9, 1.4]\n", + "[5.4, 3.9, 1.7, 0.4]\n", + "[6.2, 2.2, 4.5, 1.5]\n", + "[5.1, 3.5, 1.4, 0.3]\n", + "[4.8, 3.4, 1.6, 0.2]\n", + "[7.7, 3.8, 6.7, 2.2]\n", + "[5.6, 3.0, 4.5, 1.5]\n", + "[6.3, 3.4, 5.6, 2.4]\n", + "[5.8, 2.8, 5.1, 2.4]\n", + "[5.5, 2.3, 4.0, 1.3]\n", + "[4.9, 2.5, 4.5, 1.7]\n", + "[6.0, 2.2, 4.0, 1.0]\n", + "[5.0, 2.0, 3.5, 1.0]\n", + "[5.9, 3.2, 4.8, 1.8]\n", + "[5.4, 3.4, 1.5, 0.4]\n", + "[6.9, 3.1, 4.9, 1.5]\n", + "[4.9, 3.1, 1.5, 0.1]\n", + "[5.2, 4.1, 1.5, 0.1]\n", + "[5.1, 3.8, 1.5, 0.3]\n", + "[5.1, 3.8, 1.6, 0.2]\n", + "[6.7, 3.1, 4.7, 1.5]\n", + "[5.9, 3.0, 5.1, 1.8]\n", + "[5.8, 4.0, 1.2, 0.2]\n", + "[4.3, 3.0, 1.1, 0.1]\n", + "[6.7, 2.5, 5.8, 1.8]\n", + "[6.3, 3.3, 6.0, 2.5]\n", + "[5.6, 2.5, 3.9, 1.1]\n", + "[4.4, 3.2, 1.3, 0.2]\n", + "[4.6, 3.1, 1.5, 0.2]\n", + "[5.5, 2.6, 4.4, 1.2]\n", + "[6.9, 3.1, 5.1, 2.3]\n", + "[6.0, 2.9, 4.5, 1.5]\n", + "[7.2, 3.2, 6.0, 1.8]\n", + "[6.1, 2.8, 4.0, 1.3]\n", + "[5.7, 2.9, 4.2, 1.3]\n", + "[5.8, 2.7, 4.1, 1.0]\n", + "[4.8, 3.1, 1.6, 0.2]\n", + "[6.9, 3.2, 5.7, 2.3]\n", + "[5.5, 2.4, 3.8, 1.1]\n", + "[5.0, 3.4, 1.5, 0.2]\n", + "[4.6, 3.2, 1.4, 0.2]\n", + "[4.9, 3.1, 1.5, 0.1]\n", + "[6.0, 3.0, 4.8, 1.8]\n", + "[6.3, 2.9, 5.6, 1.8]\n", + "[6.6, 3.0, 4.4, 1.4]\n", + "[7.9, 3.8, 6.4, 2.0]\n", + "[5.6, 3.0, 4.1, 1.3]\n", + "[5.7, 3.8, 1.7, 0.3]\n", + "[5.0, 3.4, 1.6, 0.4]\n", + "[5.7, 2.8, 4.5, 1.3]\n", + "[6.7, 3.3, 5.7, 2.1]\n", + "[6.7, 3.1, 4.4, 1.4]\n", + "[6.7, 3.0, 5.2, 2.3]\n", + "[5.5, 2.5, 4.0, 1.3]\n", + "[5.0, 3.3, 1.4, 0.2]\n", + "[4.4, 3.0, 1.3, 0.2]\n", + "[6.6, 2.9, 4.6, 1.3]\n", + "[7.4, 2.8, 6.1, 1.9]\n", + "[6.5, 3.0, 5.5, 1.8]\n", + "[6.3, 2.8, 5.1, 1.5]\n", + "[6.4, 3.2, 4.5, 1.5]\n", + "[6.1, 2.9, 4.7, 1.4]\n", + "[4.6, 3.6, 1.0, 0.2]\n", + "[5.4, 3.7, 1.5, 0.2]\n", + "[5.5, 3.5, 1.3, 0.2]\n", + "[6.1, 3.0, 4.6, 1.4]\n", + "[5.8, 2.7, 5.1, 1.9]\n", + "[6.8, 3.2, 5.9, 2.3]\n", + "[6.4, 3.1, 5.5, 1.8]\n", + "[7.7, 2.8, 6.7, 2.0]\n", + "[5.0, 3.0, 1.6, 0.2]\n", + "[6.2, 2.9, 4.3, 1.3]\n", + "[5.1, 3.4, 1.5, 0.2]\n", + "[6.5, 3.2, 5.1, 2.0]\n", + "[5.1, 3.7, 1.5, 0.4]\n", + "[6.4, 2.8, 5.6, 2.2]\n", + "[5.1, 3.8, 1.9, 0.4]\n", + "[5.0, 3.6, 1.4, 0.2]\n", + "[4.7, 3.2, 1.3, 0.2]\n", + "[6.3, 3.3, 4.7, 1.6]\n", + "[6.1, 2.6, 5.6, 1.4]]\n" + ] + } + ], + "source": [ + "//Convert the iris data into 150x4 matrix\n", + "val row = 150\n", + "val col = 4\n", + "\n", + "val irisMatrix = Array(row) { DoubleArray(col) }\n", + "var i = 0\n", + "for (r in 0 until row) {\n", + " for (c in 0 until col) {\n", + " irisMatrix[r][c] = irisWithoutLabel[c][r] as Double\n", + " }\n", + "}\n", + "println(Arrays.deepToString(irisMatrix).replace(\"], \", \"]\\n\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 0.0, 1.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[1.0, 0.0, 0.0]\n", + "[0.0, 1.0, 0.0]\n", + "[0.0, 0.0, 1.0]]\n" + ] + } + ], + "source": [ + "//Now do the same for the label data\n", + "val irisLabel = iris.select(\"species\")[0]\n", + "\n", + "val rowLabel = 150\n", + "val colLabel = 3\n", + "\n", + "val twodimLabel = Array(rowLabel) { DoubleArray(colLabel) }\n", + "for (r in 0 until rowLabel) {\n", + " when (irisLabel[r]) {\n", + " \"Iris-setosa\" -> twodimLabel[r][0] = 1.0\n", + " \"Iris-versicolor\" -> twodimLabel[r][1] = 1.0\n", + " \"Iris-virginica\" -> twodimLabel[r][2] = 1.0\n", + " }\n", + "}\n", + "println(Arrays.deepToString(twodimLabel).replace(\"], \", \"]\\n\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "//Convert the data matrices into training INDArrays\n", + "val dataIn = Nd4j.create(irisMatrix)\n", + "val dataOut = Nd4j.create(twodimLabel)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import org.nd4j.linalg.lossfunctions.LossFunctions\n", + "\n", + "val seed: Long = 6\n", + "\n", + "val conf = NeuralNetConfiguration.Builder()\n", + " .seed(seed) //include a random seed for reproducibility\n", + " // use stochastic gradient descent as an optimization algorithm\n", + " .updater(Nadam()) //specify the rate of change of the learning rate.\n", + " .l2(1e-4)\n", + " .list()\n", + " .layer(DenseLayer.Builder()\n", + " .nIn(4)\n", + " .nOut(3)\n", + " .activation(Activation.TANH)\n", + " .weightInit(WeightInit.XAVIER)\n", + " .build())\n", + " .layer(org.deeplearning4j.nn.conf.layers.DenseLayer.Builder()\n", + " .nIn(3)\n", + " .nOut(3)\n", + " .build())\n", + " .layer(OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)\n", + " .nIn(3)\n", + " .nOut(3)\n", + " .activation(Activation.SOFTMAX)\n", + " .weightInit(WeightInit.XAVIER)\n", + " .build())\n", + " .build()\n", + "\n", + "val model = MultiLayerNetwork(conf)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Score \n", + "\n", + "========================Evaluation Metrics========================\n", + " # of classes: 3\n", + " Accuracy: 1,0000\n", + " Precision: 1,0000\n", + " Recall: 1,0000\n", + " F1 Score: 1,0000\n", + "Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)\n", + "\n", + "\n", + "=========================Confusion Matrix=========================\n", + " 0 1 2\n", + "-------\n", + " 5 0 0 | 0 = 0\n", + " 0 3 0 | 1 = 1\n", + " 0 0 7 | 2 = 2\n", + "\n", + "Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times\n", + "==================================================================\n" + ] + } + ], + "source": [ + "import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization\n", + "import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize\n", + "\n", + "//Create a data set from the INDArrays and shuffle it \n", + "val fullDataSet = DataSet(dataIn, dataOut)\n", + "fullDataSet.shuffle(seed)\n", + "\n", + "val splitedSet = fullDataSet.splitTestAndTrain(0.90)\n", + "val trainingData = splitedSet.train;\n", + "val testData = splitedSet.test;\n", + "\n", + "//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):\n", + "val normalizer: DataNormalization = NormalizerStandardize()\n", + "normalizer.fit(trainingData) //Collect the statistics (mean/stdev) from the training data. This does not modify the input data\n", + "normalizer.transform(trainingData) //Apply normalization to the training data\n", + "normalizer.transform(testData) //Apply normalization to the test data. This is using statistics calculated from the *training* set\n", + "\n", + "// train the network\n", + "model.setListeners(ScoreIterationListener(100))\n", + "for (l in 0..2000) {\n", + " model.fit(trainingData)\n", + "}\n", + "\n", + "// evaluate the network\n", + "val eval = Evaluation()\n", + "val output: INDArray = model.output(testData.features)\n", + "eval.eval(testData.labels, output)\n", + "println(\"Score \" + eval.stats())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Kotlin", + "language": "kotlin", + "name": "kotlin" + }, + "language_info": { + "codemirror_mode": "text/x-kotlin", + "file_extension": ".kt", + "mimetype": "text/x-kotlin", + "name": "kotlin", + "pygments_lexer": "kotlin", + "version": "1.3.70-eap-274" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}