Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Nov 22, 2024
1 parent 190c504 commit f276cf4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 19 deletions.
58 changes: 39 additions & 19 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
from pyspark.sql.utils import is_timestamp_ntz_preferred

check_dependencies(__name__)

Expand Down Expand Up @@ -113,6 +112,7 @@
from pyspark.sql.connect.tvf import TableValuedFunction
from pyspark.sql.connect.shell.progress import ProgressHandler
from pyspark.sql.connect.datasource import DataSourceRegistration
from pyspark.sql.connect.utils import LazyConfigGetter

try:
import memory_profiler # noqa: F401
Expand Down Expand Up @@ -408,7 +408,10 @@ def clearProgressHandlers(self) -> None:
clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__

def _inferSchemaFromList(
self, data: Iterable[Any], names: Optional[List[str]] = None
self,
data: Iterable[Any],
names: Optional[List[str]],
conf_getter: "LazyConfigGetter",
) -> StructType:
"""
Infer schema from list of Row, dict, or tuple.
Expand All @@ -423,12 +426,12 @@ def _inferSchemaFromList(
infer_dict_as_struct,
infer_array_from_first_element,
infer_map_from_first_pair,
prefer_timestamp_ntz,
) = self._client.get_configs(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
"spark.sql.timestampType",
prefer_timestamp,
) = (
conf_getter["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
conf_getter["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
conf_getter["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
conf_getter["spark.sql.timestampType"],
)
return functools.reduce(
_merge_type,
Expand All @@ -439,7 +442,7 @@ def _inferSchemaFromList(
infer_dict_as_struct=(infer_dict_as_struct == "true"),
infer_array_from_first_element=(infer_array_from_first_element == "true"),
infer_map_from_first_pair=(infer_map_from_first_pair == "true"),
prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"),
prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"),
)
for row in data
),
Expand All @@ -452,6 +455,22 @@ def createDataFrame(
samplingRatio: Optional[float] = None,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> "ParentDataFrame":
from pyspark.sql.connect.utils import LazyConfigGetter

conf_getter = LazyConfigGetter(
keys=[
"spark.sql.timestampType",
"spark.sql.session.timeZone",
"spark.sql.session.localRelationCacheThreshold",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
],
session=self,
)

assert data is not None
if isinstance(data, DataFrame):
raise PySparkTypeError(
Expand Down Expand Up @@ -517,8 +536,7 @@ def createDataFrame(
if schema is None:
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
infer_pandas_dict_as_map = (
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
== "true"
conf_getter["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true"
)
if infer_pandas_dict_as_map:
struct = StructType()
Expand Down Expand Up @@ -570,9 +588,8 @@ def createDataFrame(
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]

timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
)
timezone = conf_getter["spark.sql.session.timeZone"]
safecheck = conf_getter["spark.sql.execution.pandas.convertToArrowArraySafely"]

if verifySchema is _NoValue:
verifySchema = safecheck == "true"
Expand Down Expand Up @@ -600,7 +617,8 @@ def createDataFrame(
if verifySchema is _NoValue:
verifySchema = False

prefer_timestamp_ntz = is_timestamp_ntz_preferred()
timezone = conf_getter["spark.sql.session.timeZone"]
prefer_timestamp = conf_getter["spark.sql.timestampType"]

(timezone,) = self._client.get_configs("spark.sql.session.timeZone")

Expand All @@ -613,7 +631,9 @@ def createDataFrame(
_num_cols = len(_cols)

if not isinstance(schema, StructType):
schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)
schema = from_arrow_schema(
data.schema, prefer_timestamp_ntz=prefer_timestamp == "TIMESTAMP_NTZ"
)

_table = (
_check_arrow_table_timestamps_localize(data, schema, True, timezone)
Expand Down Expand Up @@ -684,7 +704,7 @@ def createDataFrame(
if not isinstance(_schema, StructType):
_schema = StructType().add("value", _schema)
else:
_schema = self._inferSchemaFromList(_data, _cols)
_schema = self._inferSchemaFromList(_data, _cols, conf_getter)

if _cols is not None and cast(int, _num_cols) < len(_cols):
_num_cols = len(_cols)
Expand Down Expand Up @@ -722,9 +742,9 @@ def createDataFrame(
else:
local_relation = LocalRelation(_table)

cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
cache_threshold = conf_getter["spark.sql.session.localRelationCacheThreshold"]
plan: LogicalPlan = local_relation
if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes:
if cache_threshold is not None and int(cache_threshold) <= _table.nbytes:
plan = CachedLocalRelation(self._cache_local_relation(local_relation))

df = DataFrame(plan, self)
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/connect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
# limitations under the License.
#
import sys
from typing import Optional, Sequence, Dict, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession

from pyspark.loose_version import LooseVersion
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
Expand Down Expand Up @@ -98,3 +102,27 @@ def require_minimum_googleapis_common_protos_version() -> None:

def get_python_ver() -> str:
return "%d.%d" % sys.version_info[:2]


class LazyConfigGetter:
def __init__(
self,
keys: Sequence[str],
session: "SparkSession",
):
assert len(keys) > 0 and len(keys) == len(set(keys))
assert all(isinstance(key, str) for key in keys)
assert session is not None
self._keys = keys
self._session = session
self._values: Dict[str, Optional[str]] = {}

def __getitem__(self, key: str) -> Optional[str]:
assert key in self._keys

if len(self._values) == 0:
values = self._session._client.get_configs(*self._keys)
for i, value in enumerate(values):
self._values[self._keys[i]] = value

return self._values[key]

0 comments on commit f276cf4

Please sign in to comment.