Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-600][Scala] NDArray auto-collector #11751

Merged
merged 10 commits into from
Jul 19, 2018
Merged

Conversation

yzhliu
Copy link
Member

@yzhliu yzhliu commented Jul 13, 2018

For now user has to dispose all intermediate allocated NDArrays themselves. The NDArrayCollector introduced here is to provide a way for NDs to be disposed automatically.

In the test case & changes to NDArrayIter.scala one can find the example of how to use it.

And for Java users:
https://github.com/yzhliu/mxnet-java-example/blob/master/src/main/java/me/yzhi/mxnet/example/infer/RndImageInference.java#L28-L30
https://github.com/yzhliu/mxnet-java-example/blob/master/src/main/java/me/yzhi/mxnet/example/infer/Seq2SeqInference.java#L35-L37

In the cases above, NDArrays allocated within the scope will be disposed automatically after the code block finishes.

@yzhliu yzhliu requested a review from nswamy as a code owner July 13, 2018 17:49
@yzhliu yzhliu changed the title Scala nd collector [Scala] NDArray auto-collector Jul 13, 2018
@yzhliu yzhliu changed the title [Scala] NDArray auto-collector [MXNET-600][Scala] NDArray auto-collector Jul 13, 2018
@yzhliu yzhliu requested a review from CodingCat July 13, 2018 18:22
@lanking520
Copy link
Member

Hi @yzhliu , It's been so great to see this solution. I would recommend you to apply this new feature to the examples in the CI for testing and that include MNIST and GAN to see if there are any more memory leaks thrown.

@szha
Copy link
Member

szha commented Jul 14, 2018

@yzhliu nice to see you back on Scala for core development.

}

class NDArrayCollector private(private val autoDispose: Boolean = true,
private val doCollect: Boolean = true) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this flag, in which case we will set it to false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/incubator-mxnet/pull/11751/files#diff-d502d315f6c6df78673dcde4d27a9577R69
new NDArray always do NDArrayCollector.collect(this), but nothing be really collected unless the user explicitly uses withScope

* </pre>
* In the case above, the intermediate NDArrays
* (created by <em>NDArray.relu</em> and <em>+</em>) will be disposed automatically. <br />
* User can also decide to use dispose the collected NDArrays later: <br />
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispose is extra?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or use is extra?

* val collector = NDArrayCollector.manual()
* val res = collector.withScope {
* (NDArray.relu(a) + a).toArray
* }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the user does not want auto disposal, what's the other benefit withScope can bring to him/her?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is no other benefit, we may consider make it simpler as just a NDArray.scope, all NDArray in this scope will be automatically disposed (do not even need manual scope)...and for the other case, the user can just do what they are currently doing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ability to collect and dispose manually is useful, for use cases like,

  • Users may want to dispose only a subset of the new-allocated NDArrays, e.g., the withScope returns a complicated data structure which contains NDArrays, these NDArrays normally cannot be disposed automatically (and cannot easily be detect by withScope.
  • Debug/performance analysis.


class NDArrayCollector private(private val autoDispose: Boolean = true,
private val doCollect: Boolean = true) {
private val arrays: mutable.Map[Long, NDArray] = mutable.HashMap.empty[Long, NDArray]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private variable may not need to have an explicit type

@yzhliu
Copy link
Member Author

yzhliu commented Jul 15, 2018

@lanking520 Here you can find how we can fix the memory leak for MNIST training (using SGD): yzhliu#5

But I prefer to make the PR separate. You can try similar thing for GAN, and with other optimizer.


class NDArrayCollector private(private val autoDispose: Boolean = true,
private val doCollect: Boolean = true) {
private val arrays = mutable.HashMap.empty[Long, NDArray]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put some comments in here: The first long represent as the C++ ptr of the NDArray, I think type CPtrAddress = Long will be more clearer in here, it is defined in the base.scala

This HashMap would only keep the ptr as the key, there could be multiple NDArray reference point to the same ptr. Will there be any risks if we only keep one NDArray here (such as we lose reference if A NDArray have different parent of B NDArray but we only kill one of them)? In that case, ArrayBuffer[NDArray] will be better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is exactly what I intended to avoid. Two NDArrays with same ptr must be disposed only once.

/**
* Clear the collector.
*/
def clear(): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see there is any use cases from outside world, shall we keep it private?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using manual(), user may want to re-use one collector:

val c = NDArrayCollector.manual()
c.withScope { ... }
...
c.clear()
c.withScope { ... }

/**
* Create a code scope, NDArrays allocated within this scope will be collected.
* The collected NDArrays will be either <br />
* - disposed automatically when the code blcok finishes (when using <em>auto</em>) or <br />
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blcok -> block

* @return The result of function <em>body</em>.
*/
def withScope[T](body: => T): T = {
val old = NDArrayCollector.currCollector.get()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct?

In every local thread, there could be only 1 currCollector exist.

If there is a cascade NDArrayCollector, we still have one currCollector, but we kill the inner loop result and keep outer one there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we talked offline, yes.

* @tparam T return type of the function <em>body</em>.
* @return The result of function <em>body</em>.
*/
def withScope[T](body: => T): T = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always need to define a return NDArray if we want to keep something created from this inner scope? Shall we change the body to some better name? Such as returnValue?

Copy link
Member

@nswamy nswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this solution, its an elegant implementation of enter/exit blocks like in python. I have concerns of using ThreadLocal which is a red flag to me(IMO) since it takes impacts thread-safety in subtle ways which will be hard to debug, for example in this case if a NDArray in WithScope() in thread-A is passed to another thread-B (think of a producer-consumer problem), if thread-A exits before the object is used in thread-B it could cause problems.

@yzhliu
Copy link
Member Author

yzhliu commented Jul 16, 2018

@nswamy Actually it is not a problem with threadlocal. No matter threadlocal is used or not, the NDArrays within the scope cannot be passed to another thread. It is stated in the Javadoc of NDArrayCollector.

ThreadLocal here is for,
Thread A:

collectorA.withScope {
   val a = new NDArray;
}

Thread B:

collectorB.withScope {
  val b = new NDArray;
}

ThreadLocal can guarantee b stays in NDCollectorB and a stays in NDCollectorA. Otherwise, the scope can be messed up, e.g.,

currCollector = NDCollectorA;  // thread A starts
currCollector = NDCollectorB;  // thread B starts, (but) before a = new NDArray runs
val a = new NDArray; // a will be in NDCollectorB
val b = new NDArray; // b will be also in NDCollectorB
// scope B finishes
a.dispose(); b.dispose(). // scope B dispose all its collected NDs
currCollector = NDCollectorA. // scope B reset the current collector
// but scope A is still running, where NDArray a can be accessed.

Copy link
Member

@lanking520 lanking520 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nswamy
Copy link
Member

nswamy commented Jul 19, 2018

@yzhliu Thanks for elaborating with a detailed example and sorry for the delay in response.
While I understand WithScope is what manages the deallocation of the NDArrays and using it ThreadLocal is what makes managing it possible, nevertheless my concern about thread-safety still remains.
The use-case of producer/consumer is not very uncommon and

val q = new Queue[NDArray] //assume Queue  is thread-safe

val t1 = new Thread() {
@override def run() ={
withScope {
 val a = NDArray.load()
 q.enqueue(a)
}
}
val t2 = new Thread() {
@override def run() = {
b = q.dequeue()
b.reshape()
}
}

This use-case or similar is not hard imagine when you are creating data pipeline and I think we should cautious in introducing new patterns to users that aren't thread-safe.

We discussed offline about using reference-counts, any luck with that?

@yzhliu
Copy link
Member Author

yzhliu commented Jul 19, 2018

@nswamy unfortunately I don't think it is possible to implement within Java user code (unless we leverage finalize) - Users cannot access java heap (in our case, reference number of one object) - the only way is to dump from outside.

Moreover, not necessarily to be multi-threading, even single-thread can have problem:

var a: NDArray
collector.withScope {
    a = new NDArray
}  // a is disposed here
a.toArray() // fail

As I stated, it is users' responsibility not to leak new NDArray outside of scope, otherwise they need to use manual(), and dispose it later.

Every tool has its own limitation, e.g., users are not supposed to use HashMap in multi-threads - otherwise they need to do synchronize themselves.

As you can see, this tool make disposing super easy: https://github.com/yzhliu/mxnet/pull/5/files#diff-19b434167a8c2d81db271f8c47975ff6R43

btw, In your case, users need to do:

val q = new Queue[NDArray] //assume Queue  is thread-safe

val c = NDArrayCollector.manual()

val t1 = new Thread() {
@override def run() ={
c.withScope {
 val a = NDArray.load()
 q.enqueue(a)
}
c.foreach(nd => dispose if nd != a)
}

val t2 = new Thread() {
@override def run() = {
b = q.dequeue()
b.reshape()
}
}

t2.join()
c.foreach(_.dispose())

@nswamy
Copy link
Member

nswamy commented Jul 19, 2018

@yzhliu I understand what you are saying and I think it is given that it doesn't work outside of the scope.
agree every tool has its limitation, I already said that I like this :)

Do you think changing WithScope to WithThreadScope and make it explicit would be useful for users to communicate that it is not thread-safe?

Also, I did not intend my code to be a running version, it was more for illustration of my point.. i'll call it out next time :)

@yzhliu
Copy link
Member Author

yzhliu commented Jul 19, 2018

@nswamy Oh no no, I was not saying your code is wrong or cannot run, I was demonstrating how to change the code a bit to safely use the collector in such situation - my code was not a runnable version either :)

As I said, even single thread can have problem, don't think it is proper to have 'Thread' in the name.

@nswamy
Copy link
Member

nswamy commented Jul 19, 2018

you are right, changing to ThreadScope might again confuse people if the case you showed.

Thanks for the clarifying about the example code, appreciate your patience :)

Copy link
Member

@nswamy nswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that Thread-safety is a limitation of this solution. We can see what users think of this approach in solving the memory management issue.

@yzhliu yzhliu merged commit 1031fe1 into apache:master Jul 19, 2018
@yzhliu
Copy link
Member Author

yzhliu commented Jul 19, 2018

Thanks all people helped to review here.

KellenSunderland pushed a commit to KellenSunderland/incubator-mxnet that referenced this pull request Jul 21, 2018
* [Scala] NDArrayCollector for automatically disposing NDArrays

* modify doc for NDArrayCollector

* modify the function doc of NDArrayCollector.withScope

* remove trivial changes

* put dispose in finally

* fix jni NDArray signature

* modify doc and private var

* dispose res when test finishes

* add comments, change variable name
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* [Scala] NDArrayCollector for automatically disposing NDArrays

* modify doc for NDArrayCollector

* modify the function doc of NDArrayCollector.withScope

* remove trivial changes

* put dispose in finally

* fix jni NDArray signature

* modify doc and private var

* dispose res when test finishes

* add comments, change variable name
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants