diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index ec059d625843..569887bce1a0 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -254,6 +254,7 @@ def substr(self, startPos, length): :param startPos: start position (int or Column) :param length: length of the substring (int or Column) + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col=u'Ali'), Row(col=u'Bob')] """ @@ -276,6 +277,7 @@ def isin(self, *cols): A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df[df.name.isin("Bob", "Mike")].collect() [Row(age=5, name=u'Bob')] >>> df[df.age.isin([1, 2, 3])].collect() @@ -303,6 +305,7 @@ def alias(self, *alias): Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ @@ -320,6 +323,7 @@ def alias(self, *alias): def cast(self, dataType): """ Convert the column into type ``dataType``. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.age.cast("string").alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] >>> df.select(df.age.cast(StringType()).alias('ages')).collect() @@ -344,6 +348,7 @@ def between(self, lowerBound, upperBound): A boolean expression that is evaluated to true if the value of this expression is between the given columns. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name, df.age.between(2, 4)).show() +-----+---------------------------+ | name|((age >= 2) AND (age <= 4))| @@ -366,6 +371,7 @@ def when(self, condition, value): :param value: a literal value, or a :class:`Column` expression. >>> from pyspark.sql import functions as F + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() +-----+------------------------------------------------------------+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| @@ -391,6 +397,7 @@ def otherwise(self, value): :param value: a literal value, or a :class:`Column` expression. >>> from pyspark.sql import functions as F + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() +-----+-------------------------------------+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END| @@ -412,9 +419,17 @@ def over(self, window): :return: a Column >>> from pyspark.sql import Window - >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) - >>> from pyspark.sql.functions import rank, min - >>> # df.select(rank().over(window), min('age').over(window)) + >>> window = Window.partitionBy("name").orderBy("age") + >>> from pyspark.sql.functions import rank + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob'), (3, 'Bob')], ['age', 'name']) + >>> df.select('name', 'age', rank().over(window)).show() + +-----+---+-----------------------------------------------------------------------------+ + | name|age|RANK() OVER (PARTITION BY name ORDER BY age ASC NULLS FIRST UnspecifiedFrame)| + +-----+---+-----------------------------------------------------------------------------+ + | Bob| 3| 1| + | Bob| 5| 2| + |Alice| 2| 1| + +-----+---+-----------------------------------------------------------------------------+ """ from pyspark.sql.window import WindowSpec if not isinstance(window, WindowSpec): @@ -442,9 +457,7 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) + globs['spark'] = spark (failure_count, test_count) = doctest.testmod( pyspark.sql.column, globs=globs, diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c22f4b87e1a7..e6e8dafdb7dc 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -338,6 +338,8 @@ def registerDataFrameAsTable(self, df, tableName): Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + >>> df = sqlContext.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") """ df.createOrReplaceTempView(tableName) @@ -346,6 +348,8 @@ def registerDataFrameAsTable(self, df, tableName): def dropTempTable(self, tableName): """ Remove the temp table from catalog. + >>> df = sqlContext.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> sqlContext.dropTempTable("table1") """ @@ -376,6 +380,8 @@ def sql(self, sqlQuery): :return: :class:`DataFrame` + >>> df = sqlContext.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -389,6 +395,8 @@ def table(self, tableName): :return: :class:`DataFrame` + >>> df = sqlContext.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -409,6 +417,8 @@ def tables(self, dbName=None): :param dbName: string, name of the database to use. :return: :class:`DataFrame` + >>> df = sqlContext.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() @@ -426,6 +436,8 @@ def tableNames(self, dbName=None): :param dbName: string, name of the database to use. Default to the current database. :return: list of table names, in string + >>> df = sqlContext.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() True @@ -474,6 +486,7 @@ def readStream(self): :return: :class:`DataStreamReader` + >>> import tempfile >>> text_sdf = sqlContext.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming True @@ -553,34 +566,18 @@ def register(self, name, f, returnType=StringType()): def _test(): import os import doctest - import tempfile from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import SQLContext import pyspark.sql.context os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') - globs['tempfile'] = tempfile globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['rdd'] = rdd = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")] - ) - globs['df'] = rdd.toDF() - jsonStrings = [ - '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', - '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', - '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' - ] - globs['jsonStrings'] = jsonStrings - globs['json'] = sc.parallelize(jsonStrings) + (failure_count, test_count) = doctest.testmod( pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 50373b858519..77dd1423c854 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -105,6 +105,8 @@ def toJSON(self, use_unicode=True): Each row is turned into a JSON document as one element in the returned RDD. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.toJSON().first() u'{"age":2,"name":"Alice"}' """ @@ -118,6 +120,8 @@ def registerTempTable(self, name): The lifetime of this temporary table is tied to the :class:`SQLContext` that was used to create this :class:`DataFrame`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.registerTempTable("people") >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -137,6 +141,8 @@ def createTempView(self, name): throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the catalog. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.createTempView("people") >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -157,6 +163,8 @@ def createOrReplaceTempView(self, name): The lifetime of this temporary table is tied to the :class:`SparkSession` that was used to create this :class:`DataFrame`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.createOrReplaceTempView("people") >>> df2 = df.filter(df.age > 3) >>> df2.createOrReplaceTempView("people") @@ -176,6 +184,8 @@ def createGlobalTempView(self, name): throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the catalog. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.createGlobalTempView("people") >>> df2 = spark.sql("select * from global_temp.people") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -218,8 +228,10 @@ def writeStream(self): def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.schema - StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + StructType(List(StructField(age,LongType,true),StructField(name,StringType,true))) """ if self._schema is None: try: @@ -233,11 +245,13 @@ def schema(self): def printSchema(self): """Prints out the schema in the tree format. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.printSchema() root - |-- age: integer (nullable = true) + |-- age: long (nullable = true) |-- name: string (nullable = true) - + """ print(self._jdf.schema().treeString()) @@ -247,9 +261,11 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.explain() == Physical Plan == - Scan ExistingRDD[age#0,name#1] + Scan ExistingRDD[age#...,name#...] >>> df.explain(True) == Parsed Logical Plan == @@ -296,8 +312,10 @@ def show(self, n=20, truncate=True): If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df - DataFrame[age: int, name: string] + DataFrame[age: bigint, name: string] >>> df.show() +---+-----+ |age| name| @@ -359,6 +377,8 @@ def withWatermark(self, eventTime, delayThreshold): .. note:: Experimental + >>> sdf = spark.createDataFrame([('Tom', 1479441846), ('Bob', 1479442946)], + ... ['name', 'time']) >>> sdf.select('name', sdf.time.cast('timestamp')).withWatermark('time', '10 minutes') DataFrame[name: string, time: timestamp] """ @@ -373,6 +393,8 @@ def withWatermark(self, eventTime, delayThreshold): def count(self): """Returns the number of rows in this :class:`DataFrame`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.count() 2 """ @@ -383,6 +405,8 @@ def count(self): def collect(self): """Returns all the records as a list of :class:`Row`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ @@ -397,6 +421,8 @@ def toLocalIterator(self): Returns an iterator that contains all of the rows in this :class:`DataFrame`. The iterator will consume as much memory as the largest partition in this DataFrame. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> list(df.toLocalIterator()) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ @@ -409,6 +435,8 @@ def toLocalIterator(self): def limit(self, num): """Limits the result count to the number specified. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.limit(1).collect() [Row(age=2, name=u'Alice')] >>> df.limit(0).collect() @@ -422,6 +450,8 @@ def limit(self, num): def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ @@ -433,6 +463,8 @@ def foreach(self, f): This is a shorthand for ``df.rdd.foreach()``. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> def f(person): ... print(person.name) >>> df.foreach(f) @@ -445,6 +477,8 @@ def foreachPartition(self, f): This a shorthand for ``df.rdd.foreachPartition()``. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> def f(people): ... for person in people: ... print(person.name) @@ -481,6 +515,9 @@ def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK): def storageLevel(self): """Get the :class:`DataFrame`'s current storage level. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + >>> df2 = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Bob', height=85)]) >>> df.storageLevel StorageLevel(False, False, False, False, 1) >>> df.cache().storageLevel @@ -517,6 +554,8 @@ def coalesce(self, numPartitions): there will not be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.coalesce(1).rdd.getNumPartitions() 1 """ @@ -536,6 +575,8 @@ def repartition(self, numPartitions, *cols): Added optional arguments to specify the partitioning columns. Also made numPartitions optional if partitioning columns are specified. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.repartition(10).rdd.getNumPartitions() 10 >>> data = df.union(df).repartition("age") @@ -553,10 +594,10 @@ def repartition(self, numPartitions, *cols): +---+-----+ |age| name| +---+-----+ - | 2|Alice| | 5| Bob| - | 2|Alice| | 5| Bob| + | 2|Alice| + | 2|Alice| +---+-----+ >>> data.rdd.getNumPartitions() 7 @@ -565,10 +606,10 @@ def repartition(self, numPartitions, *cols): +---+-----+ |age| name| +---+-----+ - | 5| Bob| - | 5| Bob| | 2|Alice| | 2|Alice| + | 5| Bob| + | 5| Bob| +---+-----+ """ if isinstance(numPartitions, int): @@ -587,6 +628,8 @@ def repartition(self, numPartitions, *cols): def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.distinct().count() 2 """ @@ -599,6 +642,8 @@ def sample(self, withReplacement, fraction, seed=None): .. note:: This is not guaranteed to provide exactly the fraction specified of the total count of the given :class:`DataFrame`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.sample(False, 0.5, 42).count() 2 """ @@ -651,6 +696,11 @@ def randomSplit(self, weights, seed=None): be normalized if they don't sum up to 1.0. :param seed: The seed for sampling. + >>> from pyspark.sql import Row + >>> df4 = spark.createDataFrame([Row(name='Alice', age=10, height=80), + ... Row(name='Bob', age=5, height=None), + ... Row(name='Tom', age=None, height=None), + ... Row(name=None, age=None, height=None)]) >>> splits = df4.randomSplit([1.0, 2.0], 24) >>> splits[0].count() 1 @@ -670,8 +720,10 @@ def randomSplit(self, weights, seed=None): def dtypes(self): """Returns all column names and their data types as a list. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.dtypes - [('age', 'int'), ('name', 'string')] + [('age', 'bigint'), ('name', 'string')] """ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @@ -680,6 +732,8 @@ def dtypes(self): def columns(self): """Returns all column names as a list. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.columns ['age', 'name'] """ @@ -690,7 +744,9 @@ def columns(self): def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. - >>> from pyspark.sql.functions import * + >>> from pyspark.sql import Row + >>> from pyspark.sql.functions import col + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') @@ -707,6 +763,9 @@ def crossJoin(self, other): :param other: Right side of the cartesian product. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + >>> df2 = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Bob', height=85)]) >>> df.select("age", "name").collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df2.select("name", "height").collect() @@ -735,12 +794,16 @@ def join(self, other, on=None, how=None): The following performs a full outer join between ``df1`` and ``df2``. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + >>> df2 = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Bob', height=85)]) >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + >>> df3 = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] @@ -748,6 +811,10 @@ def join(self, other, on=None, how=None): >>> df.join(df2, 'name').select(df.name, df2.height).collect() [Row(name=u'Bob', height=85)] + >>> df4 = spark.createDataFrame([Row(name='Alice', age=10, height=80), + ... Row(name='Bob', age=5, height=None), + ... Row(name='Tom', age=None, height=None), + ... Row(name=None, age=None, height=None)]) >>> df.join(df4, ['name', 'age']).select(df.name, df.age).collect() [Row(name=u'Bob', age=5)] """ @@ -781,6 +848,8 @@ def sortWithinPartitions(self, *cols, **kwargs): Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, length of the list must equal length of the `cols`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.sortWithinPartitions("age", ascending=False).show() +---+-----+ |age| name| @@ -802,6 +871,8 @@ def sort(self, *cols, **kwargs): Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, length of the list must equal length of the `cols`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.sort("age", ascending=False).collect() @@ -867,6 +938,8 @@ def describe(self, *cols): .. note:: This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibility of the schema of the resulting DataFrame. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.describe(['age']).show() +-------+------------------+ |summary| age| @@ -877,6 +950,7 @@ def describe(self, *cols): | min| 2| | max| 5| +-------+------------------+ + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.describe().show() +-------+------------------+-----+ |summary| age| name| @@ -905,6 +979,8 @@ def head(self, n=None): :return: If n is greater than 1, return a list of :class:`Row`. If n is 1, return a single Row. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.head() Row(age=2, name=u'Alice') >>> df.head(1) @@ -920,6 +996,8 @@ def head(self, n=None): def first(self): """Returns the first row as a :class:`Row`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.first() Row(age=2, name=u'Alice') """ @@ -930,6 +1008,8 @@ def first(self): def __getitem__(self, item): """Returns the column as a :class:`Column`. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] >>> df[ ["name", "age"]].collect() @@ -956,6 +1036,8 @@ def __getitem__(self, item): def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ @@ -974,6 +1056,8 @@ def select(self, *cols): If one of the column names is '*', that column is expanded to include all columns in the current DataFrame. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() @@ -990,6 +1074,8 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ @@ -1008,6 +1094,8 @@ def filter(self, condition): :param condition: a :class:`Column` of :class:`types.BooleanType` or a string of SQL expression. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() @@ -1038,6 +1126,8 @@ def groupBy(self, *cols): :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.groupBy().avg().collect() [Row(avg(age)=3.5)] >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect()) @@ -1057,6 +1147,8 @@ def rollup(self, *cols): Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.rollup("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| @@ -1078,6 +1170,8 @@ def cube(self, *cols): Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.cube("name", df.age).count().orderBy("name", "age").show() +-----+----+-----+ | name| age|count| @@ -1100,6 +1194,8 @@ def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.agg({"age": "max"}).collect() [Row(max(age)=5)] >>> from pyspark.sql import functions as F @@ -1191,6 +1287,11 @@ def dropna(self, how='any', thresh=None, subset=None): This overwrites the `how` parameter. :param subset: optional list of column names to consider. + >>> from pyspark.sql import Row + >>> df4 = spark.createDataFrame([Row(name='Alice', age=10, height=80), + ... Row(name='Bob', age=5, height=None), + ... Row(name='Tom', age=None, height=None), + ... Row(name=None, age=None, height=None)]) >>> df4.na.drop().show() +---+------+-----+ |age|height| name| @@ -1228,6 +1329,11 @@ def fillna(self, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. + >>> from pyspark.sql import Row + >>> df4 = spark.createDataFrame([Row(name='Alice', age=10, height=80), + ... Row(name='Bob', age=5, height=None), + ... Row(name='Tom', age=None, height=None), + ... Row(name=None, age=None, height=None)]) >>> df4.na.fill(50).show() +---+------+-----+ |age|height| name| @@ -1286,6 +1392,11 @@ def replace(self, to_replace, value, subset=None): For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. + >>> from pyspark.sql import Row + >>> df4 = spark.createDataFrame([Row(name='Alice', age=10, height=80), + ... Row(name='Bob', age=5, height=None), + ... Row(name='Tom', age=None, height=None), + ... Row(name=None, age=None, height=None)]) >>> df4.na.replace(10, 20).show() +----+------+-----+ | age|height| name| @@ -1510,6 +1621,8 @@ def withColumn(self, colName, col): :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ @@ -1525,6 +1638,8 @@ def withColumnRenamed(self, existing, new): :param existing: string, name of the existing column to rename. :param col: string, new name of the column. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ @@ -1539,6 +1654,9 @@ def drop(self, *cols): :param cols: a string name of the column to drop, or a :class:`Column` to drop, or a list of string name of the columns to drop. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) + >>> df2 = spark.createDataFrame([Row(name='Tom', height=80), Row(name='Bob', height=85)]) >>> df.drop('age').collect() [Row(name=u'Alice'), Row(name=u'Bob')] @@ -1576,6 +1694,8 @@ def toDF(self, *cols): :param cols: list of new column names (string) + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.toDF('f1', 'f2').collect() [Row(f1=2, f2=u'Alice'), Row(f1=5, f2=u'Bob')] """ @@ -1591,6 +1711,8 @@ def toPandas(self): .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice @@ -1694,26 +1816,13 @@ def sampleBy(self, col, fractions, seed=None): def _test(): import doctest from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext, SparkSession + from pyspark.sql import SQLContext, SparkSession import pyspark.sql.dataframe - from pyspark.sql.functions import from_unixtime globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['spark'] = SparkSession(sc) - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2), - Row(name='Bob', age=5)]).toDF() - globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), - Row(name='Bob', age=5, height=None), - Row(name='Tom', age=None, height=None), - Row(name=None, age=None, height=None)]).toDF() - globs['sdf'] = sc.parallelize([Row(name='Tom', time=1479441846), - Row(name='Bob', time=1479442946)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 02c2350dc2d6..45fd9423a450 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -208,6 +208,7 @@ def approxCountDistinct(col, rsd=None): def approx_count_distinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() [Row(c=2)] """ @@ -313,6 +314,7 @@ def covar_samp(col1, col2): def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] @@ -342,6 +344,7 @@ def grouping(col): Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or not, returns 1 for aggregated or 0 for not aggregated in the result set. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() +-----+--------------+--------+ | name|grouping(name)|sum(age)| @@ -366,6 +369,7 @@ def grouping_id(*cols): .. note:: The list of columns should match with grouping columns exactly, or empty (means all the grouping columns). + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() +-----+-------------+--------+ | name|grouping_id()|sum(age)| @@ -553,6 +557,7 @@ def spark_partition_id(): .. note:: This is indeterministic because it depends on data partitioning and task scheduling. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ @@ -564,6 +569,7 @@ def spark_partition_id(): def expr(str): """Parses the expression string into the column that it represents + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(expr("length(name)")).collect() [Row(length(name)=5), Row(length(name)=3)] """ @@ -578,6 +584,7 @@ def struct(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(struct('age', 'name').alias("struct")).collect() [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] >>> df.select(struct([df.age, df.name]).alias("struct")).collect() @@ -630,6 +637,7 @@ def when(condition, value): :param condition: a boolean :class:`Column` expression. :param value: a literal value, or a :class:`Column` expression. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() [Row(age=3), Row(age=4)] @@ -650,6 +658,7 @@ def log(arg1, arg2=None): If there is only one argument, then this takes the natural logarithm of the argument. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).collect() ['0.30102', '0.69897'] @@ -1188,6 +1197,7 @@ def sha2(col, numBits): and SHA-512). The numBits indicates the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() >>> digests[0] Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') @@ -1526,6 +1536,7 @@ def soundex(col): def bin(col): """Returns the string representation of the binary value of the given column. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(bin(df.age).alias('c')).collect() [Row(c=u'10'), Row(c=u'101')] """ @@ -1600,6 +1611,7 @@ def create_map(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions that grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(create_map('name', 'age').alias("map")).collect() [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] >>> df.select(create_map([df.name, df.age]).alias("map")).collect() @@ -1619,6 +1631,7 @@ def array(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions that have the same data type. + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> df.select(array('age', 'age').alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] >>> df.select(array([df.age, df.age]).alias("arr")).collect() @@ -1763,7 +1776,6 @@ def to_json(col, options={}): :param options: options to control converting. accepts the same options as the json datasource >>> from pyspark.sql import Row - >>> from pyspark.sql.types import * >>> data = [(1, Row(name='Alice', age=2))] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() @@ -1871,6 +1883,7 @@ def udf(f, returnType=StringType()): :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + >>> df = spark.createDataFrame([('Alice', 2), ('Bob', 5)], ['name', 'age']) >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) >>> df.select(slen(df.name).alias('slen')).collect() @@ -1896,7 +1909,6 @@ def _test(): sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark - globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c6305..65f491aa944c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -73,6 +73,7 @@ def agg(self, *exprs): :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> gdf = df.groupBy(df.name) >>> sorted(gdf.agg({"*": "count"}).collect()) [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] @@ -96,6 +97,7 @@ def agg(self, *exprs): def count(self): """Counts the number of records for each group. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> sorted(df.groupBy(df.age).count().collect()) [Row(age=2, count=1), Row(age=5, count=1)] """ @@ -109,8 +111,12 @@ def mean(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().mean('age').collect() [Row(avg(age)=3.5)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ... ['age', 'name', 'height']) >>> df3.groupBy().mean('age', 'height').collect() [Row(avg(age)=3.5, avg(height)=82.5)] """ @@ -124,8 +130,12 @@ def avg(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().avg('age').collect() [Row(avg(age)=3.5)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ... ['age', 'name', 'height']) >>> df3.groupBy().avg('age', 'height').collect() [Row(avg(age)=3.5, avg(height)=82.5)] """ @@ -135,8 +145,12 @@ def avg(self, *cols): def max(self, *cols): """Computes the max value for each numeric columns for each group. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().max('age').collect() [Row(max(age)=5)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ... ['age', 'name', 'height']) >>> df3.groupBy().max('age', 'height').collect() [Row(max(age)=5, max(height)=85)] """ @@ -148,8 +162,12 @@ def min(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().min('age').collect() [Row(min(age)=2)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ... ['age', 'name', 'height']) >>> df3.groupBy().min('age', 'height').collect() [Row(min(age)=2, min(height)=80)] """ @@ -161,8 +179,12 @@ def sum(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. + >>> df = spark.createDataFrame([(2, 'Alice'), (5, 'Bob')], ['age', 'name']) >>> df.groupBy().sum('age').collect() [Row(sum(age)=7)] + + >>> df3 = spark.createDataFrame([(2, 'Alice', 80), (5, 'Bob', 85)], + ... ['age', 'name', 'height']) >>> df3.groupBy().sum('age', 'height').collect() [Row(sum(age)=7, sum(height)=165)] """ @@ -180,6 +202,12 @@ def pivot(self, pivot_col, values=None): # Compute the sum of earnings for each year by course with each course as a separate column + >>> df4 = spark.createDataFrame([("dotNET", 10000, 2012), + ... ("Java", 20000, 2012), + ... ("dotNET", 5000, 2012), + ... ("dotNET", 48000, 2013), + ... ("Java", 30000, 2013)], + ... ['course', 'earnings', 'year']) >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] @@ -204,18 +232,7 @@ def _test(): .master("local[4]")\ .appName("sql.group tests")\ .getOrCreate() - sc = spark.sparkContext - globs['sc'] = sc - globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ - .toDF(StructType([StructField('age', IntegerType()), - StructField('name', StringType())])) - globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), - Row(name='Bob', age=5, height=85)]).toDF() - globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000), - Row(course="Java", year=2012, earnings=20000), - Row(course="dotNET", year=2012, earnings=5000), - Row(course="dotNET", year=2013, earnings=48000), - Row(course="Java", year=2013, earnings=30000)]).toDF() + globs['spark'] = spark (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d31f3fb8f604..fc1db9e42566 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -471,6 +471,9 @@ def mode(self, saveMode): * `error`: Throw an exception if data already exists. * `ignore`: Silently ignore this operation if data already exists. + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ # At the JVM side, the default value of mode is already set to "error". @@ -485,6 +488,9 @@ def format(self, source): :param source: string, name of the data source, e.g. 'json', 'parquet'. + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) """ self._jwrite = self._jwrite.format(source) @@ -514,6 +520,9 @@ def partitionBy(self, *cols): :param cols: name of columns + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): @@ -540,6 +549,9 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): :param partitionBy: names of partitioning columns :param options: all other string options + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode).options(**options) @@ -613,6 +625,9 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) @@ -638,6 +653,9 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): is set, it uses the value specified in ``spark.sql.parquet.compression.codec``. + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) @@ -705,6 +723,9 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + >>> import os + >>> import tempfile + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) @@ -732,6 +753,8 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): This will override ``orc.compress``. If None is set, it uses the default value, ``snappy``. + >>> import os + >>> import tempfile >>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned') >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data')) """ @@ -771,10 +794,9 @@ def jdbc(self, url, table, mode=None, properties=None): def _test(): import doctest import os - import tempfile import py4j from pyspark.context import SparkContext - from pyspark.sql import SparkSession, Row + from pyspark.sql import SparkSession import pyspark.sql.readwriter os.chdir(os.environ["SPARK_HOME"]) @@ -786,8 +808,6 @@ def _test(): except py4j.protocol.Py4JError: spark = SparkSession(sc) - globs['tempfile'] = tempfile - globs['os'] = os globs['sc'] = sc globs['spark'] = spark globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2..21db40d5419e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -51,8 +51,10 @@ def toDF(self, schema=None, sampleRatio=None): :param samplingRatio: the sample ratio of rows used for inferring :return: a DataFrame + >>> from pyspark.sql import Row + >>> rdd = sc.parallelize([Row(field1=1, field2="row1"), Row(field1=2, field2="row2")]) >>> rdd.toDF().collect() - [Row(name=u'Alice', age=1)] + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] """ return sparkSession.createDataFrame(self, schema, sampleRatio) @@ -537,6 +539,8 @@ def sql(self, sqlQuery): :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> df.createOrReplaceTempView("table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -550,6 +554,8 @@ def table(self, tableName): :return: :class:`DataFrame` + >>> df = spark.createDataFrame([(1, 'row1'), (2, 'row2'), (3, 'row3')], + ... ['field1', 'field2']) >>> df.createOrReplaceTempView("table1") >>> df2 = spark.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -631,11 +637,6 @@ def _test(): sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['spark'] = SparkSession(sc) - globs['rdd'] = rdd = sc.parallelize( - [Row(field1=1, field2="row1"), - Row(field1=2, field2="row2"), - Row(field1=3, field2="row3")]) - globs['df'] = rdd.toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.session, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index a10b185cd4c7..881dee918d72 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -160,6 +160,7 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('query_explain').start() >>> sq.processAllAvailable() # Wait a bit to generate the runtime plans. >>> sq.explain() @@ -211,6 +212,7 @@ def __init__(self, jsqm): def active(self): """Returns a list of active queries associated with this SQLContext + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sqm = spark.streams >>> # get the list of active streaming queries @@ -226,6 +228,7 @@ def get(self, id): """Returns an active query from this SQLContext or throws exception if an active query with this name doesn't exist. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sq.name u'this_query' @@ -359,6 +362,8 @@ def schema(self, schema): :param schema: a :class:`pyspark.sql.types.StructType` object + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> s = spark.readStream.schema(sdf_schema) """ from pyspark.sql import SparkSession @@ -403,6 +408,9 @@ def load(self, path=None, format=None, schema=None, **options): :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> json_sdf = spark.readStream.format("json") \\ ... .schema(sdf_schema) \\ ... .load(tempfile.mkdtemp()) @@ -482,7 +490,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, This applies to timestamp type. If None is set, it uses the default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. - >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) + >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema=sdf_schema) >>> json_sdf.isStreaming True >>> json_sdf.schema == sdf_schema @@ -511,6 +522,9 @@ def parquet(self, path): .. note:: Experimental. + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp()) >>> parquet_sdf.isStreaming True @@ -536,6 +550,7 @@ def text(self, path): :param paths: string, or list of strings, for input path(s). + >>> import tempfile >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming True @@ -615,6 +630,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. + >>> import tempfile + >>> from pyspark.sql.types import StructType, StructField, StringType + >>> sdf_schema = StructType([StructField("data", StringType(), False)]) >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) >>> csv_sdf.isStreaming True @@ -671,6 +689,7 @@ def outputMode(self, outputMode): .. note:: Experimental. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> writer = sdf.writeStream.outputMode('append') """ if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: @@ -686,6 +705,7 @@ def format(self, source): :param source: string, name of the data source, which for now can be 'parquet'. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> writer = sdf.writeStream.format('json') """ self._jwrite = self._jwrite.format(source) @@ -737,6 +757,7 @@ def queryName(self, queryName): :param queryName: unique name for the query + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> writer = sdf.writeStream.queryName('streaming_query') """ if not queryName or type(queryName) != str or len(queryName.strip()) == 0: @@ -754,6 +775,7 @@ def trigger(self, processingTime=None): :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') """ @@ -798,6 +820,7 @@ def start(self, path=None, format=None, outputMode=None, partitionBy=None, query :param options: All other string options. You may want to provide a `checkpointLocation` for most streams, however it is not required for a `memory` stream. + >>> sdf = spark.readStream.format('text').load('python/test_support/sql/streaming') >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() >>> sq.isActive True @@ -844,15 +867,8 @@ def _test(): except py4j.protocol.Py4JError: spark = SparkSession(sc) - globs['tempfile'] = tempfile - globs['os'] = os globs['spark'] = spark globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext) - globs['sdf'] = \ - spark.readStream.format('text').load('python/test_support/sql/streaming') - globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) - globs['df'] = \ - globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.streaming, globs=globs,