Skip to content

Commit

Permalink
feat(validate): update DL fromat model validate (#151)
Browse files Browse the repository at this point in the history
* feat(example): rename graphdef model

* feat(graphdef): update graphdef model checker:

* feat(validate): update framerwork validate

* fix(validate): validate suffix func

* feat(validate): validate ML format for exactly one legal file

Co-authored-by: judgeeeeee <yf@caicloud.io>
  • Loading branch information
judgeeeeee and judgeeeeee authored Nov 26, 2020
1 parent 6039392 commit e4867e6
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 69 deletions.
File renamed without changes.
145 changes: 76 additions & 69 deletions pkg/model/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"os"
"path"
"strings"
)

// Format is the definition of model format.
Expand Down Expand Up @@ -77,48 +78,55 @@ func (f Format) ValidateDirectory(rootPath string) error {
return nil
}

func ValidateError(modelPath string, modelName string, modelNum int32) error {
if modelNum != 1 {
return fmt.Errorf("Expected one %v file in %v directory, but found %v .", modelName, modelPath, modelNum)
}
return nil
}

func (f Format) validateForSavedModel(modelPath string, files []os.FileInfo) error {
var pbFileFlag bool
var variablesDirFlag bool
var pbFileNum int32
var variablesDirNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".pb" {
pbFileFlag = true
if file.Name() == "saved_model.pb" {
pbFileNum++
}
if file.IsDir() && file.Name() == "variables" {
variablesDirFlag = true
variablesDirNum++
}
}
if !pbFileFlag {
return fmt.Errorf("there are no *.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "saved_model.pb", pbFileNum); e != nil {
return e
}
if !variablesDirFlag {
return fmt.Errorf("there are no variables dir in %v directory", modelPath)
if e := ValidateError(modelPath, "variables", variablesDirNum); e != nil {
return e
}
return nil
}

func (f Format) validateForONNX(modelPath string, files []os.FileInfo) error {
var onnxFileFlag bool
var onnxFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".onnx" {
onnxFileFlag = true
onnxFileNum++
}
}
if !onnxFileFlag {
return fmt.Errorf("there are no *.onnx file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.onnx", onnxFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForH5(modelPath string, files []os.FileInfo) error {
var h5FileFlag bool
var h5FileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".h5" {
h5FileFlag = true
h5FileNum++
}
}
if !h5FileFlag {
return fmt.Errorf("there are no *.h5 file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.h5", h5FileNum); e != nil {
return e
}
return nil
}
Expand All @@ -135,141 +143,140 @@ func (f Format) validateForPMML(modelPath string, files []os.FileInfo) error {
}

func (f Format) validateForCaffeModel(modelPath string, files []os.FileInfo) error {
var caffeModelFileFlag bool
var prototxtFileFlag bool
var caffeModelFileNum int32
var prototxtFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".caffemodel" {
caffeModelFileFlag = true
caffeModelFileNum++
}
if path.Ext(file.Name()) == ".prototxt" {
prototxtFileFlag = true
prototxtFileNum++
}
}
if !caffeModelFileFlag {
return fmt.Errorf("there are no *.caffemodel file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.caffemodel", caffeModelFileNum); e != nil {
return e
}
if !prototxtFileFlag {
return fmt.Errorf("there are no *.prototxt file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.prototxt", prototxtFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForNetDef(modelPath string, files []os.FileInfo) error {
var initFileFlag bool
var predictFileFlag bool
var initFileNum int32
var predictFileNum int32
for _, file := range files {
if file.Name() == "init_net.pb" {
initFileFlag = true
initFileNum++
}
if file.Name() == "predict_net.pb" {
predictFileFlag = true
predictFileNum++
}
}
if !initFileFlag {
return fmt.Errorf("there are no init_net.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "init_net.pb", initFileNum); e != nil {
return e
}
if !predictFileFlag {
return fmt.Errorf("there are no predict_net.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "predict_net.pb", predictFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForMXNETParams(modelPath string, files []os.FileInfo) error {
var jsonFileFlag bool
var paramsFileFlag bool
var jsonFileNum int32
var paramsFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".json" {
jsonFileFlag = true
if strings.HasSuffix(file.Name(), "symbol.json") {
jsonFileNum++
}
if path.Ext(file.Name()) == ".params" {
paramsFileFlag = true
paramsFileNum++
}
}
if !jsonFileFlag {
return fmt.Errorf("there are no *.json file in %v directory", modelPath)
if e := ValidateError(modelPath, "*symbol.json", jsonFileNum); e != nil {
return e
}
if !paramsFileFlag {
return fmt.Errorf("there are no *.params file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.params", paramsFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForTorchScript(modelPath string, files []os.FileInfo) error {
var ptFileFlag bool
var ptFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".pt" {
ptFileFlag = true
ptFileNum++
}
}
if !ptFileFlag {
return fmt.Errorf("there are no *.pt file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.pt", ptFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForGraphDef(modelPath string, files []os.FileInfo) error {
var pbFileFlag bool
var graphdefFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".pb" {
pbFileFlag = true
break
if path.Ext(file.Name()) == ".graphdef" {
graphdefFileNum++
}
}
if !pbFileFlag {
return fmt.Errorf("there are no *.pb file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.graphdef", graphdefFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForTensorRT(modelPath string, files []os.FileInfo) error {
var tensorrtFileFlag bool
var tensorrtFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".plan" {
tensorrtFileFlag = true
if path.Ext(file.Name()) == ".plan" || path.Ext(file.Name()) == ".engine" {
tensorrtFileNum++
}
}
if !tensorrtFileFlag {
return fmt.Errorf("there are no *.plan file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.plan or *.engine", tensorrtFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForSKLearn(modelPath string, files []os.FileInfo) error {
var sklearnFileFlag bool
var sklearnFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".joblib" {
sklearnFileFlag = true
sklearnFileNum++
}
}
if !sklearnFileFlag {
return fmt.Errorf("there are no *.joblib file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.joblib", sklearnFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForXGBoost(modelPath string, files []os.FileInfo) error {
var xgboostFileFlag bool
var xgboostFileNum int32
for _, file := range files {
if path.Ext(file.Name()) == ".xgboost" {
xgboostFileFlag = true
xgboostFileNum++
}
}
if !xgboostFileFlag {
return fmt.Errorf("there are no *.xgboost file in %v directory", modelPath)
if e := ValidateError(modelPath, "*.xgboost", xgboostFileNum); e != nil {
return e
}
return nil
}

func (f Format) validateForMLflow(modelPath string, files []os.FileInfo) error {
var isMLflowFile bool
var MLflowFileNum int32
for _, file := range files {
if file.Name() == "MLmodel" {
// assuming that user would not fool the tool
isMLflowFile = true
MLflowFileNum++
}
}
if !isMLflowFile {
return fmt.Errorf("there are no MLmodel file in %v, directory", modelPath)
if e := ValidateError(modelPath, "MLmodel", MLflowFileNum); e != nil {
return e
}
return nil
}

0 comments on commit e4867e6

Please sign in to comment.