Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.4.x][Cherry-Pick] Memory fixes. Resolves #10867, and resolves #14080 (#14372) #14586

Merged
merged 1 commit into from
Apr 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,47 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
private[mxnet] val symbol: Symbol,
private[mxnet] var argArrays: Array[NDArray] = null,
private[mxnet] var gradArrays: Array[NDArray] = null,
private[mxnet] var auxArrays: Array[NDArray] = null,
private var _ctx: Context = null,
private var _gradsReq: Iterable[_] = null,
private var _group2ctx: Map[String, Context] = null
) extends NativeResource {

val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _gradDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
private[mxnet] var _ctx: Context = null
private[mxnet] var _gradsReq: Iterable[_] = null
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private[mxnet] var ownsArgArrays = false
private[mxnet] var ownsGradArrays = false
private[mxnet] var ownsAuxArrays = false

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()

override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
// Symbol.bind clones symbol when creating the executor so we need to dispose of the clone
symbol.dispose()
if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())}
if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
// Symbol will sometimes fill this with nulls so we've got to check the elements too
a => if (a != null) {a.dispose()})
}
if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())}
if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
}
}

Expand All @@ -86,6 +104,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
var setArgOwner = false
var setAuxOwner = false
var setGradOwner = false
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
// TODO: more precise error message should be provided by backend
require(argShapes != null, "Shape inference failed." +
Expand All @@ -107,8 +128,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
setArgOwner = true
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
setGradOwner = true
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
Expand All @@ -135,6 +158,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
setAuxOwner = true
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
Expand All @@ -145,7 +169,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
if (this._gradsReq.isInstanceOf[Seq[_]]) {
val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
Expand All @@ -162,6 +186,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
this._group2ctx,
this)
}

// This method has created new NDArrays that will need to be managed by the new Executor
if (setArgOwner) reshapedExecutor.ownsArgArrays = true
if (setGradOwner) reshapedExecutor.ownsGradArrays = true
if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true

reshapedExecutor
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater with MXKVStoreCachedStates {
override def update(index: Int, grad: NDArray, weight: NDArray): Unit = {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
ResourceScope.usingIfScopeExists(this.scope) {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
}
}

override def dispose(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable {
*/
override def close(): Unit = {
ResourceScope.removeFromThreadLocal(this)
resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
resourceQ.clear()
if (!ResourceScope.threadLocalScopes.get().contains(this)) {
resourceQ.foreach(resource => if (resource != null) resource.dispose(false))
resourceQ.clear()
}
}

/**
Expand Down Expand Up @@ -145,7 +147,7 @@ object ResourceScope {
null.asInstanceOf[A] // we'll throw in finally
} finally {
var toThrow: Throwable = retThrowable
if (retThrowable eq null) curScope.close()
if (retThrowable eq null) curScope.close
else {
try {
curScope.close
Expand All @@ -160,6 +162,17 @@ object ResourceScope {
}
}

private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = {
if (scope == None) {
body
} else {
ResourceScope.addToThreadLocal(scope.get)
ResourceScope.using(scope.get){
body
}
}
}

// thread local Scopes
private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
override def initialValue(): ArrayBuffer[ResourceScope] =
Expand All @@ -179,7 +192,7 @@ object ResourceScope {
* @param r ResourceScope to remove
*/
private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
threadLocalScopes.get() -= r
threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r))
}

/**
Expand Down
21 changes: 13 additions & 8 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -803,18 +803,23 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
auxArgsHandle,
sharedHandle,
execHandle))
val executor = new Executor(execHandle.value, this.clone())
executor.argArrays = argsNDArray
executor.gradArrays = argsGradNDArray
executor.auxArrays = auxStatesNDArray
executor._ctx = new Context(ctx.deviceType, ctx.deviceId)
executor._gradsReq = gradsReq
executor._group2ctx =

val executorGroup2ctx =
if (group2ctx == null) null
else group2ctx.map { case (key, value) =>
key -> new Context(value.deviceType, value.deviceId)
}
executor

// If this is in a scope then we want to create the clone in the same scope
var newSymbol: Symbol = null
ResourceScope.usingIfScopeExists(this.scope) {
newSymbol = this.clone()
}

new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray,
auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId),
gradsReq, executorGroup2ctx)

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module](

private var batchSize: Int = -1
private var slices: Array[(Int, Int)] = null
private var _defaultExecs: Array[Executor] = null
private var execs: Array[Executor] = null
private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
Expand Down Expand Up @@ -373,7 +372,12 @@ class DataParallelExecutorGroup private[module](
val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
val inputShapes
= dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)

ResourceScope.usingIfScopeExists(execs(i).scope) {
val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
execs(i).dispose()
execs(i) = tmpExec
}
}
} else {
execs = (0 until contexts.length).map(i =>
Expand Down Expand Up @@ -434,9 +438,6 @@ class DataParallelExecutorGroup private[module](
*/
def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
if (this._defaultExecs == null) {
this._defaultExecs = this.execs.map(x => x)
}
this.bindExec(dataShapes, labelShapes, None, reshape = true)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.mxnet.optimizer

import org.apache.mxnet.NDArrayConversions._
import org.apache.mxnet.util.SerializerUtils
import org.apache.mxnet.{LRScheduler, NDArray, Optimizer}
import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope}

/**
* Adam optimizer as described in [King2014]
Expand Down Expand Up @@ -57,63 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float =
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
var lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
})
lr = getLr(index, lr)

val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

// increment time only when the first parameters is called
timeFirstIndex match {
case Some(idx) =>
if (idx == index) time += 1
case None =>
timeFirstIndex = Option(index)
time = 0 // all parameters share the same time
}

val t1: Int = time + 1
val learningRate = (lr *
math.sqrt(1.0 - math.pow(beta2, t1)) /
(1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
oldResdGrad.dispose()
}

val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
.disposeDepsExcept(mean, resdGrad)
val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
.disposeDepsExcept(variance, resdGrad)
ResourceScope.using() {
var lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
})
lr = getLr(index, lr)

val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
.disposeDepsExcept(meanT, varianceT)
val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

val wd = this.getWd(index, this.wd)
if (wd > 0.0f) {
val stepDelta = lr * wd * weight
step += stepDelta
stepDelta.dispose()
// increment time only when the first parameters is called
timeFirstIndex match {
case Some(idx) =>
if (idx == index) time += 1
case None =>
timeFirstIndex = Option(index)
time = 0 // all parameters share the same time
}

val t1: Int = time + 1
val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) /
(1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
}

val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))

val wd = this.getWd(index, this.wd)
if (wd > 0.0f) {
val stepDelta = lr * wd * weight
step += stepDelta
}

weight -= step
mean.set(meanT)
variance.set(varianceT)
(mean, variance)
}

weight -= step
mean.set(meanT)
variance.set(varianceT)

meanT.dispose()
varianceT.dispose()
step.dispose()
resdGrad.dispose()
}

// Create additional optimizer state: mean, variance
Expand Down
Loading