diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fbb9e7fcdd86..117a0e93864d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1659,8 +1659,7 @@ def _fit(self, dataset): multiclassLabeled = dataset.select(labelCol, featuresCol) # persist if underlying dataset is not persistent. - handlePersistence = \ - dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False) + handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False) if handlePersistence: multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) @@ -1814,8 +1813,7 @@ def _transform(self, dataset): newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]])) # persist if underlying dataset is not persistent. - handlePersistence = \ - dataset.rdd.getStorageLevel() == StorageLevel(False, False, False, False) + handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False) if handlePersistence: newDataset.persist(StorageLevel.MEMORY_AND_DISK)