-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Would it make sense to have this package under the org.apache namespace since it seems like an entirely new aspect? This might allow us to make a quicker transition, away from dmlc. |
@marcoabreu, this package needs to pull in the core package under dmlc, so i prefer to keep it under dmlc. I would prefer to have a separate discussion to see how we want to manage versioning of language bindings for MXNet. Currently, It does not make sense to me that a API change in one of the language bindings trigger whole new version for entire MXNet. Also Scala packages have not been published to maven for a while now and the versions are maintained in Maven. we have to address this as well. |
Unit Tests are failing for a requirement(number of dimensions in shape match the number of elements in layout) I put into the DataDesc. I will fix the unit tests and update the PR. |
val task = new Callable[T] { | ||
override def call(): T = { | ||
// scalastyle:off println | ||
println("threadId: %s".format(Thread.currentThread().getId())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to use Log
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this for testing, removed the print
|
||
} | ||
|
||
object MXNetSingleThreadHandler extends MXNetOneThreadPerModelHandler { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need an object
extends something? any static
members to be accessed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to create a singleton object with only thread which will be used by default.
} | ||
} | ||
|
||
override val executor: ExecutorService = Executors.newFixedThreadPool(10, threadFactory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make 10
configurable through constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A fixed threadpool of 10 does not enforce the invariant of having a single thread for all MXNet use, which I believe to be the point of this approach. The thread count here should be fixed to 1 (or use Executors.newSingleThreadExecutor
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want this to be configurable, the value 10 was for testing, made this configurable now.
result.get() | ||
} | ||
catch { | ||
case e: ExecutionException => throw e.getCause() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add throw ExecutionException
to the signature of this method.
place catch
the same line of }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about other uncaught exceptions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzhliu What's the purpose of adding a @throws
annotation? These are normally disregarded in Scala.
It would make sense to document why the executionexception is being unwrapped here (the answer is "so it looks like the code was called inline").
if (handlerType == MXNetHandlerType.OneThreadPerModelHandler) { | ||
new MXNetOneThreadPerModelHandler | ||
} | ||
else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
place else
the same line of }
|
||
def execute[T](f: => T): T | ||
|
||
val executor: ExecutorService |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arguably this should not be on the trait at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to be able to expose the executor if at all I find that all NDArray creations have to go throw the same thread.
import ml.dmlc.mxnet.infer.MXNetHandlerType.MXNetHandlerType | ||
|
||
package object infer { | ||
private[mxnet] val handlerType: MXNetHandlerType = MXNetHandlerType.SingleThreadHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change this variable anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we make it user-configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will expose this as a property if I find that MXNet can be called from multiple threads.
, like said in other places I am going with the assumption that MXNet needs to be called from the same thread through out the lifetime of the process otherwise it seg-faults.
@@ -230,6 +230,8 @@ abstract class DataPack() extends Iterable[DataBatch] { | |||
// Named data desc description contains name, shape, type and other extended attributes. | |||
case class DataDesc(name: String, shape: Shape, | |||
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { | |||
require(shape.length == layout.length, "number of dimensions in shape should match the layout") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you show the current length of both in the error msg?
scala-package/infer/pom.xml
Outdated
</plugins> | ||
</build> | ||
</profile> | ||
</profiles> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of the above things can be inherited from parent pom.xml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able inherit and make the tests run. the tests needs the jar to be built
|
||
import java.util.concurrent._ | ||
|
||
trait MXNetHandler { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are not certain about the stability of the API, you may want to make it private[infer]
|
||
} | ||
|
||
object MXNetHandlerType extends Enumeration { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are not certain about the stability of the API, you may want to make it private[infer]
|
||
class MXNetOneThreadPerModelHandler extends MXNetHandler { | ||
|
||
private val threadFactory = new ThreadFactory { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you are working with Scala, you may want to work with more scala-native concurrency facility, i.e. Future, which provides more elegant way for error handling, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll learn about it and let me do that in the next iteration
if (batchSize != 1) { | ||
mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true)) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra empty line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it standard to not have blank lines, I tend to put blank lines so that it is more readable.
|
||
for((i, d) <- input.zip(inputDescriptors)) { | ||
val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra empty line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please check all the other places
/** | ||
* This method will take input as IndexedSeq one dimensional arrays and creates | ||
* NDArray needed for inference. The array will be reshaped based on the input descriptors. | ||
* @param input: A IndexedSequence of Java one-dimensional array, An IndexedSequence is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Java?
for((i, d) <- input.zip(inputDescriptors)) { | ||
val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) | ||
|
||
inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused here, NDArray.array(i, Shape(shape))
is executed by a new thread?
and it's actually blocking here until the function is finished? why we need to do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calls to MXNet's native library are not safe to call from more than one thread per program execution (unless that has been changed). This is just how MXNet works at present. We described this on the discussion site at one point.
It's actually not entirely known if some calls are safe and some are not but we've found the best stability from single-threading all calls to MXNet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am going by Calum's findings here
import ml.dmlc.mxnet.infer.MXNetHandlerType.MXNetHandlerType | ||
|
||
package object infer { | ||
private[mxnet] val handlerType: MXNetHandlerType = MXNetHandlerType.SingleThreadHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we make it user-configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some feedback on this implementation as well.
private val threadFactory = new ThreadFactory { | ||
|
||
override def newThread(r: Runnable): Thread = new Thread(r) { | ||
setName(classOf[MXNetOneThreadPerModelHandler].getCanonicalName) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name here is a little confusing given that it is also used for MXNetSingleThreadHandler.
|
||
type MXNetHandlerType = Value | ||
val SingleThreadHandler = Value("MXNetSingleThreadHandler") | ||
val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that one-thread-per-model is (to the best of my knowledge) currently an unsafe operating mode, is there any point in supporting it just yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be able to support in the future or if I find that its OK to run one-thread-per-model
result.get() | ||
} | ||
catch { | ||
case e: ExecutionException => throw e.getCause() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yzhliu What's the purpose of adding a @throws
annotation? These are normally disregarded in Scala.
It would make sense to document why the executionexception is being unwrapped here (the answer is "so it looks like the code was called inline").
|
||
override val executor: ExecutorService = Executors.newFixedThreadPool(10, threadFactory) | ||
|
||
override def execute[T](f: => T): T = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to support the recursive case? If you submit a task which in turn wants to submit a task to the MXNet thread, it will deadlock in this arrangement. You can avoid this by checking whether you're already on the managed thread and executing the code inline if so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not encounter this when i unit tested by running on the same thread. However I fixed this based on your implementation.
} | ||
} | ||
|
||
override val executor: ExecutorService = Executors.newFixedThreadPool(10, threadFactory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A fixed threadpool of 10 does not enforce the invariant of having a single thread for all MXNet use, which I believe to be the point of this approach. The thread count here should be fixed to 1 (or use Executors.newSingleThreadExecutor
).
|
||
for((i, d) <- input.zip(inputDescriptors)) { | ||
require (i.length == d.shape.product/batchSize, "number of elements:" + | ||
" %d in the input does not match the shape:%s".format( i.length, d.shape.toString())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: The use of spaces within parentheses here is inconsistent.
for((i, d) <- input.zip(inputDescriptors)) { | ||
val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) | ||
|
||
inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calls to MXNet's native library are not safe to call from more than one thread per program execution (unless that has been changed). This is just how MXNet works at present. We described this on the discussion site at one point.
It's actually not entirely known if some calls are safe and some are not but we've found the best stability from single-threading all calls to MXNet.
|
||
def execute[T](f: => T): T | ||
|
||
val executor: ExecutorService |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arguably this should not be on the trait at all.
7fb5992
to
dad1436
Compare
Hi, the community has passed to vote about associating the code changes with JIRA (https://lists.apache.org/thread.html/ab22cf0e35f1bce2c3bf3bec2bc5b85a9583a3fe7fd56ba1bbade55f@%3Cdev.mxnet.apache.org%3E) We have updated the guidelines for contributors in https://cwiki.apache.org/confluence/display/MXNET/Development+Process, please ensure that you have created a JIRA at https://issues.apache.org/jira/projects/MXNET/issues/ to describe your work in this pull request and include the JIRA title in your PR as [MXNET-xxxx] your title where MXNET-xxxx is the JIRA id Thanks! |
72e8bc6
to
85545fa
Compare
@CodingCat @yzhliu @calumleslie @CodingCat @calumleslie |
|
||
val synset = readSynsetFile(synsetFilePath) | ||
|
||
val handler = MXNetHandler() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can any of the above variable be 'private' or 'private[mxnet]'?
* @return IndexedSequence of (Label, Score) tuples. | ||
*/ | ||
def classify(input: IndexedSeq[Array[Float]], | ||
topK: Option[Int] = None): List[(String, Float)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is IndexedSeq
better than List
? (the return type)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, changed the return type to IndexedSeq instead.
|
||
val synsetFilePath = getSynsetFilePath(modelPathPrefix) | ||
|
||
val synset = readSynsetFile(synsetFilePath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to make synset optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classifier is just taking the predictor and using Synset to map the labels, this is the only difference between the predictor/classifier
resultND | ||
} | ||
|
||
def loadModule(): Module = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it is better to rely on Symbol
and Executor
directly. Executor
is much cleaner and cannot run into any memory leak before we fix the GC problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't this just be bringing parts of the Module
code into this class ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline with @yzhliu, Module is necessary to support GPUs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private
?
9ca4156
to
c370736
Compare
|
||
protected[mxnet] val synset = readSynsetFile(synsetFilePath) | ||
|
||
protected[mxnet] val handler = MXNetHandler() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make them protected[infer]
?
s.getCanonicalPath | ||
} | ||
|
||
def readSynsetFile(synsetFilePath: String): List[String] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to return IndexedSeq
since you will do something like synset(sIndx)
later.
* @param input: Indexed Sequence of NDArrays | ||
* @param topK: (Optional) How many top_k(sorting will be based on the last axis) | ||
* elements to return, if not passed returns unsorted output. | ||
* @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean 'Results will be in the same order as the input NDArray' ? I cannot quite get what 'Score will be in the form of NDArray' means.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this remained from the previous version of the code. thanks for pointing. fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, have you fixed? I don't see from my end.
override def execute[T](f: => T): T = { | ||
|
||
if (Thread.currentThread() eq creatorThread) { | ||
f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to write f()
to distinct from a return of a function variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't work since it does not take any parameters, so I added a return to be more explicit. return statement fails style-check, so i have to keep it the way it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just how by-name parameters work, it's definitely a bit awkward :(
val task = new Callable[T] { | ||
override def call(): T = { | ||
logger.debug("threadId: %s".format(Thread.currentThread().getId())) | ||
f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to write f() to distinct from a return of a function variable.
/** | ||
* Base Trait for MXNet Predictor classes. | ||
*/ | ||
private[mxnet] trait PredictBase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private[infer]
?
* This method will take input as IndexedSeq one dimensional arrays and creates | ||
* NDArray needed for inference. The array will be reshaped based on the input descriptors. | ||
* @param input: A IndexedSequence of Scala one-dimensional array, An IndexedSequence is | ||
* is needed when the model has more than one input/output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel confused with 'An IndexedSequence is is needed when the model has more than one input/output'. The IndexedSeq
is not about model, it's just an array of input data, one entry for one input sample, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this for models that need more than 1 input, a crude example would be a model that takes 2 different images. Are you implying that it can just be another dimension to the same input?
protected[mxnet] var batchSize = if (batchIndex != -1) inputDescriptors(0).shape(batchIndex) | ||
else 1 | ||
|
||
protected[mxnet] var iDescriptors = inputDescriptors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protected[infer]
?
|
||
protected[mxnet] val mxNetHandler = MXNetHandler() | ||
|
||
protected[mxnet] val mod = loadModule() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
protected[infer]
or even better, protected
/private
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if i don't make it package[infer] private. I won't be able to mock and unit test easily.
resultND | ||
} | ||
|
||
def loadModule(): Module = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private
?
override val executor: ExecutorService = | ||
Executors.newFixedThreadPool(numThreads.get, threadFactory) | ||
|
||
private val creatorThread = executor.submit(new Callable[Thread] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if creatorThread
is going to have a meaningful value when the size of the threadpool is not 1.
Given that the size of the thread pool must always be 1 at the moment, maybe it'd be easier to just avoid having to handle multiple threads?
try { | ||
result.get() | ||
} catch { | ||
case e: Exception => throw e.getCause() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably only be doing this for ExecutionException. This method can also throw InterruptedException which is not guaranteed to have a cause.
I also think that you should comment explaining why you're unwrapping this exception here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might also get other exceptions, I think it makes sense to catch all Exceptions. I will throw InterruptedException as is separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this API should only ever throw ExecutionException or InterruptedException but the point was the only type of exception you can safely call "getCause" on is the former. You probably just want all other exceptions to be propagated as is which you can do by just not catching them.
val OneThreadPerModelHandler = Value("MXNetOneThreadPerModelHandler") | ||
} | ||
|
||
private[infer] class MXNetThreadPoolHandler(numThreads: Option[Int] = Some(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numThreads
is an Option[Int]
but you always unconditionally dereference it. Is None
a valid value? Could you just accept an Int
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, regardless of whether an Option or just a Int, I needed a validation of int > 0. I changed it to a Int and added validation.
|
||
private val threadFactory = new ThreadFactory { | ||
|
||
override def newThread(r: Runnable): Thread = new Thread(r) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will create all threads with the same name, which will make them difficult (or in some contexts impossible) to tell apart. If you're going to continue to support multiple threads consider incrementing a counter whenever you vend a thread and appending the number to the name (this is the behaviour of Guava's ThreadFactoryBuilder for example).
@nswamy Sorry it took me a while to get back I've been off ill. Also I don't really use GitHub so often miss the notifications! I've added some more comments on some of the implementation here.
It looks like you fixed this with the |
37082d3
to
ac2fe86
Compare
* Scala Inference APIs * fix unit tests for shape.length == layout.length in DataDesc * make ThreadPoolHandler of size 1 * Rename PredictBase to Predictor * change classify output from List to IndexedSeq * modify MXNetHandler to check if the task is executing on the same thread that created the handler * add argument epoch for Predictor/Classifier
* Scala Inference APIs * fix unit tests for shape.length == layout.length in DataDesc * make ThreadPoolHandler of size 1 * Rename PredictBase to Predictor * change classify output from List to IndexedSeq * modify MXNetHandler to check if the task is executing on the same thread that created the handler * add argument epoch for Predictor/Classifier
* Scala Inference APIs * fix unit tests for shape.length == layout.length in DataDesc * make ThreadPoolHandler of size 1 * Rename PredictBase to Predictor * change classify output from List to IndexedSeq * modify MXNetHandler to check if the task is executing on the same thread that created the handler * add argument epoch for Predictor/Classifier
* Scala Inference APIs * fix unit tests for shape.length == layout.length in DataDesc * make ThreadPoolHandler of size 1 * Rename PredictBase to Predictor * change classify output from List to IndexedSeq * modify MXNetHandler to check if the task is executing on the same thread that created the handler * add argument epoch for Predictor/Classifier
Description
https://issues.apache.org/jira/browse/MXNET-50
MXNet Scala Inference Library.
Checklist
Essentials
make lint
)Changes
Comments