Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions python/docs/source/development/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,14 @@ you should regenerate Python Protobuf client by running ``dev/connect-gen-protos
Running PySpark Shell with Python Client
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To run Spark Connect server you locally built:
For Apache Spark you locally built:

.. code-block:: bash

bin/spark-shell \
--jars `ls connector/connect/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \
--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
bin/pyspark --remote "local[*]"

To run the Spark Connect server from the Apache Spark release:
For the Apache Spark release:

.. code-block:: bash

bin/spark-shell \
--packages org.apache.spark:spark-connect_2.12:3.4.0 \
--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin


To run the PySpark Shell with the client for the Spark Connect server:

.. code-block:: bash

bin/pyspark --remote sc://localhost

bin/pyspark --remote "local[*]" --packages org.apache.spark:spark-connect_2.12:3.4.0
2 changes: 1 addition & 1 deletion python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
udf_profiler_cls: Type[UDFBasicProfiler] = UDFBasicProfiler,
memory_profiler_cls: Type[MemoryProfiler] = MemoryProfiler,
):
if "SPARK_REMOTE" in os.environ and "SPARK_TESTING" not in os.environ:
if "SPARK_REMOTE" in os.environ and "SPARK_LOCAL_REMOTE" not in os.environ:
raise RuntimeError(
"Remote client cannot create a SparkContext. Create SparkSession instead."
)
Expand Down
15 changes: 5 additions & 10 deletions python/pyspark/sql/connect/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import pandas as pd

from pyspark import SparkContext, SparkConf
from pyspark.sql.types import StructType
from pyspark.sql.connect import DataFrame
from pyspark.sql.catalog import (
Expand Down Expand Up @@ -324,16 +323,12 @@ def _test() -> None:
import pyspark.sql.connect.catalog

globs = pyspark.sql.connect.catalog.__dict__.copy()
# Works around to create a regular Spark session
sc = SparkContext("local[4]", "sql.connect.catalog tests", conf=SparkConf())
globs["_spark"] = PySparkSession(
sc, options={"spark.app.name": "sql.connect.catalog tests"}
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.catalog tests")
.remote("local[4]")
.getOrCreate()
)

# Creates a remote Spark session.
os.environ["SPARK_REMOTE"] = "sc://localhost"
globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate()

# TODO(SPARK-41612): Support Catalog.isCached
# TODO(SPARK-41600): Support Catalog.cacheTable
del pyspark.sql.connect.catalog.Catalog.clearCache.__doc__
Expand All @@ -348,8 +343,8 @@ def _test() -> None:
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL,
)
globs["_spark"].stop()
globs["spark"].stop()

if failure_count:
sys.exit(-1)
else:
Expand Down
14 changes: 6 additions & 8 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Optional,
)

from pyspark import SparkContext, SparkConf
from pyspark.sql.types import DataType
from pyspark.sql.column import Column as PySparkColumn

Expand Down Expand Up @@ -434,13 +433,12 @@ def _test() -> None:
import pyspark.sql.connect.column

globs = pyspark.sql.connect.column.__dict__.copy()
# Works around to create a regular Spark session
sc = SparkContext("local[4]", "sql.connect.column tests", conf=SparkConf())
globs["_spark"] = PySparkSession(sc, options={"spark.app.name": "sql.connect.column tests"})
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.column tests")
.remote("local[4]")
.getOrCreate()
)

# Creates a remote Spark session.
os.environ["SPARK_REMOTE"] = "sc://localhost"
globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate()
# Spark Connect has a different string representation for Column.
del pyspark.sql.connect.column.Column.getItem.__doc__

Expand All @@ -456,7 +454,7 @@ def _test() -> None:
)

globs["spark"].stop()
globs["_spark"].stop()

if failure_count:
sys.exit(-1)
else:
Expand Down
18 changes: 7 additions & 11 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import warnings
from collections.abc import Iterable

from pyspark import _NoValue, SparkContext, SparkConf
from pyspark import _NoValue
from pyspark._globals import _NoValueType
from pyspark.sql.types import StructType, Row

Expand Down Expand Up @@ -1532,12 +1532,6 @@ def _test() -> None:
import pyspark.sql.connect.dataframe

globs = pyspark.sql.connect.dataframe.__dict__.copy()
# Works around to create a regular Spark session
sc = SparkContext("local[4]", "sql.connect.dataframe tests", conf=SparkConf())
globs["_spark"] = PySparkSession(
sc, options={"spark.app.name": "sql.connect.dataframe tests"}
)

# Spark Connect does not support RDD but the tests depend on them.
del pyspark.sql.connect.dataframe.DataFrame.coalesce.__doc__
del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__
Expand All @@ -1564,9 +1558,11 @@ def _test() -> None:
# TODO(SPARK-41818): Support saveAsTable
del pyspark.sql.connect.dataframe.DataFrame.write.__doc__

# Creates a remote Spark session.
os.environ["SPARK_REMOTE"] = "sc://localhost"
globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.dataframe,
Expand All @@ -1577,7 +1573,7 @@ def _test() -> None:
)

globs["spark"].stop()
globs["_spark"].stop()

if failure_count:
sys.exit(-1)
else:
Expand Down
17 changes: 7 additions & 10 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
cast,
)

from pyspark import SparkContext, SparkConf
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import (
CaseWhen,
Expand Down Expand Up @@ -2354,11 +2353,7 @@ def _test() -> None:
import pyspark.sql.connect.functions

globs = pyspark.sql.connect.functions.__dict__.copy()
# Works around to create a regular Spark session
sc = SparkContext("local[4]", "sql.connect.functions tests", conf=SparkConf())
globs["_spark"] = PySparkSession(
sc, options={"spark.app.name": "sql.connect.functions tests"}
)

# Spark Connect does not support Spark Context but the test depends on that.
del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__

Expand Down Expand Up @@ -2407,9 +2402,11 @@ def _test() -> None:
del pyspark.sql.connect.functions.map_zip_with.__doc__
del pyspark.sql.connect.functions.posexplode.__doc__

# Creates a remote Spark session.
os.environ["SPARK_REMOTE"] = "sc://localhost"
globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.functions tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.functions,
Expand All @@ -2420,7 +2417,7 @@ def _test() -> None:
)

globs["spark"].stop()
globs["_spark"].stop()

if failure_count:
sys.exit(-1)
else:
Expand Down
15 changes: 7 additions & 8 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def _test() -> None:
import os
import sys
import doctest
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession as PySparkSession
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message

Expand All @@ -236,21 +235,21 @@ def _test() -> None:
import pyspark.sql.connect.group

globs = pyspark.sql.connect.group.__dict__.copy()
# Works around to create a regular Spark session
sc = SparkContext("local[4]", "sql.connect.group tests", conf=SparkConf())
globs["_spark"] = PySparkSession(sc, options={"spark.app.name": "sql.connect.group tests"})

# Creates a remote Spark session.
os.environ["SPARK_REMOTE"] = "sc://localhost"
globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.group tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.group,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
globs["_spark"].stop()

globs["spark"].stop()

if failure_count:
sys.exit(-1)
else:
Expand Down
16 changes: 6 additions & 10 deletions python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Optional, Union, List, overload, Tuple, cast, Any
from typing import TYPE_CHECKING

from pyspark import SparkContext, SparkConf
from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation
from pyspark.sql.types import StructType
from pyspark.sql.utils import to_str
Expand Down Expand Up @@ -497,11 +496,6 @@ def _test() -> None:
import pyspark.sql.connect.readwriter

globs = pyspark.sql.connect.readwriter.__dict__.copy()
# Works around to create a regular Spark session
sc = SparkContext("local[4]", "sql.connect.readwriter tests", conf=SparkConf())
globs["_spark"] = PySparkSession(
sc, options={"spark.app.name": "sql.connect.readwriter tests"}
)

# TODO(SPARK-41817): Support reading with schema
del pyspark.sql.connect.readwriter.DataFrameReader.load.__doc__
Expand All @@ -517,9 +511,11 @@ def _test() -> None:
del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__
del pyspark.sql.connect.readwriter.DataFrameWriter.saveAsTable.__doc__

# Creates a remote Spark session.
os.environ["SPARK_REMOTE"] = "sc://localhost"
globs["spark"] = PySparkSession.builder.remote("sc://localhost").getOrCreate()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.readwriter tests")
.remote("local[4]")
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.readwriter,
Expand All @@ -530,7 +526,7 @@ def _test() -> None:
)

globs["spark"].stop()
globs["_spark"].stop()

if failure_count:
sys.exit(-1)
else:
Expand Down
Loading