Skip to content

Commit

Permalink
optimized datasummary et
Browse files Browse the repository at this point in the history
  • Loading branch information
ckeys committed Aug 23, 2022
1 parent b6bea28 commit f6a4703
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ class SQLDataSummary(override val uid: String) extends SQLAlg with MllibFunction
}
}): _*)

val mode_df = repartitionDF.select(getModeNum(repartitionDF.schema, numericCols, repartitionDF, modeFormat): _*).select(lit("mode").alias("metric"), col("*"))
// val mode_df = repartitionDF.select(getModeNum(repartitionDF.schema, numericCols, repartitionDF, modeFormat): _*).select(lit("mode").alias("metric"), col("*"))
val maxlength_df = repartitionDF.select(getMaxLength(repartitionDF.schema): _*).select(lit("maximumLength").alias("metric"), col("*"))
val minlength_df = repartitionDF.select(getMinLength(repartitionDF.schema): _*).select(lit("minimumLength").alias("metric"), col("*"))

Expand Down Expand Up @@ -456,8 +456,7 @@ class SQLDataSummary(override val uid: String) extends SQLAlg with MllibFunction


val colunm_idx = Seq("ordinalPosition" +: repartitionDF.columns.map(col_name => String.valueOf(repartitionDF.columns.indexOf(col_name) + 1))).map(Row.fromSeq(_))
var numeric_metric_df = mode_df
.union(distinct_proportion_df)
var numeric_metric_df = distinct_proportion_df
.union(null_value_proportion_df)
.union(empty_value_proportion_df)
.union(mean_df)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package tech.mlsql.plugins.mllib.ets.fe

import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.SparkOperationUtil
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import streaming.core.strategy.platform.SparkRuntime
import tech.mlsql.test.BasicMLSQLConfig

import java.sql.Timestamp
import java.time.LocalDateTime
import java.util.{Date, UUID}

/**
*
* @Author; Andie Huang
* @Date: 2022/6/27 19:07
*
*/
class SQLDataSummaryTTest extends FlatSpec with SparkOperationUtil with Matchers with BasicMLSQLConfig with BeforeAndAfterAll {
def startParams = Array(
"-streaming.master", "local[2]",
"-streaming.name", "unit-test",
"-streaming.rest", "false",
"-streaming.platform", "spark",
"-streaming.enableHiveSupport", "false",
"-streaming.hive.javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=metastore_db/${UUID.randomUUID().toString};create=true",
"-streaming.spark.service", "false",
"-streaming.unittest", "true",
"-spark.sql.shuffle.partitions","12",
"-spark.default.parallelism","12",
"-spark.driver.memory","14g",
"-spark.executor.memoryOverheadFactor","0.2",
"-spark.memory.offHeap.enable","true",
"-spark.memory.offHeap.size","2g",
)

"DataSummary" should "Summarize the Dataset" in {
withBatchContext(setupBatchContext(startParams)) { runtime: SparkRuntime =>
implicit val spark: SparkSession = runtime.sparkSession
val et = new SQLDataSummaryV2()

// val sseq = Seq(
// ("elena", 57, 57, 110L, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 110F, true, null, null, BigDecimal.valueOf(12),1.123D),
// ("abe", 57, 50, 120L, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 120F, true, null, null, BigDecimal.valueOf(2), 1.123D),
// ("AA", 57, 10, 130L, "432000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 130F, true, null, null, BigDecimal.valueOf(2),2.224D),
// ("cc", 0, 40, 100L, "", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), Float.NaN, true, null, null, BigDecimal.valueOf(2),2D),
// ("", -1, 30, 150L, "434000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 150F, true, null, null, BigDecimal.valueOf(2),3.375D),
// ("bb", 57, 21, 160L, "533000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), Float.NaN, false, null, null, BigDecimal.valueOf(2),3.375D)
// )
// var seq_df = spark.createDataFrame(sseq).toDF("name", "favoriteNumber", "age", "mock_col1", "income", "date", "mock_col2", "alived", "extra", "extra1", "extra2","extra3")
var start_time = new Date().getTime
val paquetDF = spark.sqlContext.read.format("parquet").load("/Users/yonghui.huang/Data/benchmark2")
// et.train(paquetDF, "", Map("atRound" -> "2")).show()
var end_time = new Date().getTime
println((end_time-start_time))

start_time = new Date().getTime
val df = et.train(paquetDF, "", Map("atRound" -> "2","approxSwitch"->"true"))
end_time = new Date().getTime
println((end_time-start_time))
df.show()
end_time = new Date().getTime
println((end_time-start_time))
}
}
}

0 comments on commit f6a4703

Please sign in to comment.