Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)

@since(1.5)
def sampleBy(self, col, fractions, seed=None):
"""
Returns a stratified sample without replacement based on the
fraction given on each stratum.

:param col: column that defines strata
:param fractions:
sampling fraction for each stratum. If a stratum is not
specified, we treat its fraction as zero.
:param seed: random seed
:return: a new DataFrame that represents the stratified sample

>>> from pyspark.sql.functions import col
>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
>>> sampled.groupBy("key").count().orderBy("key").show()
+---+-----+
|key|count|
+---+-----+
| 0| 3|
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from Scala output because we use different number of threads in python unit tests. See https://issues.apache.org/jira/browse/SPARK-9487.

| 1| 8|
+---+-----+

"""
if not isinstance(col, str):
raise ValueError("col must be a string, but got %r" % type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" % type(fractions))
for k, v in fractions.items():
if not isinstance(k, (float, int, long, basestring)):
raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
fractions[k] = float(v)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)

@since(1.4)
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
Expand Down Expand Up @@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None):

freqItems.__doc__ = DataFrame.freqItems.__doc__

def sampleBy(self, col, fractions, seed=None):
return self.df.sampleBy(col, fractions, seed)

sampleBy.__doc__ = DataFrame.sampleBy.__doc__


def _test():
import doctest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import java.{util => ju, lang => jl}

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._

Expand Down Expand Up @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @tparam T stratum type
* @return a new [[DataFrame]] that represents the stratified sample
*
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
import org.apache.spark.sql.functions.{rand, udf}
val c = Column(col)
val r = rand(seed)
val f = udf { (stratum: Any, x: Double) =>
x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
}
df.filter(f(c, r))
}

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @tparam T stratum type
* @return a new [[DataFrame]] that represents the stratified sample
*
* @since 1.5.0
*/
def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,13 @@ public void testCovariance() {
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1e-6);
}

@Test
public void testSampleBy() {
DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
Assert.assertArrayEquals(expected, actual);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import java.util.Random

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.functions.col

class DataFrameStatSuite extends SparkFunSuite {
class DataFrameStatSuite extends QueryTest {

private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
Expand Down Expand Up @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite {
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}

test("sampleBy") {
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 5), Row(1, 8)))
}
}