diff --git a/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala b/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala index 042665605..1a5862837 100644 --- a/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala +++ b/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala @@ -20,10 +20,8 @@ package streaming.dsl.load.batch import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import tech.mlsql.common.utils.reflect.ClassPath import tech.mlsql.dsl.adaptor.MLMapping -import scala.collection.JavaConversions._ /** * Created by allwefantasy on 21/9/2018. diff --git a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala index eb235760e..1be5a5ae4 100644 --- a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala +++ b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala @@ -136,10 +136,16 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener, // return the load table table }.getOrElse { - // calculate resource real absolute path + // path could be: + // 1) fileSystem path; code example: load modelExplain.`/tmp/model` where alg="RandomForest" as output; + // 2) ET name; code example: load modelExample.`JsonExpandExt` AS output_1; load modelParams.`JsonExpandExt` as output; + // For FileSystem path, pass the real path to ModelSelfExplain; for ET name pass original path val resourcePath = resourceRealPath(scriptSQLExecListener, option.get("owner"), path) - - table = ModelSelfExplain(format, cleanStr(path), option, sparkSession).isMatch.thenDo.orElse(() => { + val fsPathOrETName = format match { + case "modelExplain" => resourcePath + case _ => cleanStr(path) + } + table = ModelSelfExplain(format, fsPathOrETName, option, sparkSession).isMatch.thenDo.orElse(() => { reader.format(format).load(resourcePath) }).get } diff --git a/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql b/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql new file mode 100644 index 000000000..d7dc057ac --- /dev/null +++ b/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql @@ -0,0 +1,38 @@ +--%comparator=tech.mlsql.it.IgnoreResultComparator + +set jsonStr=''' +{"features":[5.1,3.5,1.4,0.2],"label":0.0}, +{"features":[5.1,3.5,1.4,0.2],"label":1.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[4.4,2.9,1.4,0.2],"label":0.0} +{"features":[5.1,3.5,1.4,0.2],"label":1.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[4.7,3.2,1.3,0.2],"label":1.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +'''; +load jsonStr.`jsonStr` as mock_data; + + +select vec_dense(features) as features, label as label from mock_data as mock_data_1; + + +train mock_data_1 as RandomForest.`/tmp/model` where +keepVersion="true" +and evaluateTable="mock_data_validate" + +and `fitParam.0.labelCol`="label" +and `fitParam.0.featuresCol`="features" +and `fitParam.0.maxDepth`="2" + +and `fitParam.1.featuresCol`="features" +and `fitParam.1.labelCol`="label" +and `fitParam.1.maxDepth`="10" +; + +load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1; + +select `name` from output_1 where name="uid" as result; + +!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful"; \ No newline at end of file diff --git a/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql b/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql index 1381a94fb..a6cbb2907 100644 --- a/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql +++ b/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql @@ -48,4 +48,11 @@ select pred_func(features) as predict_label, label from data1 as output; select name,value from model_result where name="status" as result; -- make sure status of all models are success. -!assert result ''':value=="success"''' "all model status should be success"; \ No newline at end of file +!assert result ''':value=="success"''' "all model status should be success"; + + +load modelExplain.`/tmp/linearregression` WHERE `alg`="LinearRegression" as lr_model_explain; + +select `name` from lr_model_explain where name="uid" as result; + +!assert result ''':name=="uid"''' "LinearRegression modelExplain should be successful"; \ No newline at end of file