Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Memory fixes. Resolves #10867, and resolves #14080 #14372

Merged
merged 14 commits into from
Mar 28, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,47 @@ object Executor {
* @see Symbol.bind : to create executor
*/
class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
private[mxnet] val symbol: Symbol) extends NativeResource {
private[mxnet] var argArrays: Array[NDArray] = null
private[mxnet] var gradArrays: Array[NDArray] = null
private[mxnet] var auxArrays: Array[NDArray] = null
private[mxnet] val symbol: Symbol,
private[mxnet] var argArrays: Array[NDArray] = null,
private[mxnet] var gradArrays: Array[NDArray] = null,
private[mxnet] var auxArrays: Array[NDArray] = null,
private var _ctx: Context = null,
private var _gradsReq: Iterable[_] = null,
private var _group2ctx: Map[String, Context] = null
) extends NativeResource {

val outputs: Array[NDArray] = getOutputs
protected var _argDict: Map[String, NDArray] = null
protected var _gradDict: Map[String, NDArray] = null
protected var _auxDict: Map[String, NDArray] = null
protected var monitorCallback: MXMonitorCallback = null
private[mxnet] var _ctx: Context = null
private[mxnet] var _gradsReq: Iterable[_] = null
private[mxnet] var _group2ctx: Map[String, Context] = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private[mxnet] var ownsArgArrays = false
private[mxnet] var ownsGradArrays = false
private[mxnet] var ownsAuxArrays = false

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
override val bytesAllocated: Long = 0
override val ref: NativeResourceRef = super.register()

override def dispose(): Unit = {
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
// Symbol.bind clones symbol when creating the executor so we need to dispose of the clone
symbol.dispose()
if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())}
if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
// Symbol will sometimes fill this with nulls so we've got to check the elements too
a => if (a != null) {a.dispose()})
}
if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())}
if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
}
}

Expand All @@ -86,6 +104,9 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
var setArgOwner = false
var setAuxOwner = false
var setGradOwner = false
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
// TODO: more precise error message should be provided by backend
require(argShapes != null, "Shape inference failed." +
Expand All @@ -107,8 +128,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
setArgOwner = true
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
setGradOwner = true
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
Expand All @@ -135,6 +158,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
setAuxOwner = true
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
Expand All @@ -145,7 +169,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
if (this._gradsReq.isInstanceOf[Seq[_]]) {
val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
Expand All @@ -162,6 +186,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
this._group2ctx,
this)
}

// This method has created new NDArrays that will need to be managed by the new Executor
if (setArgOwner) reshapedExecutor.ownsArgArrays = true
if (setGradOwner) reshapedExecutor.ownsGradArrays = true
if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true

reshapedExecutor
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater with MXKVStoreCachedStates {
override def update(index: Int, grad: NDArray, weight: NDArray): Unit = {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
ResourceScope.usingIfScopeExists(this.scope) {
val state =
if (states.contains(index)) {
states.get(index).get
} else {
val newState = optimizer.createState(index, weight)
states.put(index, newState)
newState
}
optimizer.update(index, weight, grad, state)
}
}

override def dispose(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class ResourceScope extends AutoCloseable {
*/
override def close(): Unit = {
ResourceScope.removeFromThreadLocal(this)
resourceQ.foreach(resource => if (resource != null) resource.dispose(false) )
resourceQ.clear()
if (!ResourceScope.threadLocalScopes.get().contains(this)) {
andrewfayres marked this conversation as resolved.
Show resolved Hide resolved
resourceQ.foreach(resource => if (resource != null) resource.dispose(false))
resourceQ.clear()
}
}

/**
Expand Down Expand Up @@ -145,7 +147,7 @@ object ResourceScope {
null.asInstanceOf[A] // we'll throw in finally
} finally {
var toThrow: Throwable = retThrowable
if (retThrowable eq null) curScope.close()
if (retThrowable eq null) curScope.close
else {
try {
curScope.close
Expand All @@ -160,6 +162,17 @@ object ResourceScope {
}
}

private[mxnet] def usingIfScopeExists[A](scope: Option[ResourceScope])(body: => A): A = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this new method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah so originally I wanted to use the .using method directly but ran into some problems. There are times when we need to add new native resources to the same scope of the parent but there's no guarantee that the parent is actually in an existing scope.

When this happens it leads to a few issues. First, the parent is in None scope which we cannot pass to ResourceScope.using. Second, we still want to execute the body and allocate all the new resources but don't want to want to make a new ResourceScope because all those new resources will disappear with the new scope.

Alternatives that I thought of were: 1.) to have the caller check whether or not it was in a scope and handle it appropriately. This is ugly and puts the onus on the callers in multiple places. 2.) Changing the using method to work with a None scope. I opted not to do this because it complicates that method and I believe would require changing the method parameters which I didn't want to do. 3.) Changing the default scope to be something other than None. This is probably a reasonable solution. Maybe we have some kind of base scope or something similar. That's likely to be a fairly significant change in both the design and behavior of this class.

if (scope == None) {
body
} else {
ResourceScope.addToThreadLocal(scope.get)
andrewfayres marked this conversation as resolved.
Show resolved Hide resolved
ResourceScope.using(scope.get){
body
}
}
}

// thread local Scopes
private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] {
override def initialValue(): ArrayBuffer[ResourceScope] =
Expand All @@ -179,7 +192,7 @@ object ResourceScope {
* @param r ResourceScope to remove
*/
private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
threadLocalScopes.get() -= r
threadLocalScopes.get().remove(threadLocalScopes.get().lastIndexOf(r))
andrewfayres marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
21 changes: 13 additions & 8 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -803,18 +803,23 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
auxArgsHandle,
sharedHandle,
execHandle))
val executor = new Executor(execHandle.value, this.clone())
executor.argArrays = argsNDArray
executor.gradArrays = argsGradNDArray
executor.auxArrays = auxStatesNDArray
executor._ctx = new Context(ctx.deviceType, ctx.deviceId)
executor._gradsReq = gradsReq
executor._group2ctx =

val executorGroup2ctx =
if (group2ctx == null) null
else group2ctx.map { case (key, value) =>
key -> new Context(value.deviceType, value.deviceId)
}
executor

// If this is in a scope then we want to create the clone in the same scope
var newSymbol: Symbol = null
ResourceScope.usingIfScopeExists(this.scope) {
newSymbol = this.clone()
}

new Executor(execHandle.value, newSymbol, argsNDArray, argsGradNDArray,
auxStatesNDArray, new Context(ctx.deviceType, ctx.deviceId),
gradsReq, executorGroup2ctx)

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ class DataParallelExecutorGroup private[module](

private var batchSize: Int = -1
private var slices: Array[(Int, Int)] = null
private var _defaultExecs: Array[Executor] = null
private var execs: Array[Executor] = null
private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
Expand Down Expand Up @@ -373,7 +372,12 @@ class DataParallelExecutorGroup private[module](
val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
val inputShapes
= dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
execs(i) = _defaultExecs(i).reshape(allowUpSizing = true, kwargs = inputShapes)

ResourceScope.usingIfScopeExists(execs(i).scope) {
val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
execs(i).dispose()
execs(i) = tmpExec
}
}
} else {
execs = (0 until contexts.length).map(i =>
Expand Down Expand Up @@ -434,9 +438,6 @@ class DataParallelExecutorGroup private[module](
*/
def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
if (this._defaultExecs == null) {
this._defaultExecs = this.execs.map(x => x)
}
this.bindExec(dataShapes, labelShapes, None, reshape = true)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.mxnet.optimizer

import org.apache.mxnet.NDArrayConversions._
import org.apache.mxnet.util.SerializerUtils
import org.apache.mxnet.{LRScheduler, NDArray, Optimizer}
import org.apache.mxnet.{LRScheduler, NDArray, Optimizer, ResourceScope}

/**
* Adam optimizer as described in [King2014]
Expand Down Expand Up @@ -57,63 +57,54 @@ class Adam(val learningRate: Float = 0.002f, beta1: Float = 0.9f, beta2: Float =
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
var lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
})
lr = getLr(index, lr)

val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

// increment time only when the first parameters is called
timeFirstIndex match {
case Some(idx) =>
if (idx == index) time += 1
case None =>
timeFirstIndex = Option(index)
time = 0 // all parameters share the same time
}

val t1: Int = time + 1
val learningRate = (lr *
math.sqrt(1.0 - math.pow(beta2, t1)) /
(1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
oldResdGrad.dispose()
}

val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
.disposeDepsExcept(mean, resdGrad)
val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
.disposeDepsExcept(variance, resdGrad)
ResourceScope.using() {
var lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
})
lr = getLr(index, lr)

val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))
.disposeDepsExcept(meanT, varianceT)
val (mean, variance) = state.asInstanceOf[(NDArray, NDArray)]

val wd = this.getWd(index, this.wd)
if (wd > 0.0f) {
val stepDelta = lr * wd * weight
step += stepDelta
stepDelta.dispose()
// increment time only when the first parameters is called
timeFirstIndex match {
case Some(idx) =>
if (idx == index) time += 1
case None =>
timeFirstIndex = Option(index)
time = 0 // all parameters share the same time
}

val t1: Int = time + 1
val learningRate = (lr * math.sqrt(1.0 - math.pow(beta2, t1)) /
(1.0 - math.pow(beta1, t1))).toFloat
val beta1t = beta1 * math.pow(decayFactor, t1 - 1).toFloat

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
}

val meanT = (beta1t * mean + (1.0 - beta1t) * resdGrad)
val varianceT = (beta2 * variance + (1.0f - beta2) * resdGrad * resdGrad)
val step = (learningRate * meanT / (NDArray.sqrt(varianceT) + epsilon))

val wd = this.getWd(index, this.wd)
if (wd > 0.0f) {
val stepDelta = lr * wd * weight
step += stepDelta
}

weight -= step
mean.set(meanT)
variance.set(varianceT)
(mean, variance)
}

weight -= step
mean.set(meanT)
variance.set(varianceT)

meanT.dispose()
varianceT.dispose()
step.dispose()
resdGrad.dispose()
}

// Create additional optimizer state: mean, variance
Expand Down
Loading