diff --git a/src/main/scala/BIDMach/networks/layers/ModelLayer.scala b/src/main/scala/BIDMach/networks/layers/ModelLayer.scala index 34af05f2..09d600aa 100644 --- a/src/main/scala/BIDMach/networks/layers/ModelLayer.scala +++ b/src/main/scala/BIDMach/networks/layers/ModelLayer.scala @@ -15,7 +15,7 @@ import java.util.HashMap; import BIDMach.networks._ -class ModelLayer(override val net:Net, override val opts:ModelNodeOpts = new ModelNode) extends Layer(net, opts) { +class ModelLayer(override val net:Net, override val opts:ModelNodeOpts = new ModelNode, val nmats:Int = 1) extends Layer(net, opts) { var imodel = 0; override def getModelMats(net:Net):Unit = { @@ -27,11 +27,14 @@ class ModelLayer(override val net:Net, override val opts:ModelNodeOpts = new Mod } else { val len = net.modelMap.size; net.modelMap.put(opts.modelName, len + net.opts.nmodelmats); + for (i <- 1 until nmats) { + net.modelMap.put(opts.modelName+"_%d" format i, len + i + net.opts.nmodelmats); + } len; } } else { // Otherwise return the next available int - net.imodel += 1; - net.imodel - 1; + net.imodel += nmats; + net.imodel - nmats; }; } }