Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pathPrefix for modelExplain #1794

Merged
merged 2 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ package streaming.dsl.load.batch

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import tech.mlsql.common.utils.reflect.ClassPath
import tech.mlsql.dsl.adaptor.MLMapping
import tech.mlsql.tool.HDFSOperatorV2

import scala.collection.JavaConversions._

/**
* Created by allwefantasy on 21/9/2018.
Expand Down Expand Up @@ -219,7 +218,32 @@ class ModelExplain(format: String, path: String, option: Map[String, String])(sp
if ((format == "model" && path == "explain"))
option("path")
else {
path
option.get("index") match {
// If user specifies index, try to find _model_n subdirectory
case Some(idx) =>
val subDirs = HDFSOperatorV2.listFiles(path)
.filter( _.isDirectory )
val _modelDirExists = subDirs.exists( _.getPath.getName.startsWith("_model_"))
chncaesar marked this conversation as resolved.
Show resolved Hide resolved
val _modelIdxDirExists = subDirs.exists( s"_model_${idx}" == _.getPath.getName)
val metaDirExists = subDirs.exists( "meta" == _.getPath.getName)
val modelDirExists = subDirs.exists( "model" == _.getPath.getName)
if( _modelDirExists && _modelIdxDirExists) {
s"${path}/_model_${idx}"
}
else if( _modelDirExists && ! _modelIdxDirExists ) {
throw new RuntimeException(s"model directory with index ${idx} does not exist")
}
else if( metaDirExists && modelDirExists ) {
// `keepVersion`="false", index option is ignored.
path
}
else {
throw new RuntimeException(s"${path}/_model_${idx} does not exist")
}

// If keepVersion is not enabled
case None => path
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,16 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
// return the load table
table
}.getOrElse {
// calculate resource real absolute path
// path could be:
// 1) fileSystem path; code example: load modelExplain.`/tmp/model` where alg="RandomForest" as output;
// 2) ET name; code example: load modelExample.`JsonExpandExt` AS output_1; load modelParams.`JsonExpandExt` as output;
// For FileSystem path, pass the real path to ModelSelfExplain; for ET name pass original path
val resourcePath = resourceRealPath(scriptSQLExecListener, option.get("owner"), path)

table = ModelSelfExplain(format, cleanStr(path), option, sparkSession).isMatch.thenDo.orElse(() => {
val fsPathOrETName = format match {
case "modelExplain" => resourcePath
case _ => cleanStr(path)
}
table = ModelSelfExplain(format, fsPathOrETName, option, sparkSession).isMatch.thenDo.orElse(() => {
reader.format(format).load(resourcePath)
}).get
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
--%comparator=tech.mlsql.it.IgnoreResultComparator

set jsonStr='''
{"features":[5.1,3.5,1.4,0.2],"label":0.0},
{"features":[5.1,3.5,1.4,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[4.4,2.9,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[4.7,3.2,1.3,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
''';
load jsonStr.`jsonStr` as mock_data;


select vec_dense(features) as features, label as label from mock_data as mock_data_1;


train mock_data_1 as RandomForest.`/tmp/model` where
keepVersion="true"
and evaluateTable="mock_data_validate"

and `fitParam.0.labelCol`="label"
and `fitParam.0.featuresCol`="features"
and `fitParam.0.maxDepth`="2"

and `fitParam.1.featuresCol`="features"
and `fitParam.1.labelCol`="label"
and `fitParam.1.maxDepth`="10"
;

load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1;

select `name` from output_1 where name="uid" as result;

!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful";


load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" and index = "0" as output_2;
select `name` from output_2 where name="uid" as result_2;
!assert result_2 ''':name=="uid"''' "RandomForest modelExplain should be successful";
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,11 @@ select pred_func(features) as predict_label, label from data1 as output;

select name,value from model_result where name="status" as result;
-- make sure status of all models are success.
!assert result ''':value=="success"''' "all model status should be success";
!assert result ''':value=="success"''' "all model status should be success";


load modelExplain.`/tmp/linearregression` WHERE `alg`="LinearRegression" as lr_model_explain;

select `name` from lr_model_explain where name="uid" as result;

!assert result ''':name=="uid"''' "LinearRegression modelExplain should be successful";