Skip to content

Commit df02e98

Browse files
committed
update blockify strategy
1 parent 9245263 commit df02e98

File tree

2 files changed

+18
-41
lines changed

2 files changed

+18
-41
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -142,55 +142,32 @@ private[spark] object InstanceBlock {
142142
new Iterator[InstanceBlock] {
143143
private var numCols = -1L
144144
private val buff = mutable.ArrayBuilder.make[Instance]
145-
private var buffCnt = 0L
146-
private var buffNnz = 0L
147-
private var buffUnitWeight = true
148-
private var block = Option.empty[InstanceBlock]
149145

150-
private def flush(): Unit = {
151-
block = Some(InstanceBlock.fromInstances(buff.result()))
152-
buff.clear()
153-
buffCnt = 0L
154-
buffNnz = 0L
155-
buffUnitWeight = true
156-
}
146+
override def hasNext: Boolean = iterator.hasNext
157147

158-
private def blockify(): Unit = {
159-
block = None
148+
override def next(): InstanceBlock = {
149+
buff.clear()
150+
var buffCnt = 0L
151+
var buffNnz = 0L
152+
var buffUnitWeight = true
153+
var blockMemUsage = 0L
160154

161-
while (block.isEmpty && iterator.hasNext) {
155+
while (iterator.hasNext && blockMemUsage < maxMemUsage) {
162156
val instance = iterator.next()
163157
if (numCols < 0L) numCols = instance.features.size
164158
require(numCols == instance.features.size)
165159
val nnz = instance.features.numNonzeros
166160

167-
// Check if enough memory remains to add this instance to the block.
168-
if (getBlockMemUsage(numCols, buffCnt + 1L, buffNnz + nnz,
169-
buffUnitWeight && (instance.weight == 1)) > maxMemUsage) {
170-
// Check if this instance is too large
171-
require(buffCnt > 0, s"instance $instance exceeds memory limit $maxMemUsage, " +
172-
s"please increase block size")
173-
flush()
174-
}
175-
176161
buff += instance
177162
buffCnt += 1L
178163
buffNnz += nnz
179164
buffUnitWeight &&= (instance.weight == 1)
165+
blockMemUsage = getBlockMemUsage(numCols, buffCnt, buffNnz, buffUnitWeight)
180166
}
181167

182-
if (block.isEmpty && buffCnt > 0) flush()
183-
}
184-
185-
override def hasNext: Boolean = {
186-
block.nonEmpty || { blockify(); block.nonEmpty }
187-
}
188-
189-
override def next(): InstanceBlock = {
190-
if (block.isEmpty) blockify()
191-
val ret = block.get
192-
blockify()
193-
ret
168+
// the block mem usage may slightly exceed threshold, not a big issue.
169+
// and this ensure even if one row exceed block limit, each block has one row
170+
InstanceBlock.fromInstances(buff.result())
194171
}
195172
}
196173
}

mllib/src/test/scala/org/apache/spark/ml/feature/InstanceSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ class InstanceSuite extends SparkFunSuite{
9797
assert(vec.toArray === instances(i).features.toArray)
9898
}
9999

100+
// instances larger than maxMemUsage
100101
val bigInstance = Instance(-1.0, 2.0, Vectors.dense(Array.fill(10000)(1.0)))
101-
val inputIter1 = Iterator.apply(bigInstance)
102-
val inputIter2 = Iterator.apply(instance1, instance2, bigInstance)
103-
Seq(inputIter1, inputIter2).foreach { inputIter =>
104-
intercept[IllegalArgumentException] {
105-
InstanceBlock.blokifyWithMaxMemUsage(inputIter, 1024).toArray
106-
}
102+
InstanceBlock.blokifyWithMaxMemUsage(Iterator.fill(10)(bigInstance), 64).size
103+
104+
// different numFeatures
105+
intercept[IllegalArgumentException] {
106+
InstanceBlock.blokifyWithMaxMemUsage(Iterator.apply(instance1, bigInstance), 64).size
107107
}
108108
}
109109
}

0 commit comments

Comments
 (0)