Skip to content

Commit

Permalink
adding context parameter to infer api- imageclassifier and objectdete…
Browse files Browse the repository at this point in the history
…ctor (apache#10252)

* adding context parameter

* parameter description added
  • Loading branch information
Roshrini authored and Jin Huang committed Mar 30, 2018
1 parent f48ca01 commit 3770a6c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package ml.dmlc.mxnet.infer

import ml.dmlc.mxnet.{DataDesc, NDArray, Shape}
import ml.dmlc.mxnet.{Context, DataDesc, NDArray, Shape}

import scala.collection.mutable.ListBuffer

Expand All @@ -37,13 +37,15 @@ import javax.imageio.ImageIO
* file://model-dir/synset.txt
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and Type parameters
* @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
* @param epoch Model epoch to load, defaults to 0.
*/
class ImageClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc])
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0))
extends Classifier(modelPathPrefix,
inputDescriptors) {

val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors)
inputDescriptors, contexts, epoch) {

protected[infer] val inputLayout = inputDescriptors.head.layout

Expand Down Expand Up @@ -108,8 +110,10 @@ class ImageClassifier(modelPathPrefix: String,
result
}

def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]): Classifier = {
new Classifier(modelPathPrefix, inputDescriptors)
def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)): Classifier = {
new Classifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
*/

package ml.dmlc.mxnet.infer

// scalastyle:off
import java.awt.image.BufferedImage
// scalastyle:on
import ml.dmlc.mxnet.NDArray
import ml.dmlc.mxnet.DataDesc

import ml.dmlc.mxnet.{Context, DataDesc, NDArray}
import scala.collection.mutable.ListBuffer

/**
* A class for object detection tasks
*
Expand All @@ -32,11 +34,16 @@ import scala.collection.mutable.ListBuffer
* file://model-dir/synset.txt
* @param inputDescriptors Descriptors defining the input node names, shape,
* layout and Type parameters
* @param contexts Device Contexts on which you want to run Inference, defaults to CPU.
* @param epoch Model epoch to load, defaults to 0.
*/
class ObjectDetector(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc]) {
inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)) {

val imgClassifier: ImageClassifier = getImageClassifier(modelPathPrefix, inputDescriptors)
val imgClassifier: ImageClassifier =
getImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)

val inputShape = imgClassifier.inputShape

Expand All @@ -54,7 +61,7 @@ class ObjectDetector(modelPathPrefix: String,
* To Detect bounding boxes and corresponding labels
*
* @param inputImage : PathPrefix of the input image
* @param topK : Get top k elements with maximum probability
* @param topK : Get top k elements with maximum probability
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
*/
def imageObjectDetect(inputImage: BufferedImage,
Expand All @@ -71,9 +78,10 @@ class ObjectDetector(modelPathPrefix: String,
/**
* Takes input images as NDArrays. Useful when you want to perform multiple operations on
* the input Array, or when you want to pass a batch of input images.
*
* @param input : Indexed Sequence of NDArrays
* @param topK : (Optional) How many top_k(sorting will be based on the last axis)
* elements to return. If not passed, returns all unsorted output.
* @param topK : (Optional) How many top_k(sorting will be based on the last axis)
* elements to return. If not passed, returns all unsorted output.
* @return List of List of tuples of (class, [probability, xmin, ymin, xmax, ymax])
*/
def objectDetectWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int])
Expand All @@ -90,10 +98,10 @@ class ObjectDetector(modelPathPrefix: String,
batchResult.toIndexedSeq
}

private def sortAndReformat(predictResultND : NDArray, topK: Option[Int])
private def sortAndReformat(predictResultND: NDArray, topK: Option[Int])
: IndexedSeq[(String, Array[Float])] = {
val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]()
val accuracy : ListBuffer[Float] = ListBuffer[Float]()
val accuracy: ListBuffer[Float] = ListBuffer[Float]()

// iterating over the all the predictions
val length = predictResultND.shape(0)
Expand All @@ -110,7 +118,7 @@ class ObjectDetector(modelPathPrefix: String,
handler.execute(r.dispose())
}
var result = IndexedSeq[(String, Array[Float])]()
if(topK.isDefined) {
if (topK.isDefined) {
var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
Expand All @@ -127,8 +135,9 @@ class ObjectDetector(modelPathPrefix: String,

/**
* To classify batch of input images according to the provided model
*
* @param inputBatch Input batch of Buffered images
* @param topK Get top k elements with maximum probability
* @param topK Get top k elements with maximum probability
* @return List of list of tuples of (class, probability)
*/
def imageBatchObjectDetect(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
Expand All @@ -148,9 +157,11 @@ class ObjectDetector(modelPathPrefix: String,
result
}

def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
def getImageClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)):
ImageClassifier = {
new ImageClassifier(modelPathPrefix, inputDescriptors)
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package ml.dmlc.mxnet.infer

import ml.dmlc.mxnet.{DType, DataDesc, Shape, NDArray}

import ml.dmlc.mxnet._
import org.mockito.Matchers._
import org.mockito.Mockito
import org.scalatest.{BeforeAndAfterAll}
import org.scalatest.BeforeAndAfterAll

// scalastyle:off
import java.awt.image.BufferedImage
Expand All @@ -33,15 +32,16 @@ import java.awt.image.BufferedImage
class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {

class MyImageClassifier(modelPathPrefix: String,
inputDescriptors: IndexedSeq[DataDesc])
inputDescriptors: IndexedSeq[DataDesc])
extends ImageClassifier(modelPathPrefix, inputDescriptors) {

override def getPredictor(): MyClassyPredictor = {
Mockito.mock(classOf[MyClassyPredictor])
}

override def getClassifier(modelPathPrefix: String, inputDescriptors:
IndexedSeq[DataDesc]): Classifier = {
IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)): Classifier = {
Mockito.mock(classOf[Classifier])
}

Expand Down Expand Up @@ -84,7 +84,7 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {

val synset = testImageClassifier.synset

val predictExpectedOp : List[(String, Float)] =
val predictExpectedOp: List[(String, Float)] =
List[(String, Float)]((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f))

Expand All @@ -93,13 +93,14 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))

Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
Mockito.doReturn(IndexedSeq(predictExpectedOp))
.when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))

val predictResult: IndexedSeq[IndexedSeq[(String, Float)]] =
testImageClassifier.classifyImage(inputImage, Some(4))

for(i <- predictExpected.indices) {
for (i <- predictExpected.indices) {
assertResult(predictExpected(i).sortBy(-_)) {
predictResult(i).map(_._2).toArray
}
Expand All @@ -119,23 +120,24 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {

val predictExpected: IndexedSeq[Array[Array[Float]]] =
IndexedSeq[Array[Array[Float]]](Array(Array(.98f, 0.97f, 0.96f, 0.99f),
Array(.98f, 0.97f, 0.96f, 0.99f)))
Array(.98f, 0.97f, 0.96f, 0.99f)))

val synset = testImageClassifier.synset

val predictExpectedOp : List[List[(String, Float)]] =
val predictExpectedOp: List[List[(String, Float)]] =
List[List[(String, Float)]](List((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f)),
List((synset(1), .98f), (synset(2), .97f),
(synset(3), .96f), (synset(0), .99f)))
(synset(3), .96f), (synset(0), .99f)))

val predictExpectedND: NDArray = NDArray.array(predictExpected.flatten.flatten.toArray,
Shape(2, 4))

Mockito.doReturn(IndexedSeq(predictExpectedND)).when(testImageClassifier.predictor)
.predictWithNDArray(any(classOf[IndexedSeq[NDArray]]))

Mockito.doReturn(IndexedSeq(predictExpectedOp)).when(testImageClassifier.classifier)
Mockito.doReturn(IndexedSeq(predictExpectedOp))
.when(testImageClassifier.getClassifier(modelPath, inputDescriptor))
.classifyWithNDArray(any(classOf[IndexedSeq[NDArray]]), Some(anyInt()))

val result: IndexedSeq[IndexedSeq[(String, Float)]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.awt.image.BufferedImage
// scalastyle:on
import ml.dmlc.mxnet.Context
import ml.dmlc.mxnet.DataDesc
import ml.dmlc.mxnet.{NDArray, Shape}
import ml.dmlc.mxnet.{Context, NDArray, Shape}
import org.mockito.Matchers.any
import org.mockito.Mockito
import org.scalatest.BeforeAndAfterAll
Expand All @@ -36,21 +36,24 @@ class ObjectDetectorSuite extends ClassifierSuite with BeforeAndAfterAll {
extends ObjectDetector(modelPathPrefix, inputDescriptors) {

override def getImageClassifier(modelPathPrefix: String, inputDescriptors:
IndexedSeq[DataDesc]): ImageClassifier = {
IndexedSeq[DataDesc], contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)): ImageClassifier = {
new MyImageClassifier(modelPathPrefix, inputDescriptors)
}

}

class MyImageClassifier(modelPathPrefix: String,
protected override val inputDescriptors: IndexedSeq[DataDesc])
extends ImageClassifier(modelPathPrefix, inputDescriptors) {
extends ImageClassifier(modelPathPrefix, inputDescriptors, Context.cpu(), Some(0)) {

override def getPredictor(): MyClassyPredictor = {
Mockito.mock(classOf[MyClassyPredictor])
}

override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc]):
override def getClassifier(modelPathPrefix: String, inputDescriptors: IndexedSeq[DataDesc],
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)):
Classifier = {
new MyClassifier(modelPathPrefix, inputDescriptors)
}
Expand Down

0 comments on commit 3770a6c

Please sign in to comment.