Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import Matrix, Vector, Vectors, VectorUDT
from pyspark.sql import DataFrame, Row
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -3678,7 +3678,7 @@ class _OneVsRestSharedReadWrite:
@staticmethod
def saveImpl(
instance: Union[OneVsRest, "OneVsRestModel"],
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
path: str,
extraMetadata: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -3691,7 +3691,10 @@ def saveImpl(
cast(MLWritable, instance.getClassifier()).save(classifierPath)

@staticmethod
def loadClassifier(path: str, sc: "SparkContext") -> Union[OneVsRest, "OneVsRestModel"]:
def loadClassifier(
path: str,
sc: Union["SparkContext", SparkSession],
) -> Union[OneVsRest, "OneVsRestModel"]:
classifierPath = os.path.join(path, "classifier")
return DefaultParamsReader.loadParamsInstance(classifierPath, sc)

Expand All @@ -3716,11 +3719,13 @@ def __init__(self, cls: Type[OneVsRest]) -> None:
self.cls = cls

def load(self, path: str) -> OneVsRest:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
classifier = cast(Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sc))
classifier = cast(
Classifier, _OneVsRestSharedReadWrite.loadClassifier(path, self.sparkSession)
)
ova: OneVsRest = OneVsRest(classifier=classifier)._resetUid(metadata["uid"])
DefaultParamsReader.getAndSetParams(ova, metadata, skipParams=["classifier"])
return ova
Expand All @@ -3734,7 +3739,7 @@ def __init__(self, instance: OneVsRest):

def saveImpl(self, path: str) -> None:
_OneVsRestSharedReadWrite.validateParams(self.instance)
_OneVsRestSharedReadWrite.saveImpl(self.instance, self.sc, path)
_OneVsRestSharedReadWrite.saveImpl(self.instance, self.sparkSession, path)


class OneVsRestModel(
Expand Down Expand Up @@ -3963,16 +3968,18 @@ def __init__(self, cls: Type[OneVsRestModel]):
self.cls = cls

def load(self, path: str) -> OneVsRestModel:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sc)
classifier = _OneVsRestSharedReadWrite.loadClassifier(path, self.sparkSession)
numClasses = metadata["numClasses"]
subModels = [None] * numClasses
for idx in range(numClasses):
subModelPath = os.path.join(path, f"model_{idx}")
subModels[idx] = DefaultParamsReader.loadParamsInstance(subModelPath, self.sc)
subModels[idx] = DefaultParamsReader.loadParamsInstance(
subModelPath, self.sparkSession
)
ovaModel = OneVsRestModel(cast(List[ClassificationModel], subModels))._resetUid(
metadata["uid"]
)
Expand All @@ -3992,7 +3999,9 @@ def saveImpl(self, path: str) -> None:
instance = self.instance
numClasses = len(instance.models)
extraMetadata = {"numClasses": numClasses}
_OneVsRestSharedReadWrite.saveImpl(instance, self.sc, path, extraMetadata=extraMetadata)
_OneVsRestSharedReadWrite.saveImpl(
instance, self.sparkSession, path, extraMetadata=extraMetadata
)
for idx in range(numClasses):
subModelPath = os.path.join(path, f"model_{idx}")
cast(MLWritable, instance.models[idx]).save(subModelPath)
Expand Down
19 changes: 11 additions & 8 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

if TYPE_CHECKING:
Expand Down Expand Up @@ -230,7 +231,7 @@ def __init__(self, instance: Pipeline):
def saveImpl(self, path: str) -> None:
stages = self.instance.getStages()
PipelineSharedReadWrite.validateStages(stages)
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sparkSession, path)


@inherit_doc
Expand All @@ -244,11 +245,11 @@ def __init__(self, cls: Type[Pipeline]):
self.cls = cls

def load(self, path: str) -> Pipeline:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python":
return JavaMLReader(cast(Type["JavaMLReadable[Pipeline]"], self.cls)).load(path)
else:
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
uid, stages = PipelineSharedReadWrite.load(metadata, self.sparkSession, path)
return Pipeline(stages=stages)._resetUid(uid)


Expand All @@ -266,7 +267,7 @@ def saveImpl(self, path: str) -> None:
stages = self.instance.stages
PipelineSharedReadWrite.validateStages(cast(List["PipelineStage"], stages))
PipelineSharedReadWrite.saveImpl(
self.instance, cast(List["PipelineStage"], stages), self.sc, path
self.instance, cast(List["PipelineStage"], stages), self.sparkSession, path
)


Expand All @@ -281,11 +282,11 @@ def __init__(self, cls: Type["PipelineModel"]):
self.cls = cls

def load(self, path: str) -> "PipelineModel":
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if "language" not in metadata["paramMap"] or metadata["paramMap"]["language"] != "Python":
return JavaMLReader(cast(Type["JavaMLReadable[PipelineModel]"], self.cls)).load(path)
else:
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
uid, stages = PipelineSharedReadWrite.load(metadata, self.sparkSession, path)
return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid)


Expand Down Expand Up @@ -403,7 +404,7 @@ def validateStages(stages: List["PipelineStage"]) -> None:
def saveImpl(
instance: Union[Pipeline, PipelineModel],
stages: List["PipelineStage"],
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
path: str,
) -> None:
"""
Expand All @@ -422,7 +423,9 @@ def saveImpl(

@staticmethod
def load(
metadata: Dict[str, Any], sc: "SparkContext", path: str
metadata: Dict[str, Any],
sc: Union["SparkContext", SparkSession],
path: str,
) -> Tuple[str, List["PipelineStage"]]:
"""
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
Expand Down
23 changes: 14 additions & 9 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TypeVar,
cast,
TYPE_CHECKING,
Union,
)

from pyspark import since
Expand Down Expand Up @@ -424,7 +425,7 @@ def __init__(self, instance: "Params"):
self.instance = instance

def saveImpl(self, path: str) -> None:
DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
DefaultParamsWriter.saveMetadata(self.instance, path, self.sparkSession)

@staticmethod
def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str, Any]:
Expand All @@ -438,7 +439,7 @@ def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str
def saveMetadata(
instance: "Params",
path: str,
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
extraMetadata: Optional[Dict[str, Any]] = None,
paramMap: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -464,15 +465,15 @@ def saveMetadata(
metadataJson = DefaultParamsWriter._get_metadata_to_save(
instance, sc, extraMetadata, paramMap
)
spark = SparkSession._getActiveSessionOrCreate()
spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate()
spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text(
metadataPath
)

@staticmethod
def _get_metadata_to_save(
instance: "Params",
sc: "SparkContext",
sc: Union["SparkContext", SparkSession],
extraMetadata: Optional[Dict[str, Any]] = None,
paramMap: Optional[Dict[str, Any]] = None,
) -> str:
Expand Down Expand Up @@ -560,27 +561,31 @@ def __get_class(clazz: str) -> Type[RL]:
return getattr(m, parts[-1])

def load(self, path: str) -> RL:
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
py_type: Type[RL] = DefaultParamsReader.__get_class(metadata["class"])
instance = py_type()
cast("Params", instance)._resetUid(metadata["uid"])
DefaultParamsReader.getAndSetParams(instance, metadata)
return instance

@staticmethod
def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> Dict[str, Any]:
def loadMetadata(
path: str,
sc: Union["SparkContext", SparkSession],
expectedClassName: str = "",
) -> Dict[str, Any]:
"""
Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`

Parameters
----------
path : str
sc : :py:class:`pyspark.SparkContext`
sc : :py:class:`pyspark.SparkContext` or :py:class:`pyspark.sql.SparkSession`
expectedClassName : str, optional
If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
spark = SparkSession._getActiveSessionOrCreate()
spark = sc if isinstance(sc, SparkSession) else SparkSession._getActiveSessionOrCreate()
metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index]
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
return loadedVals
Expand Down Expand Up @@ -641,7 +646,7 @@ def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool:
return metadata["class"].startswith("pyspark.ml.")

@staticmethod
def loadParamsInstance(path: str, sc: "SparkContext") -> RL:
def loadParamsInstance(path: str, sc: Union["SparkContext", SparkSession]) -> RL:
"""
Load a :py:class:`Params` instance from the given path, and return it.
This assumes the instance inherits from :py:class:`MLReadable`.
Expand Down