Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,41 @@ def sample(self, withReplacement, fraction, seed=None):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)

@since(1.5)
def sampleBy(self, col, fractions, seed=None):
"""
Returns a stratified sample without replacement based on the
fraction given on each stratum.

:param col: column that defines strata
:param fractions:
sampling fraction for each stratum. If a stratum is not
specified, we treat its fraction as zero.
:param seed: random seed
:return: a new DataFrame that represents the stratified sample

>>> from pyspark.sql.functions import col
>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
>>> sampled.groupBy("key").count().orderBy("key").show()
+---+-----+
|key|count|
+---+-----+
| 0| 5|
| 1| 8|
+---+-----+
"""
if not isinstance(col, str):
raise ValueError("col must be a string, but got %r" % type(col))
if not isinstance(fractions, dict):
raise ValueError("fractions must be a dict but got %r" % type(fractions))
for k, v in fractions.items():
if not isinstance(k, (float, int, long, basestring)):
raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
fractions[k] = float(v)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)

@since(1.4)
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
Expand Down Expand Up @@ -1318,6 +1353,11 @@ def freqItems(self, cols, support=None):

freqItems.__doc__ = DataFrame.freqItems.__doc__

def sampleBy(self, col, fractions, seed=None):
return self.df.sampleBy(col, fractions, seed)

sampleBy.__doc__ = DataFrame.sampleBy.__doc__


def _test():
import doctest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.util.UUID

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._

Expand Down Expand Up @@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @return a new [[DataFrame]] that represents the stratified sample
*
* @since 1.5.0
*/
def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = {
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
import org.apache.spark.sql.functions.rand
val c = Column(col)
val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8))
val expr = fractions.toSeq.map { case (k, v) =>
(c === k) && (r < v)
}.reduce(_ || _) || false
df.filter(expr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.functions.col

class DataFrameStatSuite extends SparkFunSuite {
class DataFrameStatSuite extends QueryTest {

private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
Expand Down Expand Up @@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite {
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}

test("sampleBy") {
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 4), Row(1, 9)))
}
}