Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def convert_other(value: Any) -> Any:
return lambda value: value

@staticmethod
def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
def convert(data: Sequence[Any], schema: StructType, verifySchema: bool = False) -> "pa.Table":
assert isinstance(data, list) and len(data) > 0

assert schema is not None and isinstance(schema, StructType)
Expand Down Expand Up @@ -372,8 +372,8 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
]
)
)

return pa.Table.from_arrays(pylist, schema=pa_schema)
table = pa.Table.from_arrays(pylist, schema=pa_schema)
return table.cast(pa_schema, safe=verifySchema)


class ArrowTableToRowsConversion:
Expand Down
30 changes: 23 additions & 7 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
import urllib

from pyspark._globals import _NoValue, _NoValueType
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.sql.connect.logging import logger
Expand Down Expand Up @@ -449,7 +450,7 @@ def createDataFrame(
data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]],
schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: Optional[bool] = None,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> "ParentDataFrame":
assert data is not None
if isinstance(data, DataFrame):
Expand All @@ -461,9 +462,6 @@ def createDataFrame(
if samplingRatio is not None:
warnings.warn("'samplingRatio' is ignored. It is not supported with Spark Connect.")

if verifySchema is not None:
warnings.warn("'verifySchema' is ignored. It is not supported with Spark Connect.")

_schema: Optional[Union[AtomicType, StructType]] = None
_cols: Optional[List[str]] = None
_num_cols: Optional[int] = None
Expand Down Expand Up @@ -576,7 +574,10 @@ def createDataFrame(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)

ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")
if verifySchema is _NoValue:
verifySchema = safecheck == "true"

ser = ArrowStreamPandasSerializer(cast(str, timezone), verifySchema)

_table = pa.Table.from_batches(
[
Expand All @@ -596,6 +597,9 @@ def createDataFrame(
).cast(arrow_schema)

elif isinstance(data, pa.Table):
if verifySchema is _NoValue:
verifySchema = False

prefer_timestamp_ntz = is_timestamp_ntz_preferred()

(timezone,) = self._client.get_configs("spark.sql.session.timeZone")
Expand All @@ -613,7 +617,10 @@ def createDataFrame(

_table = (
_check_arrow_table_timestamps_localize(data, schema, True, timezone)
.cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True))
.cast(
to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True),
safe=verifySchema,
)
.rename_columns(schema.names)
)

Expand Down Expand Up @@ -652,6 +659,12 @@ def createDataFrame(
# The _table should already have the proper column names.
_cols = None

if verifySchema is not _NoValue:
warnings.warn(
"'verifySchema' is ignored. It is not supported"
" with np.ndarray input on Spark Connect."
)

else:
_data = list(data)

Expand Down Expand Up @@ -683,12 +696,15 @@ def createDataFrame(
errorClass="CANNOT_DETERMINE_TYPE", messageParameters={}
)

if verifySchema is _NoValue:
verifySchema = True

from pyspark.sql.connect.conversion import LocalDataToArrowConversion

# Spark Connect will try its best to build the Arrow table with the
# inferred schema in the client side, and then rename the columns and
# cast the datatypes in the server side.
_table = LocalDataToArrowConversion.convert(_data, _schema)
_table = LocalDataToArrowConversion.convert(_data, _schema, cast(bool, verifySchema))

# TODO: Beside the validation on number of columns, we should also check
# whether the Arrow Schema is compatible with the user provided Schema.
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/tests/connect/test_parity_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def test_toPandas_udt(self):
def test_create_dataframe_namedtuples(self):
self.check_create_dataframe_namedtuples(True)

@unittest.skip("Spark Connect does not support verifySchema.")
def test_createDataFrame_verifySchema(self):
super().test_createDataFrame_verifySchema()
self.check_createDataFrame_verifySchema(True)


if __name__ == "__main__":
Expand Down
46 changes: 27 additions & 19 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ def test_createDataFrame_arrow_pandas(self):
self.assertEqual(df_arrow.collect(), df_pandas.collect())

def test_createDataFrame_verifySchema(self):
for arrow_enabled in [True, False]:
with self.subTest(arrow_enabled=arrow_enabled):
self.check_createDataFrame_verifySchema(arrow_enabled)

def check_createDataFrame_verifySchema(self, arrow_enabled):
data = {"id": [1, 2, 3], "value": [100000000000, 200000000000, 300000000000]}
# data.value should fail schema validation when verifySchema is True
schema = StructType(
Expand All @@ -547,29 +552,32 @@ def test_createDataFrame_verifySchema(self):
table = pa.table(data)
df = self.spark.createDataFrame(table, schema=schema)
self.assertEqual(df.collect(), expected)

with self.assertRaises(Exception):
self.spark.createDataFrame(table, schema=schema, verifySchema=True)

# pandas DataFrame with Arrow optimization
pdf = pd.DataFrame(data)
df = self.spark.createDataFrame(pdf, schema=schema)
# verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`,
# which is false by default
self.assertEqual(df.collect(), expected)
with self.assertRaises(Exception):
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
if arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
# pandas DataFrame with Arrow optimization
pdf = pd.DataFrame(data)
df = self.spark.createDataFrame(pdf, schema=schema)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True)

# pandas DataFrame without Arrow optimization
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
pdf = pd.DataFrame(data)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False)
self.assertEqual(df.collect(), expected)
# verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`,
# which is false by default
self.assertEqual(df.collect(), expected)
with self.assertRaises(Exception):
with self.sql_conf(
{"spark.sql.execution.pandas.convertToArrowArraySafely": True}
):
df = self.spark.createDataFrame(pdf, schema=schema)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True)
else:
# pandas DataFrame without Arrow optimization
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
pdf = pd.DataFrame(data)
with self.assertRaises(Exception):
self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False)
self.assertEqual(df.collect(), expected)

def _createDataFrame_toggle(self, data, schema=None):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
Expand Down