From d8c63d831c7819cf4379c63f369f02fb2443e6e9 Mon Sep 17 00:00:00 2001 From: Andrew Ayres Date: Thu, 28 Mar 2019 11:57:09 -0700 Subject: [PATCH] Memory fixes. Resolves #10867, and resolves #14080 (#14372) * Fixes for memory leak when reshaping executor * Fixed Adam Optimizer memory leak * Cleanup for PR * Added unit test for new ResourceScope method * Removing import that was added by overzealous ide * Add back in an import * Added flags for executor to know whether or not it owns NDArrays for disposal * Moving to ResourceScope.using implementation * Changes to make ResourceScope.using work with existing scope * Updating ResourceScope to work with existing scopes via usingIfScopeExists method * Fix clojure unit tests * Fixes to be compatibile with how clojure is using ResourceScope * Removing some unnecessary changes * Adding scope assertion in unit test --- .../scala/org/apache/mxnet/Executor.scala | 47 ++++++-- .../scala/org/apache/mxnet/Optimizer.scala | 20 ++-- .../org/apache/mxnet/ResourceScope.scala | 21 +++- .../main/scala/org/apache/mxnet/Symbol.scala | 21 ++-- .../module/DataParallelExecutorGroup.scala | 11 +- .../org/apache/mxnet/optimizer/Adam.scala | 101 ++++++++---------- .../org/apache/mxnet/ResourceScopeSuite.scala | 33 ++++++ 7 files changed, 165 insertions(+), 89 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 85f45bc66fe0..aec44023a5d3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -45,29 +45,47 @@ object Executor { * @see Symbol.bind : to create executor */ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, - private[mxnet] val symbol: Symbol) extends NativeResource { - private[mxnet] var argArrays: Array[NDArray] = null - private[mxnet] var gradArrays: Array[NDArray] = null - private[mxnet] var auxArrays: Array[NDArray] = null + private[mxnet] val symbol: Symbol, + private[mxnet] var argArrays: Array[NDArray] = null, + private[mxnet] var gradArrays: Array[NDArray] = null, + private[mxnet] var auxArrays: Array[NDArray] = null, + private var _ctx: Context = null, + private var _gradsReq: Iterable[_] = null, + private var _group2ctx: Map[String, Context] = null + ) extends NativeResource { + val outputs: Array[NDArray] = getOutputs protected var _argDict: Map[String, NDArray] = null protected var _gradDict: Map[String, NDArray] = null protected var _auxDict: Map[String, NDArray] = null protected var monitorCallback: MXMonitorCallback = null - private[mxnet] var _ctx: Context = null - private[mxnet] var _gradsReq: Iterable[_] = null - private[mxnet] var _group2ctx: Map[String, Context] = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) + private[mxnet] var ownsArgArrays = false + private[mxnet] var ownsGradArrays = false + private[mxnet] var ownsAuxArrays = false + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree // cannot determine the off-heap size of this object override val bytesAllocated: Long = 0 override val ref: NativeResourceRef = super.register() + override def dispose(): Unit = { if (!super.isDisposed) { super.dispose() outputs.foreach(o => o.dispose()) + // Symbol.bind clones symbol when creating the executor so we need to dispose of the clone + symbol.dispose() + if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())} + if (ownsGradArrays && gradArrays != null) {gradArrays.foreach( + // Symbol will sometimes fill this with nulls so we've got to check the elements too + a => if (a != null) {a.dispose()}) + } + if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())} + if (_argDict != null) {_argDict.foreach(a => a._2.dispose())} + if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())} + if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())} } } @@ -86,6 +104,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, */ def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false, kwargs: Map[String, Shape]): Executor = { + var setArgOwner = false + var setAuxOwner = false + var setGradOwner = false val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs) // TODO: more precise error message should be provided by backend require(argShapes != null, "Shape inference failed." + @@ -107,8 +128,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If you really want to up size, set allowUpSizing = true " + "to enable allocation of new arrays.") newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype)) + setArgOwner = true if (dArr != null) { newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype)) + setGradOwner = true } } else { newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray)) @@ -135,6 +158,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If you really want to up size, set allowUpSizing = true " + "to enable allocation of new arrays.") newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context)) + setAuxOwner = true } else { newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray)) } @@ -145,7 +169,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, "If this is intended, set partialShaping = true to suppress this warning.") } } - if (this._gradsReq.isInstanceOf[Seq[_]]) { + val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) { this.symbol.bind(this._ctx, newArgDict, newGradDict, @@ -162,6 +186,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, this._group2ctx, this) } + + // This method has created new NDArrays that will need to be managed by the new Executor + if (setArgOwner) reshapedExecutor.ownsArgArrays = true + if (setGradOwner) reshapedExecutor.ownsGradArrays = true + if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true + + reshapedExecutor } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala index 1fb634cebb26..6e7877392081 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala @@ -28,15 +28,17 @@ object Optimizer { def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { new MXKVStoreUpdater with MXKVStoreCachedStates { override def update(index: Int, grad: NDArray, weight: NDArray): Unit = { - val state = - if (states.contains(index)) { - states.get(index).get - } else { - val newState = optimizer.createState(index, weight) - states.put(index, newState) - newState - } - optimizer.update(index, weight, grad, state) + ResourceScope.usingIfScopeExists(this.scope) { + val state = + if (states.contains(index)) { + states.get(index).get + } else { + val newState = optimizer.createState(index, weight) + states.put(index, newState) + newState + } + optimizer.update(index, weight, grad, state) + } } override def dispose(): Unit = { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index bb363c0c396b..b955c185b6d1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable { */ override def close(): Unit = { ResourceScope.removeFromThreadLocal(this) - resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) - resourceQ.clear() + if (!ResourceScope.threadLocalScopes.get().contains(this)) { + resourceQ.foreach(resource => if (resource != null) resource.dispose(false)) + resourceQ.clear() + } } /** @@ -145,7 +147,7 @@ object ResourceScope { null.asInstanceOf[A] // we'll throw in finally } finally { var toThrow: Throwable = retThrowable - if (retThrowable eq null) curScope.close() + if (retThrowable eq null) curScope.close else { try { curScope.close @@ -160,6 +162,17 @@ object ResourceScope { } } + private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = { + if (scope == None) { + body + } else { + ResourceScope.addToThreadLocal(scope.get) + ResourceScope.using(scope.get){ + body + } + } + } + // thread local Scopes private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] { override def initialValue(): ArrayBuffer[ResourceScope] = @@ -179,7 +192,7 @@ object ResourceScope { * @param r ResourceScope to remove */ private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = { - threadLocalScopes.get() -= r + threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r)) } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 01349a689b6c..bfd8268f0b90 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -803,18 +803,23 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso auxArgsHandle, sharedHandle, execHandle)) - val executor = new Executor(execHandle.value, this.clone()) - executor.argArrays = argsNDArray - executor.gradArrays = argsGradNDArray - executor.auxArrays = auxStatesNDArray - executor._ctx = new Context(ctx.deviceType, ctx.deviceId) - executor._gradsReq = gradsReq - executor._group2ctx = + + val executorGroup2ctx = if (group2ctx == null) null else group2ctx.map { case (key, value) => key -> new Context(value.deviceType, value.deviceId) } - executor + + // If this is in a scope then we want to create the clone in the same scope + var newSymbol: Symbol = null + ResourceScope.usingIfScopeExists(this.scope) { + newSymbol = this.clone() + } + + new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray, + auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId), + gradsReq, executorGroup2ctx) + } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala index df66ea7721fb..74e63be3916b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala @@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module]( private var batchSize: Int = -1 private var slices: Array[(Int, Int)] = null - private var _defaultExecs: Array[Executor] = null private var execs: Array[Executor] = null private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None @@ -373,7 +372,12 @@ class DataParallelExecutorGroup private[module]( val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts)) val inputShapes = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape]) - execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes) + + ResourceScope.usingIfScopeExists(execs(i).scope) { + val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes) + execs(i).dispose() + execs(i) = tmpExec + } } } else { execs = (0 until contexts.length).map(i => @@ -434,9 +438,6 @@ class DataParallelExecutorGroup private[module]( */ def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = { if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) { - if (this._defaultExecs == null) { - this._defaultExecs = this.execs.map(x => x) - } this.bindExec(dataShapes, labelShapes, None, reshape = true) } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala index 24f3323073f7..5a8b3cb4e94c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/Adam.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.optimizer import org.apache.mxnet.NDArrayConversions._ import org.apache.mxnet.util.SerializerUtils -import org.apache.mxnet.{LRScheduler, NDArray, Optimizer} +import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope} /** * Adam optimizer as described in [King2014] @@ -57,63 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float = * The auxiliary state used in optimization. */ override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { - var lr = - (if (lrScheduler != null) { - val scheduledLr = lrScheduler(numUpdate) - updateCount(index) - scheduledLr - } else { - this.learningRate - }) - lr = getLr(index, lr) - - val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - - // increment time only when the first parameters is called - timeFirstIndex match { - case Some(idx) => - if (idx == index) time += 1 - case None => - timeFirstIndex = Option(index) - time = 0 // all parameters share the same time - } - - val t1: Int = time + 1 - val learningRate = (lr * - math.sqrt(1.0 - math.pow(beta2, t1)) / - (1.0 - math.pow(beta1, t1))).toFloat - val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat - - var resdGrad = grad * rescaleGrad - if (clipGradient != 0f) { - val oldResdGrad = resdGrad - resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) - oldResdGrad.dispose() - } - - val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) - .disposeDepsExcept(mean, resdGrad) - val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) - .disposeDepsExcept(variance, resdGrad) + ResourceScope.using() { + var lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) + lr = getLr(index, lr) - val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) - .disposeDepsExcept(meanT, varianceT) + val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)] - val wd = this.getWd(index, this.wd) - if (wd > 0.0f) { - val stepDelta = lr * wd * weight - step += stepDelta - stepDelta.dispose() + // increment time only when the first parameters is called + timeFirstIndex match { + case Some(idx) => + if (idx == index) time += 1 + case None => + timeFirstIndex = Option(index) + time = 0 // all parameters share the same time + } + + val t1: Int = time + 1 + val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) / + (1.0 - math.pow(beta1, t1))).toFloat + val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat + + var resdGrad = grad * rescaleGrad + if (clipGradient != 0f) { + val oldResdGrad = resdGrad + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + } + + val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad) + val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad) + val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon)) + + val wd = this.getWd(index, this.wd) + if (wd > 0.0f) { + val stepDelta = lr * wd * weight + step += stepDelta + } + + weight -= step + mean.set(meanT) + variance.set(varianceT) + (mean, variance) } - - weight -= step - mean.set(meanT) - variance.set(varianceT) - - meanT.dispose() - varianceT.dispose() - step.dispose() - resdGrad.dispose() } // Create additional optimizer state: mean, variance diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 41dfa7d0ead2..19162385f0f7 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -101,6 +101,39 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(a.isDisposed == true, "returned object should be disposed in the outer scope") } + /** + * Tests passing a scope to using and creating new resources within. + */ + test("test moving scope of native resource to scope of another") { + var a: TestNativeResource = null + var b: TestNativeResource = null + var c: TestNativeResource = null + var d: TestNativeResource = null + + ResourceScope.using() { + a = new TestNativeResource() + ResourceScope.using() { + b = new TestNativeResource() + ResourceScope.usingIfScopeExists(a.scope) { + c = new TestNativeResource() + ResourceScope.using() { + d = new TestNativeResource() + assert(c.scope == a.scope) + } + assert(d.isDisposed == true) + } + assert(b.isDisposed == false) + assert(c.isDisposed == false) + } + assert(a.isDisposed == false) + assert(b.isDisposed == true) + assert(c.isDisposed == false) + } + assert(a.isDisposed == true) + assert(b.isDisposed == true) + assert(c.isDisposed == true) + } + test(testName = "test NativeResources in returned Lists are not disposed") { var ndListRet: IndexedSeq[TestNativeResource] = null ResourceScope.using() {