diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3562083b0674..dd78a784915d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -266,6 +266,7 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + ctx.copyResult = true ctx.addMutableState(s"$samplerClass", sampler, s"$initSampler();") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 499f3180379c..f89951760f7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1578,4 +1578,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(rdd, StructType(schemas), false) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + + test("copy results for sampling with replacement") { + val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") + val sampleDf = df.sample(true, 2.00) + val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect + assert(d.size == d.distinct.size) + } }