Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-729] Use NDArrayCollector to fix memory leaks in Scala Examples (
Browse files Browse the repository at this point in the history
#12232)

* initial fix for RNN

* add CI test

* ignore the test due to memory leaks

* release the GAN beast

* enable rnn

* add collector and dispose

* revert the hacky thing after rebase

* rename with inference

* add collector in some examples

* add experimental tag and comments

* change the scope of the NDArrayCollector

* apply final changes...

* fix scalastyle
  • Loading branch information
lanking520 authored and nswamy committed Aug 23, 2018
1 parent 08f1c2d commit 2f177d8
Show file tree
Hide file tree
Showing 22 changed files with 808 additions and 781 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.mxnet

import org.apache.mxnet.Base.CPtrAddress
import org.apache.mxnet.annotation.Experimental
import org.slf4j.LoggerFactory

import scala.annotation.varargs
Expand Down Expand Up @@ -80,6 +81,7 @@ object NDArrayCollector {
* Create a collector allows users to later dispose the collected NDArray manually.
* @return a manually-disposable collector.
*/
@Experimental
def manual(): NDArrayCollector = new NDArrayCollector(false)

/**
Expand Down Expand Up @@ -135,6 +137,7 @@ class NDArrayCollector private(private val autoDispose: Boolean = true,
* @tparam T return type of the function <em>codeBlock</em>.
* @return The result of function <em>codeBlock</em>.
*/
@Experimental
def withScope[T](codeBlock: => T): T = {
val old = NDArrayCollector.currCollector.get()
NDArrayCollector.currCollector.set(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package org.apache.mxnet.annotation

import java.lang.annotation.{ElementType, Retention, Target, _}

/**
* Experimental: there is a comparably high chance that
* the API will undergo some kind of changes
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(Array(ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.mxnetexamples.cnntextclassification

import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnet.{Context, Executor, Model, NDArray, Optimizer, Shape, Symbol, Uniform}
import org.apache.mxnet.{Context, Executor, Model, NDArray, NDArrayCollector, Optimizer, Shape, Symbol, Uniform}
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -131,56 +131,58 @@ object CNNTextClassification {
numTotal = 0f
updateRate = 0

for (begin <- 0 until trainBatches.length by batchSize) {
val (batchD, batchL) = {
if (begin + batchSize <= trainBatches.length) {
val datas = trainBatches.drop(begin).take(batchSize)
val labels = trainLabels.drop(begin).take(batchSize)
(datas, labels)
} else {
val right = (begin + batchSize) - trainBatches.length
val left = trainBatches.length - begin
val datas = trainBatches.drop(begin).take(left) ++ trainBatches.take(right)
val labels = trainLabels.drop(begin).take(left) ++ trainLabels.take(right)
(datas, labels)
NDArrayCollector.auto().withScope {
for (begin <- 0 until trainBatches.length by batchSize) {
val (batchD, batchL) = {
if (begin + batchSize <= trainBatches.length) {
val datas = trainBatches.drop(begin).take(batchSize)
val labels = trainLabels.drop(begin).take(batchSize)
(datas, labels)
} else {
val right = (begin + batchSize) - trainBatches.length
val left = trainBatches.length - begin
val datas = trainBatches.drop(begin).take(left) ++ trainBatches.take(right)
val labels = trainLabels.drop(begin).take(left) ++ trainLabels.take(right)
(datas, labels)
}
}
numTotal += batchSize
model.data.set(batchD.flatten.flatten)
model.label.set(batchL)

model.cnnExec.forward(isTrain = true)
model.cnnExec.backward()

val tmpCorrect = {
val predLabel = NDArray.api.argmax_channel(model.cnnExec.outputs(0))
val result = predLabel.toArray.zip(batchL).map { predLabel =>
if (predLabel._1 == predLabel._2) 1
else 0
}.sum.toFloat
predLabel.dispose()
result
}
}
numTotal += batchSize
model.data.set(batchD.flatten.flatten)
model.label.set(batchL)

model.cnnExec.forward(isTrain = true)
model.cnnExec.backward()

val tmpCorrect = {
val predLabel = NDArray.api.argmax_channel(model.cnnExec.outputs(0))
val result = predLabel.toArray.zip(batchL).map { predLabel =>
if (predLabel._1 == predLabel._2) 1
else 0
}.sum.toFloat
predLabel.dispose()
result
}

numCorrect = numCorrect + tmpCorrect
val norm = Math.sqrt(paramBlocks.map { case (idx, weight, grad, state, name) =>
val temp = NDArray.api.norm(grad / batchSize).disposeDepsExcept(grad)
val l2Norm = temp.toScalar
temp.dispose()
l2Norm * l2Norm
}.sum).toFloat

if (updateRate % 2 == 0) {
paramBlocks.foreach { case (idx, weight, grad, state, name) =>
if (norm > maxGradNorm) {
grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
opt.update(idx, weight, grad, state)
numCorrect = numCorrect + tmpCorrect
val norm = Math.sqrt(paramBlocks.map { case (idx, weight, grad, state, name) =>
val temp = NDArray.api.norm(grad / batchSize).disposeDepsExcept(grad)
val l2Norm = temp.toScalar
temp.dispose()
l2Norm * l2Norm
}.sum).toFloat

if (updateRate % 2 == 0) {
paramBlocks.foreach { case (idx, weight, grad, state, name) =>
if (norm > maxGradNorm) {
grad.set(grad.toArray.map(_ * (maxGradNorm / norm)))
opt.update(idx, weight, grad, state)
}
else opt.update(idx, weight, grad, state)
grad.set(0f)
}
else opt.update(idx, weight, grad, state)
grad.set(0f)
}
updateRate = updateRate + 1
}
updateRate = updateRate + 1
}

// decay learning rate
Expand Down Expand Up @@ -237,30 +239,33 @@ object CNNTextClassification {

def test(w2vFilePath : String, mrDatasetPath: String,
ctx : Context, saveModelPath: String) : Float = {
val (numEmbed, word2vec) = DataHelper.loadGoogleModel(w2vFilePath)
val (datas, labels) = DataHelper.loadMSDataWithWord2vec(
mrDatasetPath, numEmbed, word2vec)
// randomly shuffle data
val randIdx = Random.shuffle((0 until datas.length).toList)
// split train/dev set
val (trainDats, devDatas) = {
val train = randIdx.dropRight(1000).map(datas(_)).toArray
val dev = randIdx.takeRight(1000).map(datas(_)).toArray
(train, dev)
}
val (trainLabels, devLabels) = {
val train = randIdx.dropRight(1000).map(labels(_)).toArray
val dev = randIdx.takeRight(1000).map(labels(_)).toArray
(train, dev)
val output = NDArrayCollector.auto().withScope {
val (numEmbed, word2vec) = DataHelper.loadGoogleModel(w2vFilePath)
val (datas, labels) = DataHelper.loadMSDataWithWord2vec(
mrDatasetPath, numEmbed, word2vec)
// randomly shuffle data
val randIdx = Random.shuffle((0 until datas.length).toList)
// split train/dev set
val (trainDats, devDatas) = {
val train = randIdx.dropRight(1000).map(datas(_)).toArray
val dev = randIdx.takeRight(1000).map(datas(_)).toArray
(train, dev)
}
val (trainLabels, devLabels) = {
val train = randIdx.dropRight(1000).map(labels(_)).toArray
val dev = randIdx.takeRight(1000).map(labels(_)).toArray
(train, dev)
}
// reshpae for convolution input
val sentenceSize = datas(0).length
val batchSize = 100
val lr = 0.001f
val cnnModel = setupCnnModel(ctx, batchSize, sentenceSize, numEmbed)
val result = trainCNN(cnnModel, trainDats, trainLabels, devDatas, devLabels, batchSize,
saveModelPath, learningRate = lr)
result
}
// reshpae for convolution input
val sentenceSize = datas(0).length
val batchSize = 100
val lr = 0.001f
val cnnModel = setupCnnModel(ctx, batchSize, sentenceSize, numEmbed)
val result = trainCNN(cnnModel, trainDats, trainLabels, devDatas, devLabels, batchSize,
saveModelPath, learningRate = lr)
result
output
}

def main(args: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.mxnetexamples.customop

import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet.DType.DType
import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, Operator, Shape, Symbol, Xavier}
import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, NDArrayCollector, Operator, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.RMSProp
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -141,49 +141,50 @@ object ExampleCustomOp {
evalMetric.reset()
var nBatch = 0
var epochDone = false

trainIter.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainIter.hasNext) {
val dataBatch = trainIter.next()
argDict("data").set(dataBatch.data(0))
argDict("label").set(dataBatch.label(0))
executor.forward(isTrain = true)
executor.backward()
paramsGrads.foreach { case (idx, name, grad, optimState) =>
opt.update(idx, argDict(name), grad, optimState)
NDArrayCollector.auto().withScope {
trainIter.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainIter.hasNext) {
val dataBatch = trainIter.next()
argDict("data").set(dataBatch.data(0))
argDict("label").set(dataBatch.label(0))
executor.forward(isTrain = true)
executor.backward()
paramsGrads.foreach { case (idx, name, grad, optimState) =>
opt.update(idx, argDict(name), grad, optimState)
}
evalMetric.update(dataBatch.label, executor.outputs)
nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
}
if (doReset) {
trainIter.reset()
}
evalMetric.update(dataBatch.label, executor.outputs)
nBatch += 1
batchEndCallback.invoke(epoch, nBatch, evalMetric)
epochDone = true
}
if (doReset) {
trainIter.reset()
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-accuracy=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
testIter.reset()
while (testIter.hasNext) {
val evalBatch = testIter.next()
argDict("data").set(evalBatch.data(0))
argDict("label").set(evalBatch.label(0))
executor.forward(isTrain = true)
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
val (names, values) = evalMetric.get
names.zip(values).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
validationAcc = Math.max(validationAcc, v)
}
epochDone = true
}
val (name, value) = evalMetric.get
name.zip(value).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Train-accuracy=$v")
}
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalMetric.reset()
testIter.reset()
while (testIter.hasNext) {
val evalBatch = testIter.next()
argDict("data").set(evalBatch.data(0))
argDict("label").set(evalBatch.label(0))
executor.forward(isTrain = true)
evalMetric.update(evalBatch.label, executor.outputs)
evalBatch.dispose()
}
val (names, values) = evalMetric.get
names.zip(values).foreach { case (n, v) =>
logger.info(s"Epoch[$epoch] Validation-accuracy=$v")
validationAcc = Math.max(validationAcc, v)
}
}
executor.dispose()
Expand Down
Loading

0 comments on commit 2f177d8

Please sign in to comment.