Skip to content

Commit

Permalink
Merge pull request apache#37 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
Readme and assembly module
  • Loading branch information
terrytangyuan committed Mar 5, 2016
2 parents 7351176 + 60a89c4 commit 9e5137f
Show file tree
Hide file tree
Showing 12 changed files with 475 additions and 63 deletions.
130 changes: 127 additions & 3 deletions scala-package/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,129 @@
<img src=https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/logo-m/mxnet2.png width=135/> Deep Learning for Scala/Java
=====

[![Build Status](https://travis-ci.org/dmlc/mxnet.svg?branch=master)](https://travis-ci.org/dmlc/mxnet)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
[![Join the chat at https://gitter.im/javelinjs/mxnet-scala](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/javelinjs/mxnet-scala?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)

# mxnet-scala
MXNet Scala Package
Here you find the MXNet Scala Package!
It brings flexible and efficient GPU/CPU computing and state-of-art deep learning to JVM.

- It enables you to write seamless tensor/matrix computation with multiple GPUs
in Scala, Java and other languages built on JVM.
- It also enables you to construct and customize the state-of-art deep learning models in JVM languages,
and apply them to tasks such as image classification and data science challenges.

Build
------------

Checkout the [Installation Guide](http://mxnet.readthedocs.org/en/latest/build.html) contains instructions to install mxnet.
Then you can compile the Scala Package by

```bash
make scalapkg
```

Run unit/integration tests by

```bash
make scalatest
```

If everything goes well, you will find a jar file named like `mxnet_2.10-osx-x86_64-0.1-SNAPSHOT-full.jar` under `assembly/target`. Then you can use this jar in your own project.

Also `scalapkg` target will build jars for `core` and `example` modules. If you've already downloaded and unpacked MNIST dataset to `./data/`, you can run the training example by

```bash
java -Xmx4m -cp scala-package/assembly/target/*:scala-package/examples/target/mxnet-scala-examples_2.10-0.1-SNAPSHOT.jar:scala-package/examples/target/classes/lib/args4j-2.0.29.jar ml.dmlc.mxnet.examples.imclassification.TrainMnist --data-dir=./data/ --num-epochs=10 --network=mlp --cpus=0,1,2,3
```

Change the arguments and have fun!

Usage
-------
Here is a Scala example of how training a simple 3-layer MLP on MNIST looks like:

```scala
import ml.dmlc.mxnet._
import ml.dmlc.mxnet.optimizer.SGD
import org.slf4j.LoggerFactory

// model definition
val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected(name = "fc1")(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation(name = "relu1")(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected(name = "fc2")(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation(name = "relu2")(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected(name = "fc3")(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.SoftmaxOutput(name = "sm")(Map("data" -> fc3))

// load MNIST dataset
val trainDataIter = IO.MNISTIter(Map(
"image" -> "data/train-images-idx3-ubyte",
"label" -> "data/train-labels-idx1-ubyte",
"data_shape" -> "(1, 28, 28)",
"label_name" -> "sm_label",
"batch_size" -> batchSize.toString,
"shuffle" -> "1",
"flat" -> "0",
"silent" -> "0",
"seed" -> "10"))

val valDataIter = IO.MNISTIter(Map(
"image" -> "data/t10k-images-idx3-ubyte",
"label" -> "data/t10k-labels-idx1-ubyte",
"data_shape" -> "(1, 28, 28)",
"label_name" -> "sm_label",
"batch_size" -> batchSize.toString,
"shuffle" -> "1",
"flat" -> "0", "silent" -> "0"))

// setup model
val model = new FeedForward(mlp, Context.cpu(), numEpoch = 10,
optimizer = new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f))
model.fit(trainDataIter, valDataIter)
```

Predict using the model in the following way:

```scala
val probArrays = model.predict(valDataIter)
// in this case, we do not have multiple outputs
require(probArrays.length == 1)
val prob = probArrays(0)

// get real labels
import scala.collection.mutable.ListBuffer
valDataIter.reset()
val labels = ListBuffer.empty[NDArray]
var evalData = valDataIter.next()
while (evalData != null) {
labels += evalData.label(0).copy()
evalData = valDataIter.next()
}
val y = NDArray.concatenate(labels)

// get predicted labels
val py = NDArray.argmaxChannel(prob)
require(y.shape == py.shape)

// calculate accuracy
var numCorrect = 0
var numInst = 0
for ((labelElem, predElem) <- y.toArray zip py.toArray) {
if (labelElem == predElem) {
numCorrect += 1
}
numInst += 1
}
val acc = numCorrect.toFloat / numInst
println(s"Final accuracy = $acc")
```

You can refer to [MXNet Scala Package Examples](https://github.com/javelinjs/mxnet-scala-example)
for more information about how to integrate MXNet Scala Package into your own project.
Currently you have to put the Jars into your project's build classpath manully.
We will provide pre-built binary package on [Maven Repository](http://mvnrepository.com) soon.

License
-------
MXNet Scala Package is licensed under [BSD](https://github.com/dmlc/mxnet/blob/master/scala-package/LICENSE) license.
94 changes: 94 additions & 0 deletions scala-package/assembly/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<project>
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-scala-parent_2.10</artifactId>
<version>0.1-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-scala-assmebly_2.10</artifactId>
<name>MXNet Scala Package - Project Assembly</name>
<packaging>pom</packaging>

<profiles>
<profile>
<id>osx-x86_64</id>
<activation>
<os>
<family>mac</family>
<arch>x86_64</arch>
</os>
</activation>
<properties>
<platform>osx-x86_64</platform>
<filetype>jnilib</filetype>
</properties>
</profile>
<profile>
<id>linux-x86_64</id>
<activation>
<os>
<family>linux</family>
</os>
</activation>
<properties>
<platform>linux-x86_64</platform>
<filetype>so</filetype>
</properties>
</profile>
</profiles>

<dependencies>
<dependency>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>mxnet-scala-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>ml.dmlc.mxnet</groupId>
<artifactId>libmxnet-scala</artifactId>
<version>${project.version}</version>
<type>${filetype}</type>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-install-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<executions>
<execution>
<id>full</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<finalName>mxnet_${scala.binary.version}-${platform}-${project.version}</finalName>
<descriptors>
<descriptor>src/main/assembly/assembly.xml</descriptor>
</descriptors>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
30 changes: 30 additions & 0 deletions scala-package/assembly/src/main/assembly/assembly.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<assembly>
<id>full</id>
<formats>
<format>jar</format>
</formats>
<includeBaseDirectory>false</includeBaseDirectory>
<dependencySets>
<dependencySet>
<includes>
<include>*:*:jar</include>
</includes>
<outputDirectory>/</outputDirectory>
<useProjectArtifact>true</useProjectArtifact>
<unpack>true</unpack>
<scope>runtime</scope>
</dependencySet>
<dependencySet>
<outputDirectory>lib/native</outputDirectory>
<outputFileNameMapping>${artifact.artifactId}${dashClassifier?}.${artifact.extension}</outputFileNameMapping>
<unpack>false</unpack>
<useProjectArtifact>false</useProjectArtifact>
<useStrictFiltering>false</useStrictFiltering>
<includes>
<include>*:*:dll:*</include>
<include>*:*:so:*</include>
<include>*:*:jnilib:*</include>
</includes>
</dependencySet>
</dependencySets>
</assembly>
14 changes: 7 additions & 7 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@

<build>
<plugins>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
Expand Down
20 changes: 19 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.util.NativeLibraryLoader
import org.slf4j.{LoggerFactory, Logger}

object Base {
private val logger: Logger = LoggerFactory.getLogger("MXNetJVM")

// type definitions
class RefInt(val value: Int = 0)
class RefLong(val value: Long = 0)
Expand Down Expand Up @@ -29,10 +34,23 @@ object Base {
type ExecutorHandleRef = RefLong
type SymbolHandleRef = RefLong

System.loadLibrary("mxnet-scala")
try {
try {
System.loadLibrary("mxnet-scala")
} catch {
case e: UnsatisfiedLinkError =>
NativeLibraryLoader.loadLibrary("mxnet-scala")
}
} catch {
case e: UnsatisfiedLinkError =>
logger.error("Couldn't find native library mxnet-scala")
throw e
}

val _LIB = new LibInfo
checkCall(_LIB.nativeLibInit())

// TODO: shutdown hook won't work on Windows
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = {
notifyShutdown()
Expand Down
Loading

0 comments on commit 9e5137f

Please sign in to comment.