From f81aec531685701c9ae4a5f9c75a70cd0c26df76 Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 8 Mar 2019 11:35:20 -0800 Subject: [PATCH 01/14] Fixes for memory leak when reshaping executor --- .../scala/org/apache/mxnet/Executor.scala | 37 +++++++++++++++---- .../org/apache/mxnet/NativeResource.scala | 12 ++++++ .../main/scala/org/apache/mxnet/Symbol.scala | 17 +++++---- .../module/DataParallelExecutorGroup.scala | 14 ++++--- 4 files changed, 60 insertions(+), 20 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 85f45bc66fe0..3a8e7ffba7c3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -45,18 +45,20 @@ object Executor { * @see Symbol.bind : to create executor */ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, - private[mxnet] val symbol: Symbol) extends NativeResource { - private[mxnet] var argArrays: Array[NDArray] = null - private[mxnet] var gradArrays: Array[NDArray] = null - private[mxnet] var auxArrays: Array[NDArray] = null + private[mxnet] val symbol: Symbol, + private[mxnet] var argArrays: Array[NDArray] = null, + private[mxnet] var gradArrays: Array[NDArray] = null, + private[mxnet] var auxArrays: Array[NDArray] = null, + private var _ctx: Context = null, + private var _gradsReq: Iterable[_] = null, + private var _group2ctx: Map[String, Context] = null + ) extends NativeResource { + val outputs: Array[NDArray] = getOutputs protected var _argDict: Map[String, NDArray] = null protected var _gradDict: Map[String, NDArray] = null protected var _auxDict: Map[String, NDArray] = null protected var monitorCallback: MXMonitorCallback = null - private[mxnet] var _ctx: Context = null - private[mxnet] var _gradsReq: Iterable[_] = null - private[mxnet] var _group2ctx: Map[String, Context] = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) override def nativeAddress: CPtrAddress = handle @@ -64,10 +66,30 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, // cannot determine the off-heap size of this object override val bytesAllocated: Long = 0 override val ref: NativeResourceRef = super.register() + + private[mxnet] def updateDepResourceScope(): Unit = { + if (argArrays != null) {argArrays.foreach(a => a.moveToScopeOf(this))} + if (gradArrays != null) {gradArrays.foreach( + // Symbol will sometimes fill this with nulls so we've got to check the elements too + a => if (a != null) {a.moveToScopeOf(this)}) + } + if (auxArrays != null) {auxArrays.foreach(a => a.moveToScopeOf(this))} + outputs.foreach(o => o.moveToScopeOf(this)) + } + override def dispose(): Unit = { if (!super.isDisposed) { super.dispose() outputs.foreach(o => o.dispose()) + if (argArrays != null) {argArrays.foreach(a => a.dispose())} + if (gradArrays != null) {gradArrays.foreach( + // Symbol will sometimes fill this with nulls so we've got to check the elements too + a => if (a != null) {a.dispose()}) + } + if (auxArrays != null) {auxArrays.foreach(a => a.dispose())} + if (_argDict != null) {_argDict.foreach(a => a._2.dispose())} + if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())} + if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())} } } @@ -145,6 +167,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If this is intended, set partialShaping = true to suppress this warning.") } } + if (this._gradsReq.isInstanceOf[Seq[_]]) { this.symbol.bind(this._ctx, newArgDict, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 1806b8653376..a657534dd279 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -109,6 +109,18 @@ private[mxnet] trait NativeResource } } + /** + * This method will move the NativeResource from it's current scope and add it to the + * scope of another NativeResource. Useful for when a new resource is made internally and needs + * to be coupled to an existing resource. + * @param nativeResource the native resource which has the desired destination resourceScope + */ + private [mxnet] def moveToScopeOf(nativeResource: NativeResource): Unit = { + if (scope.isDefined) scope.get.remove(this) + scope = nativeResource.scope + if (scope.isDefined) scope.get.add(this) + } + /* this is used by the WarnIfNotDisposed finalizer, the object could be disposed by the GC without the need for explicit disposal diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 29885fc723cd..d726107b630e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -803,18 +803,19 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso auxArgsHandle, sharedHandle, execHandle)) - val executor = new Executor(execHandle.value, this.clone()) - executor.argArrays = argsNDArray - executor.gradArrays = argsGradNDArray - executor.auxArrays = auxStatesNDArray - executor._ctx = new Context(ctx.deviceType, ctx.deviceId) - executor._gradsReq = gradsReq - executor._group2ctx = + + val executorGroup2ctx = if (group2ctx == null) null else group2ctx.map { case (key, value) => key -> new Context(value.deviceType, value.deviceId) } - executor + + val newSymbol = this.clone() + newSymbol.moveToScopeOf(this) + new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray, + auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId), + gradsReq, executorGroup2ctx) + } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index df66ea7721fb..61d927b83312 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -299,7 +299,7 @@ class DataParallelExecutorGroup private[module]( private var batchSize: Int = -1 private var slices: Array[(Int, Int)] = null - private var _defaultExecs: Array[Executor] = null + // private var _defaultExecs: Array[Executor] = null private var execs: Array[Executor] = null private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None @@ -373,7 +373,11 @@ class DataParallelExecutorGroup private[module]( val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts)) val inputShapes = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape]) - execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes) + val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes) + tmpExec.moveToScopeOf(execs(i)) + tmpExec.updateDepResourceScope() + execs(i).dispose() + execs(i) = tmpExec } } else { execs = (0 until contexts.length).map(i => @@ -434,9 +438,9 @@ class DataParallelExecutorGroup private[module]( */ def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = { if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) { - if (this._defaultExecs == null) { - this._defaultExecs = this.execs.map(x => x) - } + // if (this._defaultExecs == null) { + // this._defaultExecs = this.execs.map(x => x) + // } this.bindExec(dataShapes, labelShapes, None, reshape = true) } } From 120532053d606e4f00a9383bf1d431db212c6afc Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 8 Mar 2019 13:40:19 -0800 Subject: [PATCH 02/14] Fixed Adam Optimizer memory leak --- .../org/apache/mxnet/optimizer/Adam.scala | 105 +++++++++--------- 1 file changed, 50 insertions(+), 55 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala index 24f3323073f7..63257276ccec 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala @@ -17,9 +17,10 @@ package org.apache.mxnet.optimizer +import org.apache.log4j.lf5.util.Resource import org.apache.mxnet.NDArrayConversions._ import org.apache.mxnet.util.SerializerUtils -import org.apache.mxnet.{LRScheduler, NDArray, Optimizer} +import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope} /** * Adam optimizer as described in [King2014] @@ -57,63 +58,57 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float = * The auxiliary state used in optimization. */ override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { - var lr = - (if (lrScheduler != null) { - val scheduledLr = lrScheduler(numUpdate) - updateCount(index) - scheduledLr - } else { - this.learningRate - }) - lr = getLr(index, lr) - - val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - - // increment time only when the first parameters is called - timeFirstIndex match { - case Some(idx) => - if (idx == index) time += 1 - case None => - timeFirstIndex = Option(index) - time = 0 // all parameters share the same time - } - - val t1: Int = time + 1 - val learningRate = (lr * - math.sqrt(1.0 - math.pow(beta2, t1)) / - (1.0 - math.pow(beta1, t1))).toFloat - val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat - - var resdGrad = grad * rescaleGrad - if (clipGradient != 0f) { - val oldResdGrad = resdGrad - resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) - oldResdGrad.dispose() - } - - val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) - .disposeDepsExcept(mean, resdGrad) - val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) - .disposeDepsExcept(variance, resdGrad) + ResourceScope.using() { + var lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) + lr = getLr(index, lr) - val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) - .disposeDepsExcept(meanT, varianceT) + val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - val wd = this.getWd(index, this.wd) - if (wd > 0.0f) { - val stepDelta = lr * wd * weight - step += stepDelta - stepDelta.dispose() + // increment time only when the first parameters is called + timeFirstIndex match { + case Some(idx) => + if (idx == index) time += 1 + case None => + timeFirstIndex = Option(index) + time = 0 // all parameters share the same time + } + + val t1: Int = time + 1 + val learningRate = (lr * + math.sqrt(1.0 - math.pow(beta2, t1)) / + (1.0 - math.pow(beta1, t1))).toFloat + val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat + + var resdGrad = grad * rescaleGrad + if (clipGradient != 0f) { + val oldResdGrad = resdGrad + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + } + + val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) + val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) + val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) + + val wd = this.getWd(index, this.wd) + if (wd > 0.0f) { + val stepDelta = lr * wd * weight + step += stepDelta + } + + weight -= step + mean.set(meanT) + variance.set(varianceT) + + mean.moveToScopeOf(state.asInstanceOf[(NDArray, NDArray)]._1) + variance.moveToScopeOf(state.asInstanceOf[(NDArray, NDArray)]._2) } - - weight -= step - mean.set(meanT) - variance.set(varianceT) - - meanT.dispose() - varianceT.dispose() - step.dispose() - resdGrad.dispose() } // Create additional optimizer state: mean, variance From c27efd9418e33d958315367c8ca33a367fcc74f7 Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 8 Mar 2019 13:45:32 -0800 Subject: [PATCH 03/14] Cleanup for PR --- .../core/src/main/scala/org/apache/mxnet/Executor.scala | 1 - .../org/apache/mxnet/module/DataParallelExecutorGroup.scala | 4 ---- 2 files changed, 5 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 3a8e7ffba7c3..3e27bb00ed50 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -167,7 +167,6 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If this is intended, set partialShaping = true to suppress this warning.") } } - if (this._gradsReq.isInstanceOf[Seq[_]]) { this.symbol.bind(this._ctx, newArgDict, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index 61d927b83312..f6c56aab7b9c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module]( private var batchSize: Int = -1 private var slices: Array[(Int, Int)] = null - // private var _defaultExecs: Array[Executor] = null private var execs: Array[Executor] = null private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None @@ -438,9 +437,6 @@ class DataParallelExecutorGroup private[module]( */ def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = { if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) { - // if (this._defaultExecs == null) { - // this._defaultExecs = this.execs.map(x => x) - // } this.bindExec(dataShapes, labelShapes, None, reshape = true) } } From cb2542e69bc4592dc6ac0fa1e1be13c1f4bb430a Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 8 Mar 2019 14:10:30 -0800 Subject: [PATCH 04/14] Added unit test for new ResourceScope method --- .../org/apache/mxnet/ResourceScopeSuite.scala | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 41dfa7d0ead2..8557dfe49cab 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -101,6 +101,47 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(a.isDisposed == true, "returned object should be disposed in the outer scope") } + /** + * Goes multiple scopes deep to make sure we can move across multiple levels correctly + */ + test("test moving scope of native resource to scope of another") { + var a: TestNativeResource = null + var b: TestNativeResource = null + var c: TestNativeResource = null + var d: TestNativeResource = null + + ResourceScope.using() { + a = new TestNativeResource() + ResourceScope.using() { + b = new TestNativeResource() + b.moveToScopeOf(a) + ResourceScope.using() { + c = new TestNativeResource() + ResourceScope.using() { + d = new TestNativeResource() + c.moveToScopeOf(d) + d.moveToScopeOf(a) + assert(c.isDisposed == false) + assert(d.isDisposed == false) + } + assert(c.isDisposed == true) + assert(d.isDisposed == false) + } + assert(b.isDisposed == false) + assert(c.isDisposed == true) + assert(d.isDisposed == false) + } + assert(a.isDisposed == false) + assert(b.isDisposed == false) + assert(c.isDisposed == true) + assert(d.isDisposed == false) + } + assert(a.isDisposed == true) + assert(b.isDisposed == true) + assert(c.isDisposed == true) + assert(d.isDisposed == true) + } + test(testName = "test NativeResources in returned Lists are not disposed") { var ndListRet: IndexedSeq[TestNativeResource] = null ResourceScope.using() { From 1dee79129a8d186adc667aa7d611581682176460 Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 8 Mar 2019 14:18:51 -0800 Subject: [PATCH 05/14] Removing import that was added by overzealous ide --- .../core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala index 63257276ccec..f1ae87dfbc96 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala @@ -17,8 +17,6 @@ package org.apache.mxnet.optimizer -import org.apache.log4j.lf5.util.Resource -import org.apache.mxnet.NDArrayConversions._ import org.apache.mxnet.util.SerializerUtils import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope} From 5b8d88c8bebfe6b718b14741fb22a77b1490696f Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 8 Mar 2019 14:21:25 -0800 Subject: [PATCH 06/14] Add back in an import --- .../core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala index f1ae87dfbc96..c89352746a4f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala @@ -17,6 +17,7 @@ package org.apache.mxnet.optimizer +import org.apache.mxnet.NDArrayConversions._ import org.apache.mxnet.util.SerializerUtils import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope} From 2f3d5165b17c770e20d26c532568effcb0bce4d4 Mon Sep 17 00:00:00 2001 From: Ayres Date: Sat, 9 Mar 2019 15:40:27 -0800 Subject: [PATCH 07/14] Added flags for executor to know whether or not it owns NDArrays for disposal --- .../scala/org/apache/mxnet/Executor.scala | 27 ++++++++++++++++--- .../org/apache/mxnet/NativeResource.scala | 19 ++++++++++--- .../org/apache/mxnet/ResourceScope.scala | 5 ++++ .../org/apache/mxnet/ResourceScopeSuite.scala | 9 ++++--- 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 3e27bb00ed50..c20b542759d4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -61,6 +61,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, protected var monitorCallback: MXMonitorCallback = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) + private[mxnet] var ownsArgArrays = false + private[mxnet] var ownsGradArrays = false + private[mxnet] var ownsAuxArrays = false + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree // cannot determine the off-heap size of this object @@ -81,12 +85,14 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, if (!super.isDisposed) { super.dispose() outputs.foreach(o => o.dispose()) - if (argArrays != null) {argArrays.foreach(a => a.dispose())} - if (gradArrays != null) {gradArrays.foreach( + // Symbol.bind clones symbol when creating the executor so we need to dispose of the clone + symbol.dispose() + if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())} + if (ownsGradArrays && gradArrays != null) {gradArrays.foreach( // Symbol will sometimes fill this with nulls so we've got to check the elements too a => if (a != null) {a.dispose()}) } - if (auxArrays != null) {auxArrays.foreach(a => a.dispose())} + if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())} if (_argDict != null) {_argDict.foreach(a => a._2.dispose())} if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())} if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())} @@ -108,6 +114,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, */ def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false, kwargs: Map[String, Shape]): Executor = { + var setArgOwner = false + var setAuxOwner = false + var setGradOwner = false val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs) // TODO: more precise error message should be provided by backend require(argShapes != null, "Shape inference failed." + @@ -129,8 +138,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If you really want to up size, set allowUpSizing = true " + "to enable allocation of new arrays.") newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype)) + setArgOwner = true if (dArr != null) { newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype)) + setGradOwner = true } } else { newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray)) @@ -157,6 +168,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If you really want to up size, set allowUpSizing = true " + "to enable allocation of new arrays.") newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context)) + setAuxOwner = true } else { newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray)) } @@ -167,7 +179,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If this is intended, set partialShaping = true to suppress this warning.") } } - if (this._gradsReq.isInstanceOf[Seq[_]]) { + val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) { this.symbol.bind(this._ctx, newArgDict, newGradDict, @@ -184,6 +196,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, this._group2ctx, this) } + + // This method has created new NDArrays that will need to be managed by the new Executor + if (setArgOwner) reshapedExecutor.ownsArgArrays = true + if (setGradOwner) reshapedExecutor.ownsGradArrays = true + if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true + + reshapedExecutor } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index a657534dd279..b81419f84b0c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -112,13 +112,24 @@ private[mxnet] trait NativeResource /** * This method will move the NativeResource from it's current scope and add it to the * scope of another NativeResource. Useful for when a new resource is made internally and needs - * to be coupled to an existing resource. + * to be coupled to an existing resource. Only moves scopes down, never up since that could cause + * resources to be cleared sooner than they should be. * @param nativeResource the native resource which has the desired destination resourceScope */ private [mxnet] def moveToScopeOf(nativeResource: NativeResource): Unit = { - if (scope.isDefined) scope.get.remove(this) - scope = nativeResource.scope - if (scope.isDefined) scope.get.add(this) + if (scope.isDefined && nativeResource.scope.isDefined) { + val curScope = scope.get + val newScope = nativeResource.scope.get + if (ResourceScope.getIndexOfScope(curScope) > ResourceScope.getIndexOfScope(newScope)) { + curScope.remove(this) + scope = nativeResource.scope + newScope.add(this) + } + } else if (scope.isDefined) { + val curScope = scope.get + curScope.remove(this) + scope = nativeResource.scope + } } /* diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index bb363c0c396b..cdc846710820 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -182,6 +182,11 @@ object ResourceScope { threadLocalScopes.get() -= r } + private[mxnet] def getIndexOfScope(resourceScope: ResourceScope): Int = { + val scopes = threadLocalScopes.get() + scopes.indexOf(resourceScope) + } + /** * Get the latest Scope in the stack * @return diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 8557dfe49cab..4e80da1767e4 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -109,22 +109,23 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { var b: TestNativeResource = null var c: TestNativeResource = null var d: TestNativeResource = null - + var notinAScope: TestNativeResource = new TestNativeResource() ResourceScope.using() { a = new TestNativeResource() ResourceScope.using() { b = new TestNativeResource() - b.moveToScopeOf(a) + b.moveToScopeOf(notinAScope) ResourceScope.using() { c = new TestNativeResource() ResourceScope.using() { d = new TestNativeResource() + // Should fail to move c since d is in a higher scope c.moveToScopeOf(d) d.moveToScopeOf(a) assert(c.isDisposed == false) assert(d.isDisposed == false) } - assert(c.isDisposed == true) + assert(c.isDisposed == false) assert(d.isDisposed == false) } assert(b.isDisposed == false) @@ -137,7 +138,7 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(d.isDisposed == false) } assert(a.isDisposed == true) - assert(b.isDisposed == true) + assert(b.isDisposed == false) assert(c.isDisposed == true) assert(d.isDisposed == true) } From 57e5f630d5759b482faa51dfc93eac030e3d699f Mon Sep 17 00:00:00 2001 From: Ayres Date: Tue, 12 Mar 2019 11:08:55 -0700 Subject: [PATCH 08/14] Moving to ResourceScope.using implementation --- .../scala/org/apache/mxnet/Executor.scala | 10 -- .../org/apache/mxnet/NativeResource.scala | 23 ----- .../org/apache/mxnet/ResourceScope.scala | 11 +-- .../main/scala/org/apache/mxnet/Symbol.scala | 7 +- .../module/DataParallelExecutorGroup.scala | 10 +- .../org/apache/mxnet/optimizer/Adam.scala | 94 +++++++++---------- .../org/apache/mxnet/ResourceScopeSuite.scala | 27 ++---- 7 files changed, 66 insertions(+), 116 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index c20b542759d4..aec44023a5d3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -71,16 +71,6 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, override val bytesAllocated: Long = 0 override val ref: NativeResourceRef = super.register() - private[mxnet] def updateDepResourceScope(): Unit = { - if (argArrays != null) {argArrays.foreach(a => a.moveToScopeOf(this))} - if (gradArrays != null) {gradArrays.foreach( - // Symbol will sometimes fill this with nulls so we've got to check the elements too - a => if (a != null) {a.moveToScopeOf(this)}) - } - if (auxArrays != null) {auxArrays.foreach(a => a.moveToScopeOf(this))} - outputs.foreach(o => o.moveToScopeOf(this)) - } - override def dispose(): Unit = { if (!super.isDisposed) { super.dispose() diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index b81419f84b0c..1806b8653376 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -109,29 +109,6 @@ private[mxnet] trait NativeResource } } - /** - * This method will move the NativeResource from it's current scope and add it to the - * scope of another NativeResource. Useful for when a new resource is made internally and needs - * to be coupled to an existing resource. Only moves scopes down, never up since that could cause - * resources to be cleared sooner than they should be. - * @param nativeResource the native resource which has the desired destination resourceScope - */ - private [mxnet] def moveToScopeOf(nativeResource: NativeResource): Unit = { - if (scope.isDefined && nativeResource.scope.isDefined) { - val curScope = scope.get - val newScope = nativeResource.scope.get - if (ResourceScope.getIndexOfScope(curScope) > ResourceScope.getIndexOfScope(newScope)) { - curScope.remove(this) - scope = nativeResource.scope - newScope.add(this) - } - } else if (scope.isDefined) { - val curScope = scope.get - curScope.remove(this) - scope = nativeResource.scope - } - } - /* this is used by the WarnIfNotDisposed finalizer, the object could be disposed by the GC without the need for explicit disposal diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index cdc846710820..e3a5bc19efb0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -145,10 +145,12 @@ object ResourceScope { null.asInstanceOf[A] // we'll throw in finally } finally { var toThrow: Throwable = retThrowable - if (retThrowable eq null) curScope.close() + if (retThrowable eq null) { + if (scope == null) curScope.close() + } else { try { - curScope.close + if (scope == null) curScope.close } catch { case closeThrowable: Throwable => if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable @@ -182,11 +184,6 @@ object ResourceScope { threadLocalScopes.get() -= r } - private[mxnet] def getIndexOfScope(resourceScope: ResourceScope): Int = { - val scopes = threadLocalScopes.get() - scopes.indexOf(resourceScope) - } - /** * Get the latest Scope in the stack * @return diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index d726107b630e..84803a047432 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -810,8 +810,11 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso key -> new Context(value.deviceType, value.deviceId) } - val newSymbol = this.clone() - newSymbol.moveToScopeOf(this) + var newSymbol: Symbol = null + ResourceScope.using(this.scope.get) { + newSymbol = this.clone() + } + new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray, auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId), gradsReq, executorGroup2ctx) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index f6c56aab7b9c..c6cf23735cca 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -372,11 +372,11 @@ class DataParallelExecutorGroup private[module]( val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts)) val inputShapes = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape]) - val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes) - tmpExec.moveToScopeOf(execs(i)) - tmpExec.updateDepResourceScope() - execs(i).dispose() - execs(i) = tmpExec + ResourceScope.using(execs(i).scope.get) { + val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes) + execs(i).dispose() + execs(i) = tmpExec + } } } else { execs = (0 until contexts.length).map(i => diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala index c89352746a4f..b9e01aef9c0e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala @@ -57,56 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float = * The auxiliary state used in optimization. */ override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { - ResourceScope.using() { - var lr = - (if (lrScheduler != null) { - val scheduledLr = lrScheduler(numUpdate) - updateCount(index) - scheduledLr - } else { - this.learningRate - }) - lr = getLr(index, lr) - + ResourceScope.using(state.asInstanceOf[(NDArray, NDArray)]._1.scope.get) { val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - - // increment time only when the first parameters is called - timeFirstIndex match { - case Some(idx) => - if (idx == index) time += 1 - case None => - timeFirstIndex = Option(index) - time = 0 // all parameters share the same time - } - - val t1: Int = time + 1 - val learningRate = (lr * - math.sqrt(1.0 - math.pow(beta2, t1)) / - (1.0 - math.pow(beta1, t1))).toFloat - val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat - - var resdGrad = grad * rescaleGrad - if (clipGradient != 0f) { - val oldResdGrad = resdGrad - resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + ResourceScope.using() { + var lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) + lr = getLr(index, lr) + + // increment time only when the first parameters is called + timeFirstIndex match { + case Some(idx) => + if (idx == index) time += 1 + case None => + timeFirstIndex = Option(index) + time = 0 // all parameters share the same time + } + + val t1: Int = time + 1 + val learningRate = (lr * + math.sqrt(1.0 - math.pow(beta2, t1)) / + (1.0 - math.pow(beta1, t1))).toFloat + val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat + + var resdGrad = grad * rescaleGrad + if (clipGradient != 0f) { + val oldResdGrad = resdGrad + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + } + + val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) + val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) + val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) + + val wd = this.getWd(index, this.wd) + if (wd > 0.0f) { + val stepDelta = lr * wd * weight + step += stepDelta + } + + weight -= step + mean.set(meanT) + variance.set(varianceT) } - - val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) - val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) - val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) - - val wd = this.getWd(index, this.wd) - if (wd > 0.0f) { - val stepDelta = lr * wd * weight - step += stepDelta - } - - weight -= step - mean.set(meanT) - variance.set(varianceT) - - mean.moveToScopeOf(state.asInstanceOf[(NDArray, NDArray)]._1) - variance.moveToScopeOf(state.asInstanceOf[(NDArray, NDArray)]._2) } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 4e80da1767e4..22fc3095f758 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -108,39 +108,24 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { var a: TestNativeResource = null var b: TestNativeResource = null var c: TestNativeResource = null - var d: TestNativeResource = null - var notinAScope: TestNativeResource = new TestNativeResource() + ResourceScope.using() { a = new TestNativeResource() ResourceScope.using() { b = new TestNativeResource() - b.moveToScopeOf(notinAScope) - ResourceScope.using() { + ResourceScope.using(a.scope.get) { c = new TestNativeResource() - ResourceScope.using() { - d = new TestNativeResource() - // Should fail to move c since d is in a higher scope - c.moveToScopeOf(d) - d.moveToScopeOf(a) - assert(c.isDisposed == false) - assert(d.isDisposed == false) - } - assert(c.isDisposed == false) - assert(d.isDisposed == false) } assert(b.isDisposed == false) - assert(c.isDisposed == true) - assert(d.isDisposed == false) + assert(c.isDisposed == false) } assert(a.isDisposed == false) - assert(b.isDisposed == false) - assert(c.isDisposed == true) - assert(d.isDisposed == false) + assert(b.isDisposed == true) + assert(c.isDisposed == false) } assert(a.isDisposed == true) - assert(b.isDisposed == false) + assert(b.isDisposed == true) assert(c.isDisposed == true) - assert(d.isDisposed == true) } test(testName = "test NativeResources in returned Lists are not disposed") { From 2f84c9d1f119733c0b4eb2850d368a60546f6825 Mon Sep 17 00:00:00 2001 From: Ayres Date: Fri, 15 Mar 2019 21:45:41 -0700 Subject: [PATCH 09/14] Changes to make ResourceScope.using work with existing scope --- .../org/apache/mxnet/ResourceScope.scala | 26 ++++++++++++++++--- .../org/apache/mxnet/ResourceScopeSuite.scala | 7 ++++- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index e3a5bc19efb0..a4178c544fcf 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -47,6 +47,7 @@ class ResourceScope extends AutoCloseable { * the associated`'org.apache.mxnet.NativeResource.close()` method */ override def close(): Unit = { + print("in close\n") ResourceScope.removeFromThreadLocal(this) resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) resourceQ.clear() @@ -105,8 +106,16 @@ object ResourceScope { // TODO: we should move to the Scala util's Using method when we move to Scala 2.13 def using[A](scope: ResourceScope = null)(body: => A): A = { - val curScope = if (scope != null) scope else new ResourceScope() + val curScope = if (scope != null) { + ResourceScope.addToThreadLocal(scope) + scope + } else new ResourceScope() + if (scope != null) { + print("in non null using\n") + } else { + print("in null using\n") + } @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = { g.foreach( n => n match { @@ -146,11 +155,19 @@ object ResourceScope { } finally { var toThrow: Throwable = retThrowable if (retThrowable eq null) { - if (scope == null) curScope.close() + if (scope == null) { + curScope.close() + } else { + ResourceScope.removeFromThreadLocal(scope) + } } else { try { - if (scope == null) curScope.close + if (scope == null) { + curScope.close + } else { + ResourceScope.removeFromThreadLocal(scope) + } } catch { case closeThrowable: Throwable => if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable @@ -181,7 +198,8 @@ object ResourceScope { * @param r ResourceScope to remove */ private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = { - threadLocalScopes.get() -= r + // threadLocalScopes.get() -= r + threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r)) } /** diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 22fc3095f758..423d8a60f908 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -102,12 +102,13 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { } /** - * Goes multiple scopes deep to make sure we can move across multiple levels correctly + * Tests passing a scope to using and creating new resources within. */ test("test moving scope of native resource to scope of another") { var a: TestNativeResource = null var b: TestNativeResource = null var c: TestNativeResource = null + var d: TestNativeResource = null ResourceScope.using() { a = new TestNativeResource() @@ -115,6 +116,10 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { b = new TestNativeResource() ResourceScope.using(a.scope.get) { c = new TestNativeResource() + ResourceScope.using() { + d = new TestNativeResource() + } + assert(d.isDisposed == true) } assert(b.isDisposed == false) assert(c.isDisposed == false) From 4576cdc5936115315b427473bb86d8f4d8554acf Mon Sep 17 00:00:00 2001 From: Ayres Date: Thu, 21 Mar 2019 01:47:52 -0700 Subject: [PATCH 10/14] Updating ResourceScope to work with existing scopes via usingIfScopeExists method --- .../scala/org/apache/mxnet/Optimizer.scala | 20 ++-- .../org/apache/mxnet/ResourceScope.scala | 28 +++--- .../main/scala/org/apache/mxnet/Symbol.scala | 3 +- .../module/DataParallelExecutorGroup.scala | 3 +- .../org/apache/mxnet/optimizer/Adam.scala | 91 +++++++++---------- 5 files changed, 77 insertions(+), 68 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala index 1fb634cebb26..6e7877392081 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala @@ -28,15 +28,17 @@ object Optimizer { def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { new MXKVStoreUpdater with MXKVStoreCachedStates { override def update(index: Int, grad: NDArray, weight: NDArray): Unit = { - val state = - if (states.contains(index)) { - states.get(index).get - } else { - val newState = optimizer.createState(index, weight) - states.put(index, newState) - newState - } - optimizer.update(index, weight, grad, state) + ResourceScope.usingIfScopeExists(this.scope) { + val state = + if (states.contains(index)) { + states.get(index).get + } else { + val newState = optimizer.createState(index, weight) + states.put(index, newState) + newState + } + optimizer.update(index, weight, grad, state) + } } override def dispose(): Unit = { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index a4178c544fcf..a477fd55dec8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -47,7 +47,6 @@ class ResourceScope extends AutoCloseable { * the associated`'org.apache.mxnet.NativeResource.close()` method */ override def close(): Unit = { - print("in close\n") ResourceScope.removeFromThreadLocal(this) resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) resourceQ.clear() @@ -111,11 +110,6 @@ object ResourceScope { scope } else new ResourceScope() - if (scope != null) { - print("in non null using\n") - } else { - print("in null using\n") - } @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = { g.foreach( n => n match { @@ -140,12 +134,14 @@ object ResourceScope { try { val ret = body - ret match { + if (scope == null) { + ret match { // don't de-allocate if returning any collection that contains NativeResource. - case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) - case nRes: NativeResource => curScope.moveToOuterScope(nRes) - case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) ) - case _ => // do nothing + case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) + case nRes: NativeResource => curScope.moveToOuterScope(nRes) + case ndRet: NDArrayFuncReturn => ndRet.arr.foreach(nd => curScope.moveToOuterScope(nd)) + case _ => // do nothing + } } ret } catch { @@ -179,6 +175,16 @@ object ResourceScope { } } + private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = { + if (scope == None) { + body + } else { + ResourceScope.using(scope.get){ + body + } + } + } + // thread local Scopes private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] { override def initialValue(): ArrayBuffer[ResourceScope] = diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 84803a047432..821e04f08df2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -810,8 +810,9 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso key -> new Context(value.deviceType, value.deviceId) } + // If this is in a scope then we want to create the clone in the same scope var newSymbol: Symbol = null - ResourceScope.using(this.scope.get) { + ResourceScope.usingIfScopeExists(this.scope) { newSymbol = this.clone() } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index c6cf23735cca..74e63be3916b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -372,7 +372,8 @@ class DataParallelExecutorGroup private[module]( val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts)) val inputShapes = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape]) - ResourceScope.using(execs(i).scope.get) { + + ResourceScope.usingIfScopeExists(execs(i).scope) { val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes) execs(i).dispose() execs(i) = tmpExec diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala index b9e01aef9c0e..5a8b3cb4e94c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala @@ -57,54 +57,53 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float = * The auxiliary state used in optimization. */ override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { - ResourceScope.using(state.asInstanceOf[(NDArray, NDArray)]._1.scope.get) { + ResourceScope.using() { + var lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) + lr = getLr(index, lr) + val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - ResourceScope.using() { - var lr = - (if (lrScheduler != null) { - val scheduledLr = lrScheduler(numUpdate) - updateCount(index) - scheduledLr - } else { - this.learningRate - }) - lr = getLr(index, lr) - - // increment time only when the first parameters is called - timeFirstIndex match { - case Some(idx) => - if (idx == index) time += 1 - case None => - timeFirstIndex = Option(index) - time = 0 // all parameters share the same time - } - - val t1: Int = time + 1 - val learningRate = (lr * - math.sqrt(1.0 - math.pow(beta2, t1)) / - (1.0 - math.pow(beta1, t1))).toFloat - val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat - - var resdGrad = grad * rescaleGrad - if (clipGradient != 0f) { - val oldResdGrad = resdGrad - resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) - } - - val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) - val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) - val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) - - val wd = this.getWd(index, this.wd) - if (wd > 0.0f) { - val stepDelta = lr * wd * weight - step += stepDelta - } - - weight -= step - mean.set(meanT) - variance.set(varianceT) + + // increment time only when the first parameters is called + timeFirstIndex match { + case Some(idx) => + if (idx == index) time += 1 + case None => + timeFirstIndex = Option(index) + time = 0 // all parameters share the same time + } + + val t1: Int = time + 1 + val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) / + (1.0 - math.pow(beta1, t1))).toFloat + val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat + + var resdGrad = grad * rescaleGrad + if (clipGradient != 0f) { + val oldResdGrad = resdGrad + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) } + + val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) + val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) + val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) + + val wd = this.getWd(index, this.wd) + if (wd > 0.0f) { + val stepDelta = lr * wd * weight + step += stepDelta + } + + weight -= step + mean.set(meanT) + variance.set(varianceT) + (mean, variance) } } From 0a1cbbbf5a2e89aeef2993d584cd289c21b93ba6 Mon Sep 17 00:00:00 2001 From: Ayres Date: Thu, 21 Mar 2019 02:20:30 -0700 Subject: [PATCH 11/14] Fix clojure unit tests --- .../main/scala/org/apache/mxnet/ResourceScope.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index a477fd55dec8..a5aa33f433ea 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -134,14 +134,12 @@ object ResourceScope { try { val ret = body - if (scope == null) { - ret match { + ret match { // don't de-allocate if returning any collection that contains NativeResource. - case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) - case nRes: NativeResource => curScope.moveToOuterScope(nRes) - case ndRet: NDArrayFuncReturn => ndRet.arr.foreach(nd => curScope.moveToOuterScope(nd)) - case _ => // do nothing - } + case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) + case nRes: NativeResource => curScope.moveToOuterScope(nRes) + case ndRet: NDArrayFuncReturn => ndRet.arr.foreach(nd => curScope.moveToOuterScope(nd)) + case _ => // do nothing } ret } catch { From 5af91db610219498904601819303c765e0abc29f Mon Sep 17 00:00:00 2001 From: Ayres Date: Thu, 21 Mar 2019 13:09:01 -0700 Subject: [PATCH 12/14] Fixes to be compatibile with how clojure is using ResourceScope --- .../org/apache/mxnet/ResourceScope.scala | 24 +++++++------------ .../org/apache/mxnet/ResourceScopeSuite.scala | 2 +- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index a5aa33f433ea..2f34210d82c2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable { */ override def close(): Unit = { ResourceScope.removeFromThreadLocal(this) - resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) - resourceQ.clear() + if (!ResourceScope.threadLocalScopes.get().contains(this)) { + resourceQ.foreach(resource => if (resource != null) resource.dispose(false)) + resourceQ.clear() + } } /** @@ -105,10 +107,7 @@ object ResourceScope { // TODO: we should move to the Scala util's Using method when we move to Scala 2.13 def using[A](scope: ResourceScope = null)(body: => A): A = { - val curScope = if (scope != null) { - ResourceScope.addToThreadLocal(scope) - scope - } else new ResourceScope() + val curScope = if (scope != null) scope else new ResourceScope() @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = { g.foreach( n => @@ -149,19 +148,11 @@ object ResourceScope { } finally { var toThrow: Throwable = retThrowable if (retThrowable eq null) { - if (scope == null) { - curScope.close() - } else { - ResourceScope.removeFromThreadLocal(scope) - } + curScope.close } else { try { - if (scope == null) { - curScope.close - } else { - ResourceScope.removeFromThreadLocal(scope) - } + curScope.close } catch { case closeThrowable: Throwable => if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable @@ -177,6 +168,7 @@ object ResourceScope { if (scope == None) { body } else { + ResourceScope.addToThreadLocal(scope.get) ResourceScope.using(scope.get){ body } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 423d8a60f908..a6bb43d3bf78 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -114,7 +114,7 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { a = new TestNativeResource() ResourceScope.using() { b = new TestNativeResource() - ResourceScope.using(a.scope.get) { + ResourceScope.usingIfScopeExists(a.scope) { c = new TestNativeResource() ResourceScope.using() { d = new TestNativeResource() From 33c113cc361881401156788a23dd4787f556ac39 Mon Sep 17 00:00:00 2001 From: Ayres Date: Thu, 21 Mar 2019 14:27:38 -0700 Subject: [PATCH 13/14] Removing some unnecessary changes --- .../src/main/scala/org/apache/mxnet/ResourceScope.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 2f34210d82c2..b955c185b6d1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -137,7 +137,7 @@ object ResourceScope { // don't de-allocate if returning any collection that contains NativeResource. case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) case nRes: NativeResource => curScope.moveToOuterScope(nRes) - case ndRet: NDArrayFuncReturn => ndRet.arr.foreach(nd => curScope.moveToOuterScope(nd)) + case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) ) case _ => // do nothing } ret @@ -147,9 +147,7 @@ object ResourceScope { null.asInstanceOf[A] // we'll throw in finally } finally { var toThrow: Throwable = retThrowable - if (retThrowable eq null) { - curScope.close - } + if (retThrowable eq null) curScope.close else { try { curScope.close @@ -194,7 +192,6 @@ object ResourceScope { * @param r ResourceScope to remove */ private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = { - // threadLocalScopes.get() -= r threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r)) } From 232d33adcee8a801a1f69baa46d8e1b77ccf9002 Mon Sep 17 00:00:00 2001 From: Ayres Date: Wed, 27 Mar 2019 14:58:07 -0700 Subject: [PATCH 14/14] Adding scope assertion in unit test --- .../src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index a6bb43d3bf78..19162385f0f7 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -118,6 +118,7 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { c = new TestNativeResource() ResourceScope.using() { d = new TestNativeResource() + assert(c.scope == a.scope) } assert(d.isDisposed == true) }