Skip to content

Commit

Permalink
Fix issue1649
Browse files Browse the repository at this point in the history
  • Loading branch information
chncaesar committed Jul 10, 2022
1 parent 445b081 commit f536732
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ package streaming.dsl.load.batch

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import tech.mlsql.common.utils.reflect.ClassPath
import tech.mlsql.dsl.adaptor.MLMapping

import scala.collection.JavaConversions._

/**
* Created by allwefantasy on 21/9/2018.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,16 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
// return the load table
table
}.getOrElse {
// calculate resource real absolute path
// path could be:
// 1) fileSystem path; code example: load modelExplain.`/tmp/model` where alg="RandomForest" as output;
// 2) ET name; code example: load modelExample.`JsonExpandExt` AS output_1; load modelParams.`JsonExpandExt` as output;
// For FileSystem path, pass the real path to ModelSelfExplain; for ET name pass original path
val resourcePath = resourceRealPath(scriptSQLExecListener, option.get("owner"), path)

table = ModelSelfExplain(format, cleanStr(path), option, sparkSession).isMatch.thenDo.orElse(() => {
val fsPathOrETName = format match {
case "modelExplain" => resourceRealPath(scriptSQLExecListener, option.get("owner"), path)
case _ => cleanStr(path)
}
table = ModelSelfExplain(format, fsPathOrETName, option, sparkSession).isMatch.thenDo.orElse(() => {
reader.format(format).load(resourcePath)
}).get
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
--%comparator=tech.mlsql.it.IgnoreResultComparator

set jsonStr='''
{"features":[5.1,3.5,1.4,0.2],"label":0.0},
{"features":[5.1,3.5,1.4,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[4.4,2.9,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[4.7,3.2,1.3,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
''';
load jsonStr.`jsonStr` as mock_data;


select vec_dense(features) as features, label as label from mock_data as mock_data_1;


train mock_data_1 as RandomForest.`/tmp/model` where
keepVersion="true"
and evaluateTable="mock_data_validate"

and `fitParam.0.labelCol`="label"
and `fitParam.0.featuresCol`="features"
and `fitParam.0.maxDepth`="2"

and `fitParam.1.featuresCol`="features"
and `fitParam.1.labelCol`="label"
and `fitParam.1.maxDepth`="10"
;

load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1;

select `name` from output_1 where name="uid" as result;

!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful";
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,11 @@ select pred_func(features) as predict_label, label from data1 as output;

select name,value from model_result where name="status" as result;
-- make sure status of all models are success.
!assert result ''':value=="success"''' "all model status should be success";
!assert result ''':value=="success"''' "all model status should be success";


load modelExplain.`/tmp/linearregression` WHERE `alg`="LinearRegression" as lr_model_explain;

select `name` from lr_model_explain where name="uid" as result;

!assert result ''':name=="uid"''' "LinearRegression modelExplain should be successful";

0 comments on commit f536732

Please sign in to comment.