-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-53]Image classifier for scala-infer package #10054
Conversation
def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None): | ||
List[List[(String, Float)]] = { | ||
val result = ListBuffer[List[(String, Float)]]() | ||
for (image <- inputBatch) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed offline
the batch of input should be converted to an NDArray and run with ClassifywithNDArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind adding tests to verify the examples?
|
||
package ml.dmlc.mxnet.infer | ||
|
||
import ml.dmlc.mxnet._ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please import only the required modules
def classifyImage(inputImage: BufferedImage, | ||
topK: Option[Int] = None): IndexedSeq[List[(String, Float)]] = { | ||
|
||
val width = inputDescriptors(0).shape(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be based on the layout parameter of the input descriptor
* @param newHeight rescale to new height | ||
* @return Rescaled BufferedImage | ||
*/ | ||
def getScaledImage(img: BufferedImage, newWidth: Int, newHeight: Int): BufferedImage = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not call this reshape Image? instead of getScaledImage. Also this can be static method, so users can just call this method without having to create an object. ie.., a method on the companion object
|
||
val pixels = new ListBuffer[Float]() | ||
|
||
for (x <- 0 until h) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loops should be based on layout of the input descriptor. find out if H
is before W
} | ||
} | ||
|
||
val reshaped_pixels = NDArray.array(pixels.toArray, shape = Shape(224, 224, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape based on input descriptor
|
||
val reshaped_pixels = NDArray.array(pixels.toArray, shape = Shape(224, 224, 3)) | ||
|
||
val swapped_axis = NDArray.swapaxes(reshaped_pixels, 0, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you need to swap if you build the Input according to the input descriptor?
topK: Option[Int] = None): IndexedSeq[List[(String, Float)]] = { | ||
|
||
val width = inputDescriptors(0).shape(2) | ||
val height = inputDescriptors(0).shape(3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above.
* @param resizedImage BufferedImage to get pixels from | ||
* @return NDArray pixels array | ||
*/ | ||
def getPixelsFromImage(resizedImage: BufferedImage): NDArray = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take layout parameter as input and make it static
|
||
val output = super.classifyWithNDArray(input, topK) | ||
|
||
IndexedSeq(output(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dispose input NDArray before return the result.
@marcoabreu currently we don't have any hooks to run Scala integration tests. We will take up adding examples to CI later. |
@nswamy I'm afraid I don't understand. We already have a Scala-Job in CI. Could you elaborate? |
@marcoabreu there is no CI job for running integration tests for Scala examples or tests and I am not comfortable adding this to the unit test stage that is why I said this should be taken up separately and don't spend cycles here. Also I will need your help to set up a CI job for running Scala integration tests. |
I see. Do you think we can add them before the release is cut? I'd be happy to assist in creating the integration job. For the time being (until we got proper nightly tests), we can make a job in the PR pipeline and move it later on. |
@marcoabreu I don't think we'll have the time to make this improvement before the release, but surely have it on our plate to take it up soon after. |
But we can't really announce a new feature without having test coverage :/ |
@marcoabreu this is an example, are we testing python examples. there are already tests in this PR |
@marcoabreu, can you help create a CI Integration test job for Scala ?. We can add it there. I guess its a good thing to make sure the examples work. |
Sure thing! Would you want me to create a PR to your branch or just explain you here what you should do? |
could you please create a PR, we can pick up. |
Sure, no problem. Can you make the changes in order to allow unit testing and integration testing of the scala package and I'll make the integration into CI? At the moment, See https://github.com/apache/incubator-mxnet/blob/master/scala-package/pom.xml#L254 for reference |
@marcoabreu I will look into it. WIll try to run the example as a part of integration test |
b38551a
to
fd2c2ae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A PR for the README.md is forthcoming...
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* | ||
|
||
# model dir | ||
MODEL_DIR=$1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you have this print out usage info?
def runInference(modelPathPrefix: String, inputImagePath: String, inputImageDir: String): | ||
IndexedSeq[IndexedSeq[(String, Float)]] = { | ||
val dType = DType.Float32 | ||
val inputShape = Shape(1, 3, 224, 224) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the SSD example we're getting artifacts by a sh script and these include the signature.json. Shouldn't we be fetching the input info from this?
@Option(name = "--model-dir", usage = "the input model directory") | ||
private val modelPathPrefix: String = "/resnet/resnet-152" | ||
@Option(name = "--input-image", usage = "the input image") | ||
private val inputImagePath: String = "/images/Cat-hd-wallpapers.jpg" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefer to keep the test image artifact consistent: kitten.jpg (plus this is easily grabbed from s3 with some links we have in the MMS repo).
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usage info?
"wget http://data.mxnet.io/models/imagenet/resnet/synset.txt -P resnet/ -q --show-progress"! | ||
|
||
"wget " + | ||
"http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use something on s3 in case this image just disappears for some reason?
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usage info?
c36b6a8
to
22fcac1
Compare
scala-package/examples/pom.xml
Outdated
<version>1.2.0-SNAPSHOT</version> | ||
<relativePath>../pom.xml</relativePath> | ||
</parent> | ||
<modelVersion>4.0.0</modelVersion> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove tabs and replace with spaces
|
||
val inputShape = inputDescriptors(0).shape | ||
|
||
// Considering 'NCHW' as default layout when not provided |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need force user to pass the shape and layout for CHW, you can only use Batch size defaulted to the first axis when not passed
def classifyImage(inputImage: BufferedImage, | ||
topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = { | ||
|
||
val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also transpose if the layout is different, example channel is last. Lets do that as a separate task
val op = NDArray.concatenate(imageBatch) | ||
|
||
val result = super.classifyWithNDArray(IndexedSeq(op), topK) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also dispose the list of NDArrays imageBatch
} | ||
|
||
/** | ||
* Read image file from provided path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what path?
var col = 0 | ||
while (col < w) { | ||
val rgb = pixels(row * w + col) | ||
result(0 * h * w + row * w + col) = (rgb >> 16) & 0xFF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add comments
* | ||
* @param resizedImage BufferedImage to get pixels from | ||
* @param inputImageShape Should be same as inputDescriptor shape | ||
* @return NDArray pixels array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a Note saying Caller is responsible to dispose the returned NDArray
val imageBatch = ListBuffer[NDArray]() | ||
for (image <- inputBatch) { | ||
val scaledImage = ImageClassifier.reshapeImage(image, width, height) | ||
val pixelsNdarray = ImageClassifier.bufferedImageToPixels(scaledImage, inputShape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dispose pixelsNDArray
ba3a43f
to
388fb4d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very good job Roshani!
last few changes before we can merge this.
@@ -44,12 +45,19 @@ class ImageClassifier(modelPathPrefix: String, | |||
|
|||
val classifier: Classifier = getClassifier(modelPathPrefix, inputDescriptors) | |||
|
|||
require(inputDescriptors.head.shape.length != 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.length >= 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check already comes from DataDesc. So, wasnt adding it here. But yeah I can add extra check
require(inputDescriptors.head.shape.length != 0, | ||
"Please provide shape information in the descriptor") | ||
|
||
require(!inputDescriptors.head.layout.isEmpty, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for the I already changed the datadescriptor to have shape.length == layout.length
val inputLayout = inputDescriptors(0).layout | ||
|
||
val inputShape = inputDescriptors(0).shape | ||
|
||
// Considering 'NCHW' as default layout when not provided | ||
// Else get axis according to the layout | ||
// [TODO] if layout is different |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if layout is different than the bufferedImage layout, transpose to match the inputdescriptor shape
@@ -122,31 +133,45 @@ object ImageClassifier { | |||
} | |||
|
|||
/** | |||
* Read image file from provided path | |||
* Convert input BufferedImage to NDArray of input shape | |||
* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start the note with <p>
and end with <p>
scala-package/examples/pom.xml
Outdated
@@ -147,5 +174,12 @@ | |||
<artifactId>opencv</artifactId> | |||
<version>2.4.9-7</version> | |||
</dependency> | |||
<!-- https://mvnrepository.com/artifact/org.mockito/mockito-all --> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need mockito for testing the example. this should be like a integration test
|
||
require(!inputDescriptors.head.layout.isEmpty, | ||
"Please provide layout information in the descriptor") | ||
|
||
val inputLayout = inputDescriptors(0).layout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make the below variables protected[infer]
abf4688
to
f7e70c4
Compare
3d5659c
to
6b3c6ec
Compare
Docs GTG. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added my review. @aaronmarkham @nswamy are your comments still valid?
@@ -14,6 +14,24 @@ | |||
<name>MXNet Scala Package - Examples</name> | |||
|
|||
<profiles> | |||
<profile> | |||
<id>osx-x86_64-cpu</id> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the default behaviour windows or why is it not defined here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no support for windows in the current Scala-package and there are no plans to add with this change. This has to be taken up as a separate task for MXNet-Scala binding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I wasn't aware of that, thanks for elaborating
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add set -e (to allow proper failure detection in automated environments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"/resnet18/ -q") ! | ||
|
||
Process("wget " + | ||
"http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be moved to S3?
|
||
<artifactId>mxnet-infer</artifactId> | ||
<artifactId>mxnet-infer_2.11</artifactId> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly is the "2.11" here? To me, it sounds like this is some kind of centralized version. Maybe use a variable instead? (also above). Is this related to "{scala.binary.version}"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is the Scala binary version, this is how other modules in this package are built. We won't be refactoring the entire project here.
val image1 = new BufferedImage(100, 200, BufferedImage.TYPE_BYTE_GRAY) | ||
val image2 = ImageClassifier.reshapeImage(image1, 1000, 2000) | ||
|
||
assert(image2.getWidth === 1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you execute this test once, generate an MD5 hash based on that output and verify it here? The shape will probably always be right, but the question is whether the actual content has been resized correctly or the code just generated some random data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python has libraries to do image hashing(imagehash). Getting hash is not that trivial in Java. Will have to implement image hashing function (average_hash)to match the hash.
However, I ran the function "ImageClassifier.reshapeImage" and tested the hash with python if it's resizing correctly or not. It's working as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcoabreu I am not convinced that it is worth the effort and extra code to test this functionality. see this https://www.pyimagesearch.com/2017/11/27/image-hashing-opencv-python/.
Also, the reshaped images are used to test against the model for inference and we verify the output(label) matches what the expected output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, validating them in a follow up step with the output labels is even better
|
||
val result = ImageClassifier.bufferedImageToPixels(image2, Shape(1, 3, 2, 2)) | ||
|
||
assert(result.shape == inputDescriptor(0).shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you verify the content? (See above)
@@ -14,6 +14,24 @@ | |||
<name>MXNet Scala Package - Examples</name> | |||
|
|||
<profiles> | |||
<profile> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if an undefined platform is being detected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the mentioned platforms are supported similar to what the current scala-package supports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm, Thanks for addressing my comments so quickly. Good job!
* Image classifier for infer package
* Image classifier for infer package
* Image classifier for infer package
* Image classifier for infer package
Description
https://issues.apache.org/jira/browse/MXNET-50
Depends on #9678
So, build will fail till that PR gets merged
Checklist
Essentials
make lint
)Changes
Comments