Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
* @tparam T item type
*/
@DeveloperApi
class PoissonSampler[T: ClassTag](
class PoissonSampler[T](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mengxr Is it OK to change this?

fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {

Expand Down
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,10 @@ object MimaExcludes {
// for multilayer perceptron.
// This class is marked as `private`.
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction")
) ++ Seq(
// [SPARK-13674][SQL] Add wholestage codegen support to Sample
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
)
case v if v.startsWith("1.6") =>
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public abstract class BufferedRowIterator {
protected UnsafeRow unsafeRow = new UnsafeRow(0);
private long startTimeNs = System.nanoTime();

protected int partitionIndex = -1;

public boolean hasNext() throws IOException {
if (currentRows.isEmpty()) {
processNext();
Expand All @@ -58,7 +60,7 @@ public long durationMs() {
/**
* Initializes from array of iterators of InternalRow.
*/
public abstract void init(Iterator<InternalRow> iters[]);
public abstract void init(int index, Iterator<InternalRow> iters[]);

/**
* Append a row to currentRows.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.broadcast
import org.apache.spark.{broadcast, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -323,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
this.references = references;
}

public void init(scala.collection.Iterator inputs[]) {
public void init(int index, scala.collection.Iterator inputs[]) {
partitionIndex = index;
${ctx.initMutableStates()}
}

Expand Down Expand Up @@ -351,10 +352,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitions { iter =>
rdds.head.mapPartitionsWithIndex { (index, iter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(Array(iter))
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
Expand All @@ -367,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
} else {
// Right now, we support up to two upstreams.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
val partitionIndex = TaskContext.getPartitionId()
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(Array(leftIter, rightIter))
buffer.init(partitionIndex, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.random.PoissonSampler
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryNode with CodegenSupport {
Expand Down Expand Up @@ -223,9 +223,12 @@ case class Sample(
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan) extends UnaryNode {
child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output

private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))

protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
Expand All @@ -239,6 +242,63 @@ case class Sample(
child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
}
}

override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}

protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val sampler = ctx.freshName("sampler")

if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
s"$initSampler();")

ctx.addNewFunction(initSampler,
s"""
| private void $initSampler() {
| $sampler = new $samplerClass<UnsafeRow>($upperBound - $lowerBound, false);
| java.util.Random random = new java.util.Random(${seed}L);
| long randomSeed = random.nextLong();
| int loopCount = 0;
| while (loopCount < partitionIndex) {
| randomSeed = random.nextLong();
| loopCount += 1;
| }
| $sampler.setSeed(randomSeed);
| }
""".stripMargin.trim)

val samplingCount = ctx.freshName("samplingCount")
s"""
| int $samplingCount = $sampler.sample();
| while ($samplingCount-- > 0) {
| $numOutput.add(1);
| ${consume(ctx, input)}
| }
""".stripMargin.trim
} else {
val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName
ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
s"""
| $sampler = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, false);
| $sampler.setSeed(${seed}L + partitionIndex);
""".stripMargin.trim)

s"""
| if ($sampler.sample() == 0) continue;
| $numOutput.add(1);
| ${consume(ctx, input)}
""".stripMargin.trim
}
}
}

case class Range(
Expand Down Expand Up @@ -320,11 +380,7 @@ case class Range(
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if ($input.hasNext()) {
| initRange(((InternalRow) $input.next()).getInt(0));
| } else {
| return;
| }
| initRange(partitionIndex);
| }
|
| while (!$overflow && $checkEnd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}

ignore("range/sample/sum") {
val N = 500 << 20
runBenchmark("range/sample/sum", N) {
sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect()
}
/*
Westmere E56xx/L56xx/X56xx (Nehalem-C)
range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
range/sample/sum codegen=false 53888 / 56592 9.7 102.8 1.0X
range/sample/sum codegen=true 41614 / 42607 12.6 79.4 1.3X
*/

runBenchmark("range/sample/sum", N) {
sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect()
}
/*
Westmere E56xx/L56xx/X56xx (Nehalem-C)
range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
range/sample/sum codegen=false 12982 / 13384 40.4 24.8 1.0X
range/sample/sum codegen=true 7074 / 7383 74.1 13.5 1.8X
*/
}

ignore("stat functions") {
val N = 100L << 20

Expand Down