Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-273] adjust data cache and make more efficient #281

Merged
merged 19 commits into from
Apr 28, 2023
Merged
1 change: 0 additions & 1 deletion mllib-dal/src/main/scala/com/intel/oap/mllib/OneDAL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,6 @@ object OneDAL {
.coalesce(executorNum,
partitionCoalescer = Some(new ExecutorInProcessCoalescePartitioner()))
.setName("coalescedRdd")
.cache()

// convert RDD to HomogenTable
val coalescedTables = coalescedRdd.mapPartitionsWithIndex { (index: Int, it: Iterator[Vector]) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.mllib.feature
import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.sql._
import org.apache.spark.storage.StorageLevel

/**
* PCA trains a model to project vectors to a lower dimensional space of the top `PCA!.k`
Expand All @@ -50,11 +51,15 @@ class PCA @Since("1.5.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): PCAModel = {
transformSchema(dataset.schema, logging = true)
val handlePersistence = (dataset.storageLevel == StorageLevel.NONE)
val input = dataset.select($(inputCol)).rdd
val inputVectors = input.map {
case Row(v: Vector) => v
}

if (handlePersistence) {
inputVectors.persist(StorageLevel.MEMORY_AND_DISK)
inputVectors.count()
}
val numFeatures = inputVectors.first().size
require($(k) <= numFeatures,
s"source vector size $numFeatures must be no less than k=$k")
Expand All @@ -68,6 +73,9 @@ class PCA @Since("1.5.0") (
val executor_cores = Utils.sparkExecutorCores()
val pca = new PCADALImpl(k = $(k), executor_num, executor_cores)
val pcaDALModel = pca.train(inputVectors)
if (handlePersistence) {
inputVectors.unpersist()
}
new OldPCAModel(pcaDALModel.k, pcaDALModel.pc, pcaDALModel.explainedVariance)
} else {
val inputOldVectors = inputVectors.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ class Correlation extends CorrelationShim {
dataset.sparkSession.sparkContext)
if (Utils.isOAPEnabled() && isPlatformSupported && method == "pearson") {
val handlePersistence = (dataset.storageLevel == StorageLevel.NONE)
if (handlePersistence) {
dataset.persist(StorageLevel.MEMORY_AND_DISK)
}
val rdd = dataset.select(column).rdd.map {
case Row(v: Vector) => v
}
if (handlePersistence) {
rdd.persist(StorageLevel.MEMORY_AND_DISK)
rdd.count()
}
val executor_num = Utils.sparkExecutorNum(dataset.sparkSession.sparkContext)
val executor_cores = Utils.sparkExecutorCores()
val matrix = new CorrelationDALImpl(executor_num, executor_cores)
Expand All @@ -88,7 +89,7 @@ class Correlation extends CorrelationShim {
val schema = StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = false)))
val dataframe = dataset.sparkSession.createDataFrame(Seq(Row(matrix)).asJava, schema)
if (handlePersistence) {
dataset.unpersist()
rdd.unpersist()
}
dataframe
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,19 @@ class Statistics extends SummarizerShim {
X.sparkContext)
if (Utils.isOAPEnabled() && isPlatformSupported) {
val handlePersistence = (X.getStorageLevel == StorageLevel.NONE)
if (handlePersistence) {
X.persist(StorageLevel.MEMORY_AND_DISK)
}
val rdd = X.map {
v => v.asML
}
if (handlePersistence) {
rdd.persist(StorageLevel.MEMORY_AND_DISK)
rdd.count()
}
val executor_num = Utils.sparkExecutorNum(X.sparkContext)
val executor_cores = Utils.sparkExecutorCores()
val summary = new SummarizerDALImpl(executor_num, executor_cores)
.computeSummarizerMatrix(rdd)
if (handlePersistence) {
X.unpersist()
rdd.unpersist()
}
summary
} else {
Expand Down