From 1c0d4ebd30c9f03cb4ff84f645a9baad0fc5c9ef Mon Sep 17 00:00:00 2001 From: andiehuang Date: Tue, 23 Aug 2022 10:31:02 +0800 Subject: [PATCH] optimize DataSummaryET update codes update code update codes update DataSummary update codes update codes update codes revise the positions of metrics columns reduce an action in mode calculation rm modeformat rm modeformat update codes revise ut rm redundant file fix comments fix comments revise pom revise pom --- .../mlsql/plugins/mllib/app/MLSQLMllib.scala | 4 +- .../mllib/ets/fe/SQLDataSummaryV2.scala | 442 ++++++++++++++++++ .../mllib/ets/fe/SQLDataSummaryV2Test.scala | 96 ++++ 3 files changed, 540 insertions(+), 2 deletions(-) create mode 100644 mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala create mode 100644 mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala diff --git a/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala index 313b6a92..04be2a6b 100644 --- a/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala +++ b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/app/MLSQLMllib.scala @@ -5,7 +5,7 @@ import tech.mlsql.common.utils.log.Logging import tech.mlsql.dsl.CommandCollection import tech.mlsql.ets.register.ETRegister import tech.mlsql.plugins.mllib.ets._ -import tech.mlsql.plugins.mllib.ets.fe.{DataTranspose, OnehotExt, PSIExt, SQLDataSummary, SQLDescriptiveMetrics, SQLMissingValueProcess, SQLPatternDistribution, SQLUniqueIdentifier} +import tech.mlsql.plugins.mllib.ets.fe.{DataTranspose, OnehotExt, PSIExt, SQLDataSummary, SQLDataSummaryV2, SQLDescriptiveMetrics, SQLMissingValueProcess, SQLPatternDistribution, SQLUniqueIdentifier} import tech.mlsql.plugins.mllib.ets.fintech.scorecard.{SQLBinning, SQLScoreCard} import tech.mlsql.version.VersionCompatibility @@ -20,7 +20,7 @@ class MLSQLMllib extends tech.mlsql.app.App with VersionCompatibility with Loggi ETRegister.register("SampleDatasetExt", classOf[SampleDatasetExt].getName) ETRegister.register("TakeRandomSampleExt", classOf[TakeRandomSampleExt].getName) ETRegister.register("ColumnsExt", classOf[ColumnsExt].getName) - ETRegister.register("DataSummary", classOf[SQLDataSummary].getName) + ETRegister.register("DataSummary", classOf[SQLDataSummaryV2].getName) ETRegister.register("DataMissingValueProcess", classOf[SQLMissingValueProcess].getName) ETRegister.register("Binning", classOf[SQLBinning].getName) ETRegister.register("ScoreCard", classOf[SQLScoreCard].getName) diff --git a/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala new file mode 100644 index 00000000..6d94484e --- /dev/null +++ b/mlsql-mllib/src/main/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2.scala @@ -0,0 +1,442 @@ +package tech.mlsql.plugins.mllib.ets.fe + +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession, functions => F} +import streaming.dsl.ScriptSQLExec +import streaming.dsl.auth._ +import streaming.dsl.mmlib.algs.param.BaseParams +import streaming.dsl.mmlib.algs.{CodeExampleText, Functions, MllibFunctions} +import streaming.dsl.mmlib.{Code, SQLAlg, SQLCode} +import tech.mlsql.dsl.auth.ETAuth +import tech.mlsql.common.utils.log.Logging +import tech.mlsql.dsl.auth.dsl.mmlib.ETMethod.ETMethod + +import java.util.Date +import scala.util.{Failure, Success, Try} + +class SQLDataSummaryV2(override val uid: String) extends SQLAlg with MllibFunctions with Functions with BaseParams with ETAuth with Logging { + + def this() = this(BaseParams.randomUID()) + + var round_at = 2 + + var numericCols: Array[String] = null + + def colWithFilterBlank(sc: StructField): Column = { + val col_name = sc.name + sc.dataType match { + case DoubleType => col(col_name).isNotNull && !col(col_name).isNaN + case FloatType => col(col_name).isNotNull && !col(col_name).isNaN + case StringType => col(col_name).isNotNull && col(col_name) =!= "" + case _ => col(col_name).isNotNull + } + } + + + def countColsStdDevNumber(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + if (numeric_columns.contains(c)) { + val expr = stddev(when(colWithFilterBlank(sc), col(c))) + when(expr.isNull, lit("")).otherwise(expr).alias(c + "_standardDeviation") + } else { + lit("").alias(c + "_standardDeviation") + } + }).toArray + } + + def countColsStdErrNumber(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + if (numeric_columns.contains(c)) { + val expr = stddev(when(colWithFilterBlank(sc), col(c))) / sqrt(sum(when(colWithFilterBlank(sc), 1).otherwise(0))) + when(expr.isNull, lit("")).otherwise(expr).alias(c + "_standardError") + } else { + lit("").alias(c + "_standardError") + } + }).toArray + } + + def isPrimaryKey(schmea: StructType, approx: Boolean): Array[Column] = { + schmea.map(sc => { + val c = sc.name + val exp1 = if (approx) { + approx_count_distinct(when(colWithFilterBlank(sc), col(sc.name))) / sum(when(colWithFilterBlank(sc), 1).otherwise(0)) + } else { + countDistinct(when(colWithFilterBlank(sc), col(sc.name))) / sum(when(colWithFilterBlank(sc), 1).otherwise(0)) + } + when(exp1 === 1, 1).otherwise(0).alias(sc.name + "_primaryKeyCandidate") + }).toArray + } + + def countUniqueValueRatio(schema: StructType, approx: Boolean): Array[Column] = { + schema.map(sc => { + // TODO: + val sum_expr = sum(when(colWithFilterBlank(sc), 1).otherwise(0)) + val divide_expr = if (approx) { + approx_count_distinct(when(colWithFilterBlank(sc), col(sc.name))) / sum_expr + } else { + countDistinct(when(colWithFilterBlank(sc), col(sc.name))) / sum_expr + } + val ratio_expr = when(sum_expr === 0, 0.0).otherwise(divide_expr) + + ratio_expr.alias(sc.name + "_uniqueValueRatio") + }).toArray + } + + def getMaxNum(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + val max_expr = max(when(colWithFilterBlank(sc), col(c))) + when(max_expr.isNull, "").otherwise(max_expr.cast(StringType)).alias(c + "_max") + }).toArray + } + + def getMinNum(schema: StructType, numeric_columns: Array[String]): Array[Column] = { + schema.map(sc => { + val c = sc.name + val min_expr = min(when(colWithFilterBlank(sc), col(c))) + when(min_expr.isNull, "").otherwise(min_expr.cast(StringType)).alias(c + "_min") + }).toArray + } + + def roundAtSingleCol(sc: StructField, column: Column): Column = { + if (numericCols.contains(sc.name)) { + return round(column, round_at).cast(StringType) + } + column.cast(StringType) + } + + def processModeValue(modeCandidates: Array[Row], modeFormat: String): Any = { + val mode = if (modeCandidates.lengthCompare(2) >= 0) { + modeFormat match { + case ModeValueFormat.empty => "" + case ModeValueFormat.all => "[" + modeCandidates.map(_.get(0).toString).mkString(",") + "]" + case ModeValueFormat.auto => modeCandidates.head.get(0) + } + } else { + modeCandidates.head.get(0) + } + mode + } + + def isArrayString(mode: Any): Boolean = { + mode.toString.startsWith("[") && mode.toString.endsWith("]") + } + + def countNonNullValue(schema: StructType): Array[Column] = { + schema.map(sc => { + sum(when(col(sc.name).isNotNull, 1).otherwise(0)) + }).toArray + } + + def nullValueCount(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType match { + case DoubleType | FloatType => (sum(when(col(sc.name).isNull || col(sc.name).isNaN, 1).otherwise(0))) / (sum(lit(1))).alias(sc.name + "_nullValueRatio") + case _ => (sum(when(col(sc.name).isNull, 1).otherwise(0))) / (sum(lit(1))).alias(sc.name + "_nullValueRatio") + } + }).toArray + } + + def emptyCount(schema: StructType): Array[Column] = { + schema.map(sc => { + sum(when(col(sc.name) === "", 1).otherwise(0)) / sum(lit(1.0)).alias(sc.name + "_blankValueRatio") + }).toArray + } + + def getMaxLength(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType match { + case StringType => max(length(col(sc.name))).alias(sc.name + "maximumLength") + case _ => lit("").alias(sc.name + "maximumLength") + } + }).toArray + } + + def getMinLength(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType match { + case StringType => min(length(col(sc.name))).alias(sc.name + "minimumLength") + case _ => lit("").alias(sc.name + "minimumLength") + } + }).toArray + } + + + def getMeanValue(schema: StructType): Array[Column] = { + schema.map(sc => { + val new_col = if (numericCols.contains(sc.name)) { + val avgExp = avg(when(colWithFilterBlank(sc), col(sc.name))) + // val roundExp = round(avgExp, round_at) + when(avgExp.isNull, lit("")).otherwise(avgExp).alias(sc.name + "_mean") + } else { + lit("").alias(sc.name + "_mean") + } + new_col + }).toArray + } + + def getTypeLength(schema: StructType): Array[Column] = { + schema.map(sc => { + sc.dataType.typeName match { + case "byte" => lit(1L).alias(sc.name) + case "short" => lit(2L).alias(sc.name) + case "integer" => lit(4L).alias(sc.name) + case "long" => lit(8L).alias(sc.name) + case "float" => lit(4L).alias(sc.name) + case "double" => lit(8L).alias(sc.name) + case "string" => max(length(col(sc.name))).alias(sc.name) + case "date" => lit(8L).alias(sc.name) + case "timestamp" => lit(8L).alias(sc.name) + case "boolean" => lit(1L).alias(sc.name) + case name: String if name.contains("decimal") => first(lit(16L)).alias(sc.name) + case _ => lit("").alias(sc.name) + } + }).toArray + } + + def roundNumericCols(df: DataFrame, round_at: Integer): DataFrame = { + df.select(df.schema.map(sc => { + sc.dataType match { + case DoubleType => expr(s"cast (${sc.name} as decimal(38,2)) as ${sc.name}") + case FloatType => expr(s"cast (${sc.name} as decimal(38,2)) as ${sc.name}") + case _ => col(sc.name) + } + }): _*) + } + + def dataFormat(resRow: Array[Seq[Any]], metricsIdx: Map[Int, String], roundAt: Int): Array[Seq[Any]] = { + resRow.map(row => { + row.zipWithIndex.map(el => { + val e = el._1 + val round_at = metricsIdx.getOrElse(el._2 - 1, "") match { + case t if t.endsWith("Ratio") => roundAt + 2 + case _ => roundAt + } + var newE = e + try { + val v = e.toString.toDouble + newE = BigDecimal(v).setScale(round_at, BigDecimal.RoundingMode.HALF_UP).toDouble + } catch { + case e: Exception => logInfo(e.toString) + } + newE + }) + }) + } + + def getPercentileRows(metrics: Array[String], schema: StructType, df: DataFrame, relativeError: Double): (Array[Array[Double]], Array[String]) = { + var percentilePoints: Array[Double] = Array() + var percentileCols: Array[String] = Array() + if (metrics.contains("%25")) { + percentilePoints = percentilePoints :+ 0.25 + percentileCols = percentileCols :+ "%25" + } + if (metrics.contains("median")) { + percentilePoints = percentilePoints :+ 0.5 + percentileCols = percentileCols :+ "median" + } + if (metrics.contains("%75")) { + percentilePoints = percentilePoints :+ 0.75 + percentileCols = percentileCols :+ "%75" + } + + val cols = schema.map(sc => { + var res = lit(0.0).as(sc.name) + if (numericCols.contains(sc.name)) { + res = col(sc.name) + } + res + }).toArray + val quantileRows: Array[Array[Double]] = df.select(cols: _*).na.fill(0.0).stat.approxQuantile(df.columns, percentilePoints, relativeError) + (quantileRows, percentileCols) + } + + def processSelectedMetrics(metrics: Array[String]): Array[String] = { + val normalMetrics = "maximumLength,minimumLength,uniqueValueRatio,nullValueRatio,blankValueRatio,mean,standardDeviation,standardError,max,min,dataLength,primaryKeyCandidate".split(",") + val computedMetrics = "%25,median,%75".split(",") + val modeMetric = "mode".split(",") + var leftMetrics: Array[String] = Array() + var rightMetrics: Array[String] = Array() + var appendMetrics: Array[String] = Array() + metrics.map(m => { + m match { + case metric if normalMetrics.contains(metric) => leftMetrics = leftMetrics :+ metric + case metric if computedMetrics.contains(metric) => rightMetrics = rightMetrics :+ metric + case metric if modeMetric.contains(metric) => appendMetrics = appendMetrics :+ metric + case _ => require(false, "The selected metrics contains unkonwn calculation! " + m) + } + }) + leftMetrics ++ rightMetrics ++ appendMetrics + } + + def getModeValue(schema: StructType, df: DataFrame): Array[Any] = { + val mode = schema.toList.par.map(sc => { + val dfWithoutNa = df.select(col(sc.name)).na.drop() + val modeDF = dfWithoutNa.groupBy(col(sc.name)).count().orderBy(F.desc("count")).limit(2) + val modeList = modeDF.collect() + if (modeList.length != 0) { + modeList match { + case p if p.length < 2 => p(0).get(0) + case p if p(0).get(1) == p(1).get(1) => "" + case _ => modeList(0).get(0) + } + } else { + "" + } + }).toArray + mode + } + + def train(df: DataFrame, path: String, params: Map[String, String]): DataFrame = { + + round_at = Integer.valueOf(params.getOrElse("roundAt", "2")) + + val metrics = params.getOrElse(DataSummary.metrics, "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode").split(",").filter(!_.equalsIgnoreCase("")) + val relativeError = params.getOrElse("relativeError", "0.01").toDouble + val approxCountDistinct = params.getOrElse("approxCountDistinct", "false").toBoolean + val repartitionDF = df + val columns = repartitionDF.columns + + columns.map(col => { + if (col.contains(".") || col.contains("`")) { + throw new RuntimeException(s"The column name : ${col} contains special symbols, like . or `, please rename it first!! ") + } + }) + + var start_time = new Date().getTime + numericCols = repartitionDF.schema.filter(sc => { + sc.dataType.typeName match { + case datatype: String => Array("integer", "short", "double", "float", "long").contains(datatype) || datatype.contains("decimal") + case _ => false + } + }).map(sc => { + sc.name + }).toArray + val schema = repartitionDF.schema + + val default_metrics = Map( + "dataLength" -> getTypeLength(schema), + "max" -> getMaxNum(schema, numericCols), + "min" -> getMinNum(schema, numericCols), + "maximumLength" -> getMaxLength(schema), + "minimumLength" -> getMinLength(schema), + "mean" -> getMeanValue(schema), + "standardDeviation" -> countColsStdDevNumber(schema, numericCols), + "standardError" -> countColsStdErrNumber(schema, numericCols), + "nullValueRatio" -> nullValueCount(schema), + "blankValueRatio" -> emptyCount(schema), + "uniqueValueRatio" -> countUniqueValueRatio(schema, approxCountDistinct), + "primaryKeyCandidate" -> isPrimaryKey(schema, approxCountDistinct), + ) + val processedSelectedMetrics = processSelectedMetrics(metrics) + val newCols = processedSelectedMetrics.map(name => default_metrics.getOrElse(name, null)).filter(_ != null).flatMap(arr => arr).toArray + val metricsIdx = processedSelectedMetrics.zipWithIndex.map(t => { + (t._2, t._1) + }).toMap + var resDF = repartitionDF.select(newCols: _*) + logInfo(s"normal metrics plan:\n${resDF.explain(true)}") + val rows = resDF.collect() + val rowN = schema.length + val ordinaryPosRow = df.columns.map(col_name => String.valueOf(df.columns.indexOf(col_name) + 1)).toSeq + val normalMetricsRow = (ordinaryPosRow ++ rows(0).toSeq).grouped(rowN).map(_.toSeq).toArray.toSeq.transpose + var end_time = new Date().getTime + + logInfo("The elapsed time for normal metrics is : " + (end_time - start_time)) + + // Calculate Percentile + start_time = new Date().getTime + val (quantileRows, quantileCols) = getPercentileRows(processedSelectedMetrics, schema, df, relativeError) + end_time = new Date().getTime + logInfo("The elapsed time for percentile metrics is: " + (end_time - start_time)) + + var datatype_schema: Array[StructField] = null + var resRows: Array[Seq[Any]] = null + quantileCols.length match { + case 0 => + resRows = Range(0, schema.length).map(i => { + Seq(schema(i).name) ++ normalMetricsRow(i) + }).toArray + case _ => + resRows = Range(0, schema.length).map(i => { + Seq(schema(i).name) ++ normalMetricsRow(i) ++ quantileRows(i).toSeq + }).toArray + } + datatype_schema = ("ColumnName" +: "ordinaryPosition" +: processedSelectedMetrics).map(t => { + StructField(t, StringType) + }) + + start_time = new Date().getTime + // Calculate Mode + if (processedSelectedMetrics.contains("mode")) { + val modeRows = getModeValue(schema, df) + resRows = Range(0, schema.length).map(i => { + resRows(i) :+ modeRows(i) + }).toArray + end_time = new Date().getTime + logInfo("The elapsed time for mode metric is: " + (end_time - start_time)) + } + + + resRows = dataFormat(resRows, metricsIdx, round_at) + val resAfterTransformed = resRows.map(row => { + val t = row.map(e => String.valueOf(e)) + t + }) + val spark = df.sparkSession + spark.createDataFrame(spark.sparkContext.parallelize(resAfterTransformed.map(Row.fromSeq(_)), 1), StructType(datatype_schema)) + } + + override def load(sparkSession: SparkSession, path: String, params: Map[String, String]): Any = { + } + + override def predict(sparkSession: SparkSession, _model: Any, name: String, params: Map[String, String]): UserDefinedFunction = ??? + + override def batchPredict(df: DataFrame, path: String, params: Map[String, String]): DataFrame = + train(df, path, params) + + override def codeExample: Code = Code(SQLCode, CodeExampleText.jsonStr + + """ + | + |set abc=''' + |{"name": "elena", "age": 57, "phone": 15552231521, "income": 433000, "label": 0} + |{"name": "candy", "age": 67, "phone": 15552231521, "income": 1200, "label": 0} + |{"name": "bob", "age": 57, "phone": 15252211521, "income": 89000, "label": 0} + |{"name": "candy", "age": 25, "phone": 15552211522, "income": 36000, "label": 1} + |{"name": "candy", "age": 31, "phone": 15552211521, "income": 300000, "label": 1} + |{"name": "finn", "age": 23, "phone": 15552211521, "income": 238000, "label": 1} + |'''; + | + |load jsonStr.`abc` as table1; + |select age, income from table1 as table2; + |run table2 as DataSummary.`` as summaryTable; + |; + """.stripMargin) + + + override def auth(etMethod: ETMethod, path: String, params: Map[String, String]): List[TableAuthResult] = { + val vtable = MLSQLTable( + Option(DB_DEFAULT.MLSQL_SYSTEM.toString), + Option("__fe_data_summary_operator__"), + OperateType.SELECT, + Option("select"), + TableType.SYSTEM) + + val context = ScriptSQLExec.contextGetOrForTest() + context.execListener.getTableAuth match { + case Some(tableAuth) => + tableAuth.auth(List(vtable)) + case None => + List(TableAuthResult(granted = true, "")) + } + } +} + +object ModeValueFormat { + val all = "all" + val empty = "empty" + val auto = "auto" +} \ No newline at end of file diff --git a/mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala b/mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala new file mode 100644 index 00000000..6ff6a5de --- /dev/null +++ b/mlsql-mllib/src/test/java/tech/mlsql/plugins/mllib/ets/fe/SQLDataSummaryV2Test.scala @@ -0,0 +1,96 @@ +package tech.mlsql.plugins.mllib.ets.fe + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.{col, explode, struct, desc} +import org.apache.spark.streaming.SparkOperationUtil +import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers} +import streaming.core.strategy.platform.SparkRuntime +import tech.mlsql.test.BasicMLSQLConfig + +import java.sql.Timestamp +import java.time.LocalDateTime +import java.util.{Date, UUID} + +/** + * + * @Author; Andie Huang + * @Date: 2022/6/27 19:07 + * + */ +class SQLDataSummaryV2Test extends FlatSpec with SparkOperationUtil with Matchers with BasicMLSQLConfig with BeforeAndAfterAll { + def startParams = Array( + "-streaming.master", "local[*]", + "-streaming.name", "unit-test", + "-streaming.rest", "false", + "-streaming.platform", "spark", + "-streaming.enableHiveSupport", "false", + "-streaming.hive.javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=metastore_db/${UUID.randomUUID().toString};create=true", + "-streaming.spark.service", "false", + "-streaming.unittest", "true", + "-spark.sql.shuffle.partitions", "12", + "-spark.default.parallelism", "12", + "-spark.executor.memoryOverheadFactor", "0.2", + "-spark.dirver.maxResultSize", "2g" + ) + + "DataSummary" should "Summarize the Dataset" in { + withBatchContext(setupBatchContext(startParams)) { runtime: SparkRuntime => + implicit val spark: SparkSession = runtime.sparkSession + val et = new SQLDataSummaryV2() + val sseq1 = Seq( + ("elena", 57, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("abe", 50, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("AA", 10, "432000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("cc", 40, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("", 30, "434000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))), + ("bb", 21, "533000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0))) + ) + val seq_df1 = spark.createDataFrame(sseq1).toDF("name", "age", "income", "date") + val res1DF = et.train(seq_df1, "", Map("atRound" -> "2", "metrics" -> "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode")) + res1DF.show() + assert(res1DF.collect()(0).mkString(",") === "name,1.0,5.0,elena,AA,5.0,0.0,,,,0.0,0.1667,1.0,1.0,0.0,") + assert(res1DF.collect()(1).mkString(",") === "age,2.0,4.0,57.0,10.0,,,34.67,17.77,7.2556,0.0,0.0,1.0,1.0,30.0,") + assert(res1DF.collect()(2).mkString(",") === "income,3.0,6.0,533000.0,432000.0,6.0,6.0,,,,0.0,0.0,0.67,0.0,0.0,433000.0") + assert(res1DF.collect()(3).mkString(",") === "date,4.0,8.0,2021-03-08 18:00:00,2021-03-08 18:00:00,,,,,,0.0,0.0,0.17,0.0,0.0,2021-03-08 18:00:00.0") + val sseq = Seq( + ("elena", 57, 57, 110L, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 110F, true, null, null, BigDecimal.valueOf(12), 1.123D), + ("abe", 57, 50, 120L, "433000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 120F, true, null, null, BigDecimal.valueOf(2), 1.123D), + ("AA", 57, 10, 130L, "432000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 130F, true, null, null, BigDecimal.valueOf(2), 2.224D), + ("cc", 0, 40, 100L, "", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), Float.NaN, true, null, null, BigDecimal.valueOf(2), 2D), + ("", -1, 30, 150L, "434000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), 150F, true, null, null, BigDecimal.valueOf(2), 3.375D), + ("bb", 57, 21, 160L, "533000", Timestamp.valueOf(LocalDateTime.of(2021, 3, 8, 18, 0)), Float.NaN, false, null, null, BigDecimal.valueOf(2), 3.375D) + ) + val seq_df = spark.createDataFrame(sseq).toDF("name", "favoriteNumber", "age", "mock_col1", "income", "date", "mock_col2", "alived", "extra", "extra1", "extra2", "extra3") + val res2DF = et.train(seq_df, "", Map("atRound" -> "2", "metrics" -> "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode")) + res2DF.show() + assert(res2DF.collect()(0).mkString(",") === "name,1.0,5.0,elena,AA,5.0,0.0,,,,0.0,0.1667,1.0,1.0,0.0,") + assert(res2DF.collect()(1).mkString(",") === "favoriteNumber,2.0,4.0,57.0,-1.0,,,37.83,29.69,12.1228,0.0,0.0,0.5,0.0,57.0,57.0") + assert(res2DF.collect()(2).mkString(",") === "age,3.0,4.0,57.0,10.0,,,34.67,17.77,7.2556,0.0,0.0,1.0,1.0,30.0,") + assert(res2DF.collect()(3).mkString(",") === "mock_col1,4.0,8.0,160.0,100.0,,,128.33,23.17,9.4575,0.0,0.0,1.0,1.0,120.0,") + assert(res2DF.collect()(4).mkString(",") === "income,5.0,6.0,533000.0,432000.0,6.0,0.0,,,,0.0,0.1667,0.8,0.0,0.0,433000.0") + assert(res2DF.collect()(5).mkString(",") === "date,6.0,8.0,2021-03-08 18:00:00,2021-03-08 18:00:00,,,,,,0.0,0.0,0.17,0.0,0.0,2021-03-08 18:00:00.0") + assert(res2DF.collect()(6).mkString(",") === "mock_col2,7.0,4.0,150.0,110.0,,,127.5,17.08,8.5391,0.3333,0.0,1.0,1.0,110.0,") + assert(res2DF.collect()(7).mkString(",") === "alived,8.0,1.0,true,false,,,,,,0.0,0.0,0.33,0.0,0.0,true") + assert(res2DF.collect()(8).mkString(",") === "extra,9.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + assert(res2DF.collect()(9).mkString(",") === "extra1,10.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + assert(res2DF.collect()(10).mkString(",") === "extra2,11.0,16.0,12.0,2.0,,,3.67,4.08,1.6667,0.0,0.0,0.33,0.0,2.0,2.0") + assert(res2DF.collect()(11).mkString(",") === "extra3,12.0,8.0,3.38,1.12,,,2.2,1.01,0.4132,0.0,0.0,0.67,0.0,2.0,") + val sseq2 = Seq( + (null, null), + (null, null) + ) + val seq_df2 = spark.createDataFrame(sseq2).toDF("col1", "col2") + val res3DF = et.train(seq_df2, "", Map("atRound" -> "2", "metrics" -> "dataLength,max,min,maximumLength,minimumLength,mean,standardDeviation,standardError,nullValueRatio,blankValueRatio,uniqueValueRatio,primaryKeyCandidate,median,mode")) + res3DF.show() + assert(res3DF.collect()(0).mkString(",") === "col1,1.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + assert(res3DF.collect()(1).mkString(",") === "col2,2.0,,,,,,,,,1.0,0.0,0.0,0.0,0.0,") + // val paquetDF1 = spark.sqlContext.read.format("parquet").load("/Users/yonghui.huang/Data/benchmarkZL1") + // val paquetDF2 = paquetDF1.sample(true, 1) + // println(paquetDF2.count()) + // val df1 = et.train(paquetDF2, "", Map("atRound" -> "2", "relativeError" -> "0.01")) + // df1.show() + // val df2 = et.train(paquetDF2, "", Map("atRound" -> "2", "approxCountDistinct" -> "true")) + // df2.show() + } + } +} \ No newline at end of file