Skip to content

Commit 97dde82

Browse files
committed
Fix bugs in sampling
1 parent cc33460 commit 97dde82

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ case class SampleExec(
266266
if (withReplacement) {
267267
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
268268
val initSampler = ctx.freshName("initSampler")
269+
ctx.copyResult = true
269270
ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
270271
s"$initSampler();")
271272

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,4 +1578,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
15781578
val df = spark.createDataFrame(rdd, StructType(schemas), false)
15791579
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
15801580
}
1581+
1582+
test("copy results for sampling with replacement") {
1583+
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
1584+
val sampleDf = df.sample(true, 2.00)
1585+
val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
1586+
assert(d.size == d.distinct.size)
1587+
}
15811588
}

0 commit comments

Comments
 (0)