Skip to content

Commit

Permalink
Fix scalastyle (apache#14669)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored and larroy committed Apr 15, 2019
1 parent 2c17660 commit 6dfe3eb
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,24 @@ class FeedForward private(

// Initialize the predictor module for running prediction.
private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
var shouldInit = true
if (this.predExec != null) {
val (argShapes, _, _) = symbol.inferShape(inputShapes)
require(argShapes != null, "Shape inference failed." +
s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " +
s"and aux states ${symbol.listAuxiliaryStates()}")
val predShapes = this.predExec.argArrays.map(_.shape)
if (argShapes.sameElements(predShapes)) {
return
shouldInit = false
}
}
// for now only use the first device
val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
predExec.copyParamsFrom(_argParams, _auxParams)
ExecutorManager.checkArguments(symbol)
this.predExec = predExec
if(shouldInit) {
// for now only use the first device
val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
predExec.copyParamsFrom(_argParams, _auxParams)
ExecutorManager.checkArguments(symbol)
this.predExec = predExec
}
}

// Initialize the iterator given input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,13 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
allowMissing: Boolean = false,
forceInit: Boolean = false,
allowExtra: Boolean = false): Unit = {
if (paramsInitialized && !forceInit) {
return
if (!paramsInitialized || forceInit) {
require(binded, "call bind before initializing the parameters")
this._currModule.initParams(initializer, argParams, auxParams,
allowMissing, forceInit, allowExtra)
this.paramsDirty = false
this.paramsInitialized = true
}
require(binded, "call bind before initializing the parameters")
this._currModule.initParams(initializer, argParams, auxParams,
allowMissing, forceInit, allowExtra)
this.paramsDirty = false
this.paramsInitialized = true
}

/**
Expand Down Expand Up @@ -218,28 +217,27 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[

if (this.binded) {
logger.warn("Already bound, ignoring bind()")
return
}
} else {
require(sharedModule.isEmpty,
"sharedModule for BucketingModule is not supported")

require(sharedModule.isEmpty,
"sharedModule for BucketingModule is not supported")

this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true

val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
val module = new Module(sym, dNames, lNames, this.contexts,
this.workLoadList, this.fixedParamNames)
module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
forceRebind = false, sharedModule = None, gradReq)
this._currModule = module
this._currBucketKey = this.defaultBucketKey
this._buckets(this.defaultBucketKey) = module

// copy back saved params, if already initialized
if (this.paramsInitialized) {
this.setParams(argParams, auxParams)
this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true

val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
val module = new Module(sym, dNames, lNames, this.contexts,
this.workLoadList, this.fixedParamNames)
module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
forceRebind = false, sharedModule = None, gradReq)
this._currModule = module
this._currBucketKey = this.defaultBucketKey
this._buckets(this.defaultBucketKey) = module

// copy back saved params, if already initialized
if (this.paramsInitialized) {
this.setParams(argParams, auxParams)
}
}
}

Expand Down
153 changes: 76 additions & 77 deletions scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,36 +121,35 @@ class Module(symbolVar: Symbol,
allowMissing: Boolean = false,
forceInit: Boolean = false,
allowExtra: Boolean = false): Unit = {
if (paramsInitialized && !forceInit) {
return
}
require(binded, "call bind before initializing the parameters")
if (!paramsInitialized || forceInit) {
require(binded, "call bind before initializing the parameters")

if (this.argParams == null) {
val paramArrays =
execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.argParams = this.paramNames.zip(paramArrays).toMap
}
if (this.argParams == null) {
val paramArrays =
execGroup.paramArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.argParams = this.paramNames.zip(paramArrays).toMap
}

if (this.auxParams == null) {
val auxArrays =
execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.auxParams = this.auxNames.zip(auxArrays).toMap
}
if (this.auxParams == null) {
val auxArrays =
execGroup.auxArrays.map(nds => NDArray.zeros(nds(0).shape, dtype = nds(0).dtype))
this.auxParams = this.auxNames.zip(auxArrays).toMap
}

this.argParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), argParams)
}
this.argParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), argParams)
}

this.auxParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), auxParams)
}
this.auxParams.foreach { case (name, arr) =>
impl(name, arr, allowMissing, Option(initializer), auxParams)
}

this.paramsInitialized = true
this.paramsDirty = false
this.paramsInitialized = true
this.paramsDirty = false

// copy the initialized parameters to devices
this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra)
// copy the initialized parameters to devices
this.execGroup.setParams(this.argParams, this.auxParams, allowExtra = allowExtra)
}
}

// Internal helper for parameter initialization
Expand Down Expand Up @@ -246,64 +245,64 @@ class Module(symbolVar: Symbol,

if (binded) {
logger.warn("Already binded, ignoring bind()")
return
}
} else {
this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true

this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
this.binded = true
if (!forTraining) {
require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)")
} else {
// this is not True, as some module might not contains a loss function
// that consumes the labels
// require(labelShapes != None)
}

if (!forTraining) {
require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)")
} else {
// this is not True, as some module might not contains a loss function
// that consumes the labels
// require(labelShapes != None)
}
this.dataShapesVar = dataShapes
this.labelShapesVar = labelShapes

this.dataShapesVar = dataShapes
this.labelShapesVar = labelShapes

val sharedGroup =
sharedModule.map(sharedModuleInst => {
require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized,
s"bind() and initParams() must be called first on shared module.")
sharedModuleInst.execGroup
})

val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap ++
labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap)
.getOrElse(Map.empty[String, DType])

execGroup = new Builder(symbol, contexts, paramNames)
.setWorkLoadList(workLoads)
.setDataShapes(dataShapes)
.setLabelShapes(labelShapes.orNull)
.setForTraining(forTraining)
.setInputsNeedGrad(inputsNeedGrad)
.setSharedGroup(sharedGroup.orNull)
.setFixedParamNames(fixedParamNames.orNull)
.setGradReq(gradReq)
.setInputTypes(inputTypes)
.build()

if (sharedModule.isDefined) {
paramsInitialized = true
argParams = sharedModule.get.argParams
auxParams = sharedModule.get.auxParams
} else if (paramsInitialized) {
// if the parameters are already initialized, we are re-binding
// so automatically copy the already initialized params
execGroup.setParams(argParams, auxParams)
}
val sharedGroup =
sharedModule.map(sharedModuleInst => {
require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized,
s"bind() and initParams() must be called first on shared module.")
sharedModuleInst.execGroup
})

sharedModule.foreach {
case sharedModuleInst: Module =>
if (sharedModuleInst.optimizerInitialized) {
borrowOptimizer(sharedModuleInst)
}
case _ =>
val inputTypes = this.dataShapesVar.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap ++
labelShapes.map(shapes => shapes.map(dataDesc => (dataDesc.name, dataDesc.dtype)).toMap)
.getOrElse(Map.empty[String, DType])

execGroup = new Builder(symbol, contexts, paramNames)
.setWorkLoadList(workLoads)
.setDataShapes(dataShapes)
.setLabelShapes(labelShapes.orNull)
.setForTraining(forTraining)
.setInputsNeedGrad(inputsNeedGrad)
.setSharedGroup(sharedGroup.orNull)
.setFixedParamNames(fixedParamNames.orNull)
.setGradReq(gradReq)
.setInputTypes(inputTypes)
.build()

if (sharedModule.isDefined) {
paramsInitialized = true
argParams = sharedModule.get.argParams
auxParams = sharedModule.get.auxParams
} else if (paramsInitialized) {
// if the parameters are already initialized, we are re-binding
// so automatically copy the already initialized params
execGroup.setParams(argParams, auxParams)
}

sharedModule.foreach {
case sharedModuleInst: Module =>
if (sharedModuleInst.optimizerInitialized) {
borrowOptimizer(sharedModuleInst)
}
case _ =>
}
}

}

/**
Expand Down
Loading

0 comments on commit 6dfe3eb

Please sign in to comment.