From 052fc951462535e6088fa9239652415a6b070337 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 2 Dec 2025 19:53:16 +0800 Subject: [PATCH 1/3] test test --- python/pyspark/sql/session.py | 38 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index dbcfe52e77b6e..8a56ee6328019 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1565,32 +1565,34 @@ def createDataFrame( # type: ignore[misc] has_pyarrow = False if has_numpy and isinstance(data, np.ndarray): - # `data` of numpy.ndarray type will be converted to a pandas DataFrame, - # so pandas is required. - from pyspark.sql.pandas.utils import require_minimum_pandas_version + # `data` of numpy.ndarray type will be converted to an arrow Table, + # so pyarrow is required. + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version - require_minimum_pandas_version() + require_minimum_pyarrow_version() if data.ndim not in [1, 2]: raise PySparkValueError( errorClass="INVALID_NDARRAY_DIMENSION", messageParameters={"dimensions": "1 or 2"}, ) - if data.ndim == 1 or data.shape[1] == 1: - column_names = ["value"] + col_names: list[str] = [] + if isinstance(schema, StructType): + col_names = schema.names + elif isinstance(schema, list): + col_names = schema + + if data.ndim == 1: + col_names = col_names or ["value"] + data = pa.Table.from_arrays(arrays=[pa.array(data)], names=col_names) + elif data.shape[1] == 1: + col_names = col_names or ["value"] + data = pa.Table.from_arrays(arrays=[pa.array(data.squeeze())], names=col_names) else: - column_names = ["_%s" % i for i in range(1, data.shape[1] + 1)] - - if schema is None and not self._jconf.arrowPySparkEnabled(): - # Construct `schema` from `np.dtype` of the input NumPy array - # TODO: Apply the logic below when self._jconf.arrowPySparkEnabled() is True - spark_type = _from_numpy_type(data.dtype) - if spark_type is not None: - schema = StructType( - [StructField(name, spark_type, nullable=True) for name in column_names] - ) - - data = pd.DataFrame(data, columns=column_names) + n_cols = data.shape[1] + arrow_columns = [pa.array(data[::, i]) for i in range(0, n_cols)] + col_names = col_names or [f"_{i + 1}" for i in range(0, n_cols)] + data = pa.Table.from_arrays(arrays=arrow_columns, names=col_names) if has_pandas and isinstance(data, pd.DataFrame): # Create a DataFrame from pandas DataFrame. From 40fe3f670bfb6c66564318010718fa4c9aa9fc95 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 2 Dec 2025 20:04:53 +0800 Subject: [PATCH 2/3] simplify --- python/pyspark/sql/session.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 8a56ee6328019..43cb3802c20c2 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1581,18 +1581,20 @@ def createDataFrame( # type: ignore[misc] col_names = schema.names elif isinstance(schema, list): col_names = schema + elif data.ndim == 1 or data.shape[1] == 1: + col_names = ["value"] + else: + col_names = [f"_{i + 1}" for i in range(0, data.shape[1])] if data.ndim == 1: - col_names = col_names or ["value"] data = pa.Table.from_arrays(arrays=[pa.array(data)], names=col_names) elif data.shape[1] == 1: - col_names = col_names or ["value"] data = pa.Table.from_arrays(arrays=[pa.array(data.squeeze())], names=col_names) else: - n_cols = data.shape[1] - arrow_columns = [pa.array(data[::, i]) for i in range(0, n_cols)] - col_names = col_names or [f"_{i + 1}" for i in range(0, n_cols)] - data = pa.Table.from_arrays(arrays=arrow_columns, names=col_names) + data = pa.Table.from_arrays( + arrays=[pa.array(data[::, i]) for i in range(0, data.shape[1])], + names=col_names, + ) if has_pandas and isinstance(data, pd.DataFrame): # Create a DataFrame from pandas DataFrame. From d0812c0a5ea58d3c11a3db5aa4a2fd00063a3c7f Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 3 Dec 2025 08:30:43 +0800 Subject: [PATCH 3/3] lint --- python/pyspark/sql/session.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 43cb3802c20c2..b59ed9f0a840e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -52,7 +52,6 @@ from pyspark.sql.types import ( AtomicType, DataType, - StructField, StructType, VariantVal, _make_type_verifier, @@ -60,7 +59,6 @@ _has_nulltype, _merge_type, _create_converter, - _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler from pyspark.sql.utils import (