Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,13 @@
from pyspark.sql.types import (
AtomicType,
DataType,
StructField,
StructType,
VariantVal,
_make_type_verifier,
_infer_schema,
_has_nulltype,
_merge_type,
_create_converter,
_from_numpy_type,
)
from pyspark.errors.exceptions.captured import install_exception_handler
from pyspark.sql.utils import (
Expand Down Expand Up @@ -1565,32 +1563,36 @@ def createDataFrame( # type: ignore[misc]
has_pyarrow = False

if has_numpy and isinstance(data, np.ndarray):
# `data` of numpy.ndarray type will be converted to a pandas DataFrame,
# so pandas is required.
from pyspark.sql.pandas.utils import require_minimum_pandas_version
# `data` of numpy.ndarray type will be converted to an arrow Table,
# so pyarrow is required.
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()
if data.ndim not in [1, 2]:
raise PySparkValueError(
errorClass="INVALID_NDARRAY_DIMENSION",
messageParameters={"dimensions": "1 or 2"},
)

if data.ndim == 1 or data.shape[1] == 1:
column_names = ["value"]
col_names: list[str] = []
if isinstance(schema, StructType):
col_names = schema.names
elif isinstance(schema, list):
col_names = schema
elif data.ndim == 1 or data.shape[1] == 1:
col_names = ["value"]
else:
column_names = ["_%s" % i for i in range(1, data.shape[1] + 1)]

if schema is None and not self._jconf.arrowPySparkEnabled():
# Construct `schema` from `np.dtype` of the input NumPy array
# TODO: Apply the logic below when self._jconf.arrowPySparkEnabled() is True
spark_type = _from_numpy_type(data.dtype)
if spark_type is not None:
schema = StructType(
[StructField(name, spark_type, nullable=True) for name in column_names]
)
col_names = [f"_{i + 1}" for i in range(0, data.shape[1])]

data = pd.DataFrame(data, columns=column_names)
if data.ndim == 1:
data = pa.Table.from_arrays(arrays=[pa.array(data)], names=col_names)
elif data.shape[1] == 1:
data = pa.Table.from_arrays(arrays=[pa.array(data.squeeze())], names=col_names)
else:
data = pa.Table.from_arrays(
arrays=[pa.array(data[::, i]) for i in range(0, data.shape[1])],
names=col_names,
)

if has_pandas and isinstance(data, pd.DataFrame):
# Create a DataFrame from pandas DataFrame.
Expand Down