Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Java Inference api and SSD example #12830

Merged
merged 17 commits into from
Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,21 @@
# specific language governing permissions and limitations
# under the License.

hw_type=cpu
if [[ $1 = gpu ]]
then
hw_type=gpu
fi

platform=linux-x86_64

if [[ $OSTYPE = [darwin]* ]]
Copy link
Member

@nswamy nswamy Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are you getting this variables from? $OSTYPE, and $1 is set to MODEL_DIR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing the $1 thing. I missed this when I took the suggestion to update the script to not be osx specific.

The $OSTYPE is an env variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$1 thing is to pick up CPU or GPU because you need to append that as well to the path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$1 was to MODEL_DIR

then
platform=osx-x86_64
fi

MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*

# model dir and prefix
MODEL_DIR=$1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

hw_type=cpu
if [[ $1 = gpu ]]
then
hw_type=gpu
fi

platform=linux-x86_64

if [[ $OSTYPE = [darwin]* ]]
then
platform=osx-x86_64
fi

MXNET_ROOT=$(cd "$(dirname $0)/../../../../../"; pwd)
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/$platform-$hw_type/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*:$MXNET_ROOT/scala-package/examples/src/main/scala/org/apache/mxnetexamples/api/java/infer/imageclassifier/*

# model dir and prefix
MODEL_DIR=$1
# input image
INPUT_IMG=$2
# which input image dir
INPUT_DIR=$3

java -Xmx8G -cp $CLASS_PATH \
org.apache.mxnetexamples.infer.javapi.objectdetector.SSDClassifierExample \
--model-path-prefix $MODEL_DIR \
--input-image $INPUT_IMG \
--input-dir $INPUT_DIR
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Single Shot Multi Object Detection using Scala Inference API

In this example, you will learn how to use Scala Inference API to run Inference on pre-trained Single Shot Multi Object Detection (SSD) MXNet model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has someone followed this README and verified that it works for Java. I still see paths referring to Scala folders?
@piyushghai @zachgk @lanking520 could one of you try to follow step by step as a naive user and verify this README works for Java

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still some references to the word Scala but the commands have all been updated and should work. Someone else running through the steps would definitely be welcome.

I've got a task today to start updating documentation for the Java API and I plan to include this as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll also be running through these steps since I also need to work on portions of the documentation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline conversation with Andrew - he will do a cleanup along with the other docs he is writing for Java API


The model is trained on the [Pascal VOC 2012 dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html). The network is a SSD model built on Resnet50 as base network to extract image features. The model is trained to detect the following entities (classes): ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']. For more details about the model, you can refer to the [MXNet SSD example](https://github.com/apache/incubator-mxnet/tree/master/example/ssd).


## Contents

1. [Prerequisites](#prerequisites)
2. [Download artifacts](#download-artifacts)
3. [Setup datapath and parameters](#setup-datapath-and-parameters)
4. [Run the image inference example](#run-the-image-inference-example)
5. [Infer APIs](#infer-api-details)
6. [Next steps](#next-steps)


## Prerequisites

1. MXNet
2. MXNet Scala Package
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be reworded to MXNet Java and Scala Package ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, going to get all these fixes in one fell swoop during a big doc update.

3. [IntelliJ IDE (or alternative IDE) project setup](http://mxnet.incubator.apache.org/tutorials/scala/mxnet_scala_on_intellij.html) with the MXNet Scala Package
4. wget


## Setup Guide

### Download Artifacts
#### Step 1
You can download the files using the script `get_ssd_data.sh`. It will download and place the model files in a `model` folder and the test image files in a `image` folder in the current directory.
From the `scala-package/examples/scripts/infer/imageclassifier/` folder run:

```bash
./get_ssd_data.sh
```

**Note**: You may need to run `chmod +x get_resnet_data.sh` before running this script.

Alternatively use the following links to download the Symbol and Params files via your browser:
- [resnet50_ssd_model-symbol.json](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json)
- [resnet50_ssd_model-0000.params](https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params)
- [synset.txt](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/synset.txt)

In the pre-trained model, the `input_name` is `data` and shape is `(1, 3, 512, 512)`.
This shape translates to: a batch of `1` image, the image has color and uses `3` channels (RGB), and the image has the dimensions of `512` pixels in height by `512` pixels in width.

`image/jpeg` is the expected input type, since this example's image pre-processor only supports the handling of binary JPEG images.

The output shape is `(1, 6132, 6)`. As with the input, the `1` is the number of images. `6132` is the number of prediction results, and `6` is for the size of each prediction. Each prediction contains the following components:
- `Class`
- `Accuracy`
- `Xmin`
- `Ymin`
- `Xmax`
- `Ymax`


### Setup Datapath and Parameters
#### Step 2
The code `Line 31: val baseDir = System.getProperty("user.dir")` in the example will automatically searches the work directory you have defined. Please put the files in your [work directory](https://stackoverflow.com/questions/16239130/java-user-dir-property-what-exactly-does-it-mean). <!-- how do you define the work directory? -->

Alternatively, if you would like to use your own path, please change line 31 into your own path
```scala
val baseDir = <Your Own Path>
```

The followings is the parameters defined for this example, you can find more information in the `class SSDClassifierExample`.

| Argument | Comments |
| ----------------------------- | ---------------------------------------- |
| `model-path-prefix` | Folder path with prefix to the model (including json, params, and any synset file). |
| `input-image` | The image to run inference on. |
| `input-dir` | The directory of images to run inference on. |


## How to Run Inference
After the previous steps, you should be able to run the code using the following script that will pass all of the required parameters to the Infer API.

From the `scala-package/examples/scripts/inferexample/objectdetector/` folder run:

```bash
./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```

**Notes**:
* These are relative paths to this script.
* You may need to run `chmod +x run_ssd_example.sh` before running this script.

The example should give expected output as shown below:
```
Class: car
Probabilties: 0.99847263
(Coord:,312.21335,72.0291,456.01443,150.66176)
Class: bicycle
Probabilties: 0.90473825
(Coord:,155.95807,149.96362,383.8369,418.94513)
Class: dog
Probabilties: 0.8226818
(Coord:,83.82353,179.13998,206.63783,476.7875)
```
the outputs come from the the input image, with top3 predictions picked.


## Infer API Details
This example uses ObjectDetector class provided by MXNet's scala package Infer APIs. It provides methods to load the images, create NDArray out of Java BufferedImage and run prediction using Classifier and Predictor APIs.


## References
This documentation used the model and inference setup guide from the [MXNet Model Server SSD example](https://github.com/awslabs/mxnet-model-server/blob/master/examples/ssd/README.md).


## Next Steps

Check out the following related tutorials and examples for the Infer API:

* [Image Classification with the MXNet Scala Infer API](../imageclassifier/README.md)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep this as a placeholder and point the users to another Java API example in the future ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's my plan. I'm updating all the Java Documentation throughout today and tomorrow. The imageClassifier is actually mostly written sitting on my machine so I figured I'd just leave this here for now.

Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnetexamples.infer.javapi.objectdetector;

import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.mxnet.javaapi.*;
import org.apache.mxnet.infer.javaapi.ObjectDetector;

// scalastyle:off
import java.awt.image.BufferedImage;
// scalastyle:on

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.tuple.ImmutablePair;

import java.io.File;

public class SSDClassifierExample {
@Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
private String modelPathPrefix = "/model/ssd_resnet50_512";
@Option(name = "--input-image", usage = "the input image")
private String inputImagePath = "/images/dog.jpg";
@Option(name = "--input-dir", usage = "the input batch of images directory")
private String inputImageDir = "/images/";

final static Logger logger = LoggerFactory.getLogger(SSDClassifierExample.class);

static List<List<ImmutablePair<String, List<Float>>>>
runObjectDetectionSingle(String modelPathPrefix, String inputImagePath, List<Context> context) {
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can reduce 2 lines into one with Arrays.asList()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the line limits imposed by ScalaStyle this still spans 2 lines. Because of this my preference is to leave as 2 commands because I think it's slightly more readable than a single command spanning 2 lines. I don't have strong feelings on this though.

inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
BufferedImage img = ObjectDetector.loadImageFromFile(inputImagePath);
ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);
return objDet.imageObjectDetect(img, 3);
}

static List<List<List<ImmutablePair<String, List<Float>>>>>
runObjectDetectionBatch(String modelPathPrefix, String inputImageDir, List<Context> context) {
Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
ObjectDetector objDet = new ObjectDetector(modelPathPrefix, inputDescriptors, context, 0);

// Loading batch of images from the directory path
List<List<String>> batchFiles = generateBatches(inputImageDir, 20);
List<List<List<ImmutablePair<String, List<Float>>>>> outputList
= new ArrayList<List<List<ImmutablePair<String, List<Float>>>>>();

for (List<String> batchFile : batchFiles) {
List<BufferedImage> imgList = ObjectDetector.loadInputBatch(batchFile);
// Running inference on batch of images loaded in previous step
List<List<ImmutablePair<String, List<Float>>>> tmp
= objDet.imageBatchObjectDetect(imgList, 5);
outputList.add(tmp);
}
return outputList;
}

static List<List<String>> generateBatches(String inputImageDirPath, int batchSize) {
File dir = new File(inputImageDirPath);

List<List<String>> output = new ArrayList<List<String>>();
List<String> batch = new ArrayList<String>();
for (File imgFile : dir.listFiles()) {
batch.add(imgFile.getPath());
if (batch.size() == batchSize) {
output.add(batch);
batch = new ArrayList<String>();
}
}
if (batch.size() > 0) {
output.add(batch);
}
return output;
}

public static void main(String[] args) {
SSDClassifierExample inst = new SSDClassifierExample();
CmdLineParser parser = new CmdLineParser(inst);
try {
parser.parseArgument(args);
} catch (Exception e) {
logger.error(e.getMessage(), e);
parser.printUsage(System.err);
System.exit(1);
}

String mdprefixDir = inst.modelPathPrefix;
String imgPath = inst.inputImagePath;
String imgDir = inst.inputImageDir;

if (!checkExist(Arrays.asList(mdprefixDir + "-symbol.json", imgDir, imgPath))) {
logger.error("Model or input image path does not exist");
System.exit(1);
}

List<Context> context = new ArrayList<Context>();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this flag still be SCALA_TEST_ON_GPU ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh... that's a good question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I think that this is correct given that the plan is for the java tests to run as part of the scala test suite instead of as a separate thing. This is far less important for this example than others since there isn't any training involved and the whole thing runs in a matter of seconds.

Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context.add(Context.gpu());
} else {
context.add(Context.cpu());
}

try {
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
Shape outputShape = new Shape(new int[] {1, 6132, 6});


int width = inputShape.get(2);
int height = inputShape.get(3);
String outputStr = "\n";

List<List<ImmutablePair<String, List<Float>>>> output
= runObjectDetectionSingle(mdprefixDir, imgPath, context);

for (List<ImmutablePair<String, List<Float>>> ele : output) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These complex nested structures make it a bit hard to read. Going by @nswamy 's suggestion, a POJO will definitely help ease out the readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree. Already started writing it.

for (ImmutablePair<String, List<Float>> i : ele) {
outputStr += "Class: " + i.getKey() + "\n";
List<Float> arr = i.getValue();
outputStr += "Probabilties: " + arr.get(0) + "\n";

List<Float> coord = Arrays.asList(arr.get(1) * width,
arr.get(2) * height, arr.get(3) * width, arr.get(4) * height);
StringBuilder sb = new StringBuilder();
for (float c: coord) {
sb.append(", ").append(c);
}
outputStr += "Coord:" + sb.substring(2)+ "\n";
}
}
logger.info(outputStr);

List<List<List<ImmutablePair<String, List<Float>>>>> outputList =
runObjectDetectionBatch(mdprefixDir, imgDir, context);

outputStr = "\n";
int index = 0;
for (List<List<ImmutablePair<String, List<Float>>>> i: outputList) {
for (List<ImmutablePair<String, List<Float>>> j : i) {
outputStr += "*** Image " + (index + 1) + "***" + "\n";
for (ImmutablePair<String, List<Float>> k : j) {
outputStr += "Class: " + k.getKey() + "\n";
List<Float> arr = k.getValue();
outputStr += "Probabilties: " + arr.get(0) + "\n";
List<Float> coord = Arrays.asList(arr.get(1) * width,
arr.get(2) * height, arr.get(3) * width, arr.get(4) * height);

StringBuilder sb = new StringBuilder();
for (float c : coord) {
sb.append(", ").append(c);
}
outputStr += "Coord:" + sb.substring(2) + "\n";
}
index++;
}
}
logger.info(outputStr);

} catch (Exception e) {
logger.error(e.getMessage(), e);
parser.printUsage(System.err);
System.exit(1);
}
System.exit(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an explicit exit from the main method ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. The scala example I was working from had one so I left it in for the time being because it doesn't hurt anything and I wasn't 100% sure it wasn't being used for something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay.

}

static Boolean checkExist(List<String> arr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitPick: checkFileExist

Boolean exist = true;
for (String item : arr) {
exist = new File(item).exists() && exist;
if (!exist) {
logger.error("Cannot find: " + item);
}
}
return exist;
}
}
Loading