diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index dbcfe52e77b6e..b59ed9f0a840e 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -52,7 +52,6 @@ from pyspark.sql.types import ( AtomicType, DataType, - StructField, StructType, VariantVal, _make_type_verifier, @@ -60,7 +59,6 @@ _has_nulltype, _merge_type, _create_converter, - _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler from pyspark.sql.utils import ( @@ -1565,32 +1563,36 @@ def createDataFrame( # type: ignore[misc] has_pyarrow = False if has_numpy and isinstance(data, np.ndarray): - # `data` of numpy.ndarray type will be converted to a pandas DataFrame, - # so pandas is required. - from pyspark.sql.pandas.utils import require_minimum_pandas_version + # `data` of numpy.ndarray type will be converted to an arrow Table, + # so pyarrow is required. + from pyspark.sql.pandas.utils import require_minimum_pyarrow_version - require_minimum_pandas_version() + require_minimum_pyarrow_version() if data.ndim not in [1, 2]: raise PySparkValueError( errorClass="INVALID_NDARRAY_DIMENSION", messageParameters={"dimensions": "1 or 2"}, ) - if data.ndim == 1 or data.shape[1] == 1: - column_names = ["value"] + col_names: list[str] = [] + if isinstance(schema, StructType): + col_names = schema.names + elif isinstance(schema, list): + col_names = schema + elif data.ndim == 1 or data.shape[1] == 1: + col_names = ["value"] else: - column_names = ["_%s" % i for i in range(1, data.shape[1] + 1)] - - if schema is None and not self._jconf.arrowPySparkEnabled(): - # Construct `schema` from `np.dtype` of the input NumPy array - # TODO: Apply the logic below when self._jconf.arrowPySparkEnabled() is True - spark_type = _from_numpy_type(data.dtype) - if spark_type is not None: - schema = StructType( - [StructField(name, spark_type, nullable=True) for name in column_names] - ) + col_names = [f"_{i + 1}" for i in range(0, data.shape[1])] - data = pd.DataFrame(data, columns=column_names) + if data.ndim == 1: + data = pa.Table.from_arrays(arrays=[pa.array(data)], names=col_names) + elif data.shape[1] == 1: + data = pa.Table.from_arrays(arrays=[pa.array(data.squeeze())], names=col_names) + else: + data = pa.Table.from_arrays( + arrays=[pa.array(data[::, i]) for i in range(0, data.shape[1])], + names=col_names, + ) if has_pandas and isinstance(data, pd.DataFrame): # Create a DataFrame from pandas DataFrame.