Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 90 additions & 111 deletions core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi
trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {

/** take a random sample */
def sample(items: Iterator[T]): Iterator[U]
def sample(items: Iterator[T]): Iterator[U] =
items.filter(_ => sample > 0).asInstanceOf[Iterator[U]]

/**
* Whether to sample the next item or not.
* Return how many times the next item will be sampled. Return 0 if it is not sampled.
*/
def sample(): Int

/** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] =
Expand Down Expand Up @@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals

override def setSeed(seed: Long): Unit = rng.setSeed(seed)

override def sample(items: Iterator[T]): Iterator[T] = {
override def sample(): Int = {
if (ub - lb <= 0.0) {
if (complement) items else Iterator.empty
if (complement) 1 else 0
} else {
if (complement) {
items.filter { item => {
val x = rng.nextDouble()
(x < lb) || (x >= ub)
}}
} else {
items.filter { item => {
val x = rng.nextDouble()
(x >= lb) && (x < ub)
}}
}
val x = rng.nextDouble()
val n = if ((x >= lb) && (x < ub)) 1 else 0
if (complement) 1 - n else n
}
}

Expand Down Expand Up @@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T

override def setSeed(seed: Long): Unit = rng.setSeed(seed)

override def sample(items: Iterator[T]): Iterator[T] = {
private lazy val gapSampling: GapSampling =
new GapSampling(fraction, rng, RandomSampler.rngEpsilon)

override def sample(): Int = {
if (fraction <= 0.0) {
Iterator.empty
0
} else if (fraction >= 1.0) {
items
1
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
gapSampling.sample()
} else {
items.filter { _ => rng.nextDouble() <= fraction }
if (rng.nextDouble() <= fraction) {
1
} else {
0
}
}
}

Expand Down Expand Up @@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag](
rngGap.setSeed(seed)
}

override def sample(items: Iterator[T]): Iterator[T] = {
private lazy val gapSamplingReplacement =
new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon)

override def sample(): Int = {
if (fraction <= 0.0) {
Iterator.empty
0
} else if (useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
gapSamplingReplacement.sample()
} else {
rng.sample()
}
}

override def sample(items: Iterator[T]): Iterator[T] = {
if (fraction <= 0.0) {
Iterator.empty
} else {
val useGapSampling = useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction

items.flatMap { item =>
val count = rng.sample()
val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}
}
Expand All @@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag](


private[spark]
class GapSamplingIterator[T: ClassTag](
var data: Iterator[T],
class GapSampling(
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {

require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")

/** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
private val iterDrop: Int => Unit = {
val arrayClass = Array.empty[T].iterator.getClass
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
case `arrayClass` =>
(n: Int) => { data = data.drop(n) }
case `arrayBufferClass` =>
(n: Int) => { data = data.drop(n) }
case _ =>
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
}
}
}

override def hasNext: Boolean = data.hasNext
private val lnq = math.log1p(-f)

override def next(): T = {
val r = data.next()
advance()
r
/** Return 1 if the next item should be sampled. Otherwise, return 0. */
def sample(): Int = {
if (countForDropping > 0) {
countForDropping -= 1
0
} else {
advance()
1
}
}

private val lnq = math.log1p(-f)
private var countForDropping: Int = 0

/** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
/**
* Decide the number of elements that won't be sampled,
* according to geometric dist P(k) = (f)(1-f)^k.
*/
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / lnq).toInt
iterDrop(k)
countForDropping = (math.log(u) / lnq).toInt
}

/** advance to first sample as part of object construction. */
Expand All @@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag](
// work reliably.
}


private[spark]
class GapSamplingReplacementIterator[T: ClassTag](
var data: Iterator[T],
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
class GapSamplingReplacement(
val f: Double,
val rng: Random = RandomSampler.newDefaultRNG,
epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {

require(f > 0.0, s"Sampling fraction ($f) must be > 0")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")

/** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
private val iterDrop: Int => Unit = {
val arrayClass = Array.empty[T].iterator.getClass
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
case `arrayClass` =>
(n: Int) => { data = data.drop(n) }
case `arrayBufferClass` =>
(n: Int) => { data = data.drop(n) }
case _ =>
(n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
}
}
}

/** current sampling value, and its replication factor, as we are sampling with replacement. */
private var v: T = _
private var rep: Int = 0

override def hasNext: Boolean = data.hasNext || rep > 0

override def next(): T = {
val r = v
rep -= 1
if (rep <= 0) advance()
r
}

/**
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probability of Poisson(0; f)
*/
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / (-f)).toInt
iterDrop(k)
// set the value and replication factor for the next value
if (data.hasNext) {
v = data.next()
rep = poissonGE1
}
}

private val q = math.exp(-f)
protected val q = math.exp(-f)

/**
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
* This is an adaptation from the algorithm for Generating Poisson distributed random variables:
* http://en.wikipedia.org/wiki/Poisson_distribution
*/
private def poissonGE1: Int = {
protected def poissonGE1: Int = {
// simulate that the standard poisson sampling
// gave us at least one iteration, for a sample of >= 1
var pp = q + ((1.0 - q) * rng.nextDouble())
Expand All @@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag](
}
r
}
private var countForDropping: Int = 0

def sample(): Int = {
if (countForDropping > 0) {
countForDropping -= 1
0
} else {
val r = poissonGE1
advance()
r
}
}

/**
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probabililty of Poisson(0; f)
*/
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
countForDropping = (math.log(u) / (-f)).toInt
}

/** advance to first sample as part of object construction. */
advance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] {
s = seed
}

override def sample(): Int = 1

override def sample(items: Iterator[Long]): Iterator[Long] = {
Iterator(s)
}
Expand Down
Loading