Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Nov 22, 2024
1 parent f276cf4 commit c8af66d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 61 deletions.
5 changes: 5 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Dict,
Set,
NoReturn,
Mapping,
cast,
TYPE_CHECKING,
Type,
Expand Down Expand Up @@ -1576,6 +1577,10 @@ def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
configs = dict(self.config(op).pairs)
return tuple(configs.get(key) for key in keys)

def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]:
op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
return dict(self.config(op).pairs)

def get_config_with_defaults(
self, *pairs: Tuple[str, Optional[str]]
) -> Tuple[Optional[str], ...]:
Expand Down
57 changes: 24 additions & 33 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
cast,
overload,
Iterable,
Mapping,
TYPE_CHECKING,
ClassVar,
)
Expand Down Expand Up @@ -112,7 +113,6 @@
from pyspark.sql.connect.tvf import TableValuedFunction
from pyspark.sql.connect.shell.progress import ProgressHandler
from pyspark.sql.connect.datasource import DataSourceRegistration
from pyspark.sql.connect.utils import LazyConfigGetter

try:
import memory_profiler # noqa: F401
Expand Down Expand Up @@ -411,7 +411,7 @@ def _inferSchemaFromList(
self,
data: Iterable[Any],
names: Optional[List[str]],
conf_getter: "LazyConfigGetter",
configs: Mapping[str, Optional[str]],
) -> StructType:
"""
Infer schema from list of Row, dict, or tuple.
Expand All @@ -428,10 +428,10 @@ def _inferSchemaFromList(
infer_map_from_first_pair,
prefer_timestamp,
) = (
conf_getter["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
conf_getter["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
conf_getter["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
conf_getter["spark.sql.timestampType"],
configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
configs["spark.sql.timestampType"],
)
return functools.reduce(
_merge_type,
Expand All @@ -455,22 +455,6 @@ def createDataFrame(
samplingRatio: Optional[float] = None,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> "ParentDataFrame":
from pyspark.sql.connect.utils import LazyConfigGetter

conf_getter = LazyConfigGetter(
keys=[
"spark.sql.timestampType",
"spark.sql.session.timeZone",
"spark.sql.session.localRelationCacheThreshold",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
],
session=self,
)

assert data is not None
if isinstance(data, DataFrame):
raise PySparkTypeError(
Expand Down Expand Up @@ -525,8 +509,21 @@ def createDataFrame(
messageParameters={},
)

# Get all related configs in a batch
configs = self._client.get_config_dict(
"spark.sql.timestampType",
"spark.sql.session.timeZone",
"spark.sql.session.localRelationCacheThreshold",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
)

_table: Optional[pa.Table] = None
timezone: Optional[str] = None
timezone: Optional[str] = configs["spark.sql.session.timeZone"]
prefer_timestamp = configs["spark.sql.timestampType"]

if isinstance(data, pd.DataFrame):
# Logic was borrowed from `_create_from_pandas_with_arrow` in
Expand All @@ -536,7 +533,7 @@ def createDataFrame(
if schema is None:
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
infer_pandas_dict_as_map = (
conf_getter["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true"
configs["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
Expand Down Expand Up @@ -588,8 +585,7 @@ def createDataFrame(
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]

timezone = conf_getter["spark.sql.session.timeZone"]
safecheck = conf_getter["spark.sql.execution.pandas.convertToArrowArraySafely"]
safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]

if verifySchema is _NoValue:
verifySchema = safecheck == "true"
Expand Down Expand Up @@ -617,11 +613,6 @@ def createDataFrame(
if verifySchema is _NoValue:
verifySchema = False

timezone = conf_getter["spark.sql.session.timeZone"]
prefer_timestamp = conf_getter["spark.sql.timestampType"]

(timezone,) = self._client.get_configs("spark.sql.session.timeZone")

# If no schema supplied by user then get the names of columns only
if schema is None:
_cols = data.column_names
Expand Down Expand Up @@ -704,7 +695,7 @@ def createDataFrame(
if not isinstance(_schema, StructType):
_schema = StructType().add("value", _schema)
else:
_schema = self._inferSchemaFromList(_data, _cols, conf_getter)
_schema = self._inferSchemaFromList(_data, _cols, configs)

if _cols is not None and cast(int, _num_cols) < len(_cols):
_num_cols = len(_cols)
Expand Down Expand Up @@ -742,7 +733,7 @@ def createDataFrame(
else:
local_relation = LocalRelation(_table)

cache_threshold = conf_getter["spark.sql.session.localRelationCacheThreshold"]
cache_threshold = configs["spark.sql.session.localRelationCacheThreshold"]
plan: LogicalPlan = local_relation
if cache_threshold is not None and int(cache_threshold) <= _table.nbytes:
plan = CachedLocalRelation(self._cache_local_relation(local_relation))
Expand Down
28 changes: 0 additions & 28 deletions python/pyspark/sql/connect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
# limitations under the License.
#
import sys
from typing import Optional, Sequence, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession

from pyspark.loose_version import LooseVersion
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
Expand Down Expand Up @@ -102,27 +98,3 @@ def require_minimum_googleapis_common_protos_version() -> None:

def get_python_ver() -> str:
return "%d.%d" % sys.version_info[:2]


class LazyConfigGetter:
def __init__(
self,
keys: Sequence[str],
session: "SparkSession",
):
assert len(keys) > 0 and len(keys) == len(set(keys))
assert all(isinstance(key, str) for key in keys)
assert session is not None
self._keys = keys
self._session = session
self._values: Dict[str, Optional[str]] = {}

def __getitem__(self, key: str) -> Optional[str]:
assert key in self._keys

if len(self._values) == 0:
values = self._session._client.get_configs(*self._keys)
for i, value in enumerate(values):
self._values[self._keys[i]] = value

return self._values[key]

0 comments on commit c8af66d

Please sign in to comment.