Skip to content

Support file paths for ONNX predictor #1711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3cfa91b
Modify predictor validation struct accordingly
RobertLucian Dec 15, 2020
f833586
Merge branch 'master' into improvement/python-api-spec
RobertLucian Dec 15, 2020
d13b743
More validations for models
RobertLucian Dec 15, 2020
dc0fda9
Model resource extraction in Python stack
RobertLucian Dec 15, 2020
aa6e338
Fixes for DML on Python
RobertLucian Dec 15, 2020
7fe6d0f
Fix models not showing up in cortex get
RobertLucian Dec 15, 2020
da93dbd
Telemetry & misc
RobertLucian Dec 15, 2020
6bc7c64
DML fixes
RobertLucian Dec 15, 2020
bd09727
Merge branch 'master' into improvement/python-api-spec
RobertLucian Dec 16, 2020
15bdde5
Add dynamic_model_loading to the docs
RobertLucian Dec 16, 2020
12095cc
Fix test examples
RobertLucian Dec 16, 2020
268c297
Disallow DML for BatchAPI kind
RobertLucian Dec 16, 2020
9258050
Merge branch 'master' into feature/single-model-paths
RobertLucian Dec 16, 2020
c8c9aa6
Merge branch 'master' into improvement/python-api-spec
vishalbollu Dec 17, 2020
e3b09e8
Support ONNX model file paths
RobertLucian Dec 17, 2020
26e5e58
Move model_path field inside the models section
RobertLucian Dec 17, 2020
0f09a21
Fixes on the go-side + some docs
RobertLucian Dec 17, 2020
6fa13b4
Fixes for the Python side
RobertLucian Dec 18, 2020
927e2f5
Merge branch 'master' into improvement/python-api-spec
RobertLucian Dec 18, 2020
8f4f72e
Merge branch 'improvement/python-api-spec' into feature/single-model-…
RobertLucian Dec 18, 2020
e8bfc53
Update docs
RobertLucian Dec 18, 2020
7c79acd
Merge branch 'master' into improvement/python-api-spec
RobertLucian Dec 18, 2020
a906fff
Rename models:model_path to models:path
RobertLucian Dec 18, 2020
d2f5184
Merge branch 'improvement/python-api-spec' into feature/single-model-…
RobertLucian Dec 18, 2020
220f7fa
Misc changes
RobertLucian Dec 18, 2020
de694bf
Merge branch 'improvement/python-api-spec' into feature/single-model-…
RobertLucian Dec 18, 2020
43086b6
Fix merge conflicts from 'master' into feature/single-model-paths
RobertLucian Dec 22, 2020
c980ba8
Address review comments
RobertLucian Dec 22, 2020
e000387
Merge branch 'master' into feature/single-model-paths
RobertLucian Dec 22, 2020
08e7f02
Fix validation
RobertLucian Dec 22, 2020
f25c2f2
Merge branch 'master' into feature/single-model-paths
RobertLucian Dec 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions cli/local/model_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func CacheLocalModels(apiSpec *spec.API, models []spec.CuratedModelResource) err
if wasAlreadyCached {
modelsThatWereCachedAlready++
}
if len(model.Versions) == 0 {
if model.IsFilePath || len(model.Versions) == 0 {
localModelCache.TargetPath = filepath.Join(model.Name, "1")
} else {
localModelCache.TargetPath = model.Name
Expand Down Expand Up @@ -98,7 +98,7 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo
destModelDir := filepath.Join(_modelCacheDir, localModelCache.ID)

if files.IsDir(destModelDir) {
if len(model.Versions) == 0 {
if model.IsFilePath || len(model.Versions) == 0 {
localModelCache.HostPath = filepath.Join(destModelDir, "1")
} else {
localModelCache.HostPath = destModelDir
Expand All @@ -110,7 +110,7 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo
if err != nil {
return nil, false, err
}
if len(model.Versions) == 0 {
if model.IsFilePath || len(model.Versions) == 0 {
if _, err := files.CreateDirIfMissing(filepath.Join(destModelDir, "1")); err != nil {
return nil, false, err
}
Expand All @@ -137,10 +137,16 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo
}
}

if len(model.Versions) == 0 {
if model.IsFilePath || len(model.Versions) == 0 {
destModelDir = filepath.Join(destModelDir, "1")
}
if err := files.CopyDirOverwrite(strings.TrimSuffix(model.Path, "/"), s.EnsureSuffix(destModelDir, "/")); err != nil {

if model.IsFilePath {
err = files.CopyFileOverwrite(model.Path, filepath.Join(destModelDir, filepath.Base(model.Path)))
} else {
err = files.CopyDirOverwrite(strings.TrimSuffix(model.Path, "/"), s.EnsureSuffix(destModelDir, "/"))
}
if err != nil {
return nil, false, err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/cortex/downloader/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def start(args):

if from_path.startswith("s3://"):
bucket_name, prefix = S3.deconstruct_s3_path(from_path)
client = S3(bucket_name, client_config={})
client = S3(bucket_name)
elif from_path.startswith("gs://"):
bucket_name, prefix = GCS.deconstruct_gcs_path(from_path)
client = GCS(bucket_name)
Expand Down
27 changes: 22 additions & 5 deletions pkg/cortex/serve/cortex_internal/lib/model/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,34 @@ def get_models_from_api_spec(
for model in models:
model_resource = {}
model_resource["name"] = model["name"]
model_resource["s3_path"] = model["path"].startswith("s3://")
model_resource["gcs_path"] = model["path"].startswith("gs://")
model_resource["local_path"] = (
not model_resource["s3_path"] and not model_resource["gcs_path"]
)

if not model["signature_key"]:
model_resource["signature_key"] = models_spec["signature_key"]
else:
model_resource["signature_key"] = model["signature_key"]

ends_as_file_path = model["path"].endswith(".onnx")
if ends_as_file_path and os.path.exists(
os.path.join(model_dir, model_resource["name"], "1", os.path.basename(model["path"]))
):
model_resource["is_file_path"] = True
model_resource["s3_path"] = False
model_resource["gcs_path"] = False
model_resource["local_path"] = True
model_resource["versions"] = []
model_resource["path"] = os.path.join(
model_dir, model_resource["name"], "1", os.path.basename(model["path"])
)
model_resources.append(model_resource)
continue
model_resource["is_file_path"] = False

model_resource["s3_path"] = model["path"].startswith("s3://")
model_resource["gcs_path"] = model["path"].startswith("gs://")
model_resource["local_path"] = (
not model_resource["s3_path"] and not model_resource["gcs_path"]
)

if model_resource["s3_path"] or model_resource["gcs_path"]:
model_resource["path"] = model["path"]
_, versions, _, _, _, _, _ = find_all_cloud_models(
Expand Down
46 changes: 35 additions & 11 deletions pkg/operator/operator/k8s.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const (

const (
_specCacheDir = "/mnt/spec"
_modelDir = "/mnt/model"
_emptyDirMountPath = "/mnt"
_emptyDirVolumeName = "mnt"
_tfServingContainerName = "serve"
Expand Down Expand Up @@ -570,20 +571,43 @@ func pythonDownloadArgs(api *spec.API) string {
}

func onnxDownloadArgs(api *spec.API) string {
downloadConfig := downloadContainerConfig{
LastLog: fmt.Sprintf(_downloaderLastLog, "onnx"),
DownloadArgs: []downloadContainerArg{
{
From: config.BucketPath(api.ProjectKey),
To: path.Join(_emptyDirMountPath, "project"),
Unzip: true,
ItemName: "the project code",
HideFromLog: true,
HideUnzippingLog: true,
},
downloadContainerArs := []downloadContainerArg{
{
From: config.BucketPath(api.ProjectKey),
To: path.Join(_emptyDirMountPath, "project"),
Unzip: true,
ItemName: "the project code",
HideFromLog: true,
HideUnzippingLog: true,
},
}

if api.Predictor.Models.Path != nil && strings.HasSuffix(*api.Predictor.Models.Path, ".onnx") {
downloadContainerArs = append(downloadContainerArs, downloadContainerArg{
From: *api.Predictor.Models.Path,
To: path.Join(_modelDir, consts.SingleModelName, "1"),
ItemName: "the onnx model",
})
}

for _, model := range api.Predictor.Models.Paths {
if model == nil {
continue
}
if strings.HasSuffix(model.Path, ".onnx") {
downloadContainerArs = append(downloadContainerArs, downloadContainerArg{
From: model.Path,
To: path.Join(_modelDir, model.Name, "1"),
ItemName: fmt.Sprintf("%s onnx model", model.Name),
})
}
}

downloadConfig := downloadContainerConfig{
LastLog: fmt.Sprintf(_downloaderLastLog, "onnx"),
DownloadArgs: downloadContainerArs,
}

downloadArgsBytes, _ := json.Marshal(downloadConfig)
return base64.URLEncoding.EncodeToString(downloadArgsBytes)
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/types/spec/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ type LocalModelCache struct {

type CuratedModelResource struct {
*userconfig.ModelResource
S3Path bool `json:"s3_path"`
GCSPath bool `json:"gcs_path"`
LocalPath bool `json:"local_path"`
Versions []int64 `json:"versions"`
S3Path bool `json:"s3_path"`
GCSPath bool `json:"gcs_path"`
LocalPath bool `json:"local_path"`
IsFilePath bool `json:"file_path"`
Versions []int64 `json:"versions"`
}

/*
Expand Down
12 changes: 12 additions & 0 deletions pkg/types/spec/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ const (
ErrInvalidPythonModelPath = "spec.invalid_python_model_path"
ErrInvalidTensorFlowModelPath = "spec.invalid_tensorflow_model_path"
ErrInvalidONNXModelPath = "spec.invalid_onnx_model_path"
ErrInvalidONNXModelFilePath = "spec.invalid_onnx_model_file_path"

ErrDuplicateModelNames = "spec.duplicate_model_names"
ErrReservedModelName = "spec.reserved_model_name"
Expand Down Expand Up @@ -438,6 +439,17 @@ func ErrorInvalidONNXModelPath(modelPath string, modelSubPaths []string) error {
})
}

func ErrorInvalidONNXModelFilePath(filePath string) error {
message := fmt.Sprintf("%s: invalid %s model file path; specify an ONNX file path or provide a directory with one of the following structures:\n", filePath, userconfig.ONNXPredictorType.CasedString())
templateModelPath := "path/to/model/directory/"
message += fmt.Sprintf(_onnxVersionedExpectedStructMessage, templateModelPath, templateModelPath)

return errors.WithStack(&errors.Error{
Kind: ErrInvalidONNXModelFilePath,
Message: message,
})
}

func ErrorDuplicateModelNames(duplicateModel string) error {
return errors.WithStack(&errors.Error{
Kind: ErrDuplicateModelNames,
Expand Down
102 changes: 93 additions & 9 deletions pkg/types/spec/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -1197,18 +1197,26 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,

var modelWrapError func(error) error
var modelResources []userconfig.ModelResource
var modelFileResources []userconfig.ModelResource

if hasSingleModel {
modelWrapError = func(err error) error {
return errors.Wrap(err, userconfig.ModelsPathKey)
return errors.Wrap(err, userconfig.ModelsKey, userconfig.ModelsPathKey)
}
modelResources = []userconfig.ModelResource{
{
Name: consts.SingleModelName,
Path: *predictor.Models.Path,
},
modelResource := userconfig.ModelResource{
Name: consts.SingleModelName,
Path: *predictor.Models.Path,
}

if strings.HasSuffix(*predictor.Models.Path, ".onnx") && provider != types.LocalProviderType {
if err := validateONNXModelFilePath(*predictor.Models.Path, projectFiles.ProjectDir(), awsClient, gcpClient); err != nil {
return modelWrapError(err)
}
modelFileResources = append(modelFileResources, modelResource)
} else {
modelResources = append(modelResources, modelResource)
*predictor.Models.Path = s.EnsureSuffix(*predictor.Models.Path, "/")
}
*predictor.Models.Path = s.EnsureSuffix(*predictor.Models.Path, "/")
}
if hasMultiModels {
if len(predictor.Models.Paths) > 0 {
Expand All @@ -1225,8 +1233,15 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
path.Name,
)
}
(*path).Path = s.EnsureSuffix((*path).Path, "/")
modelResources = append(modelResources, *path)
if strings.HasSuffix((*path).Path, ".onnx") && provider != types.LocalProviderType {
if err := validateONNXModelFilePath((*path).Path, projectFiles.ProjectDir(), awsClient, gcpClient); err != nil {
return errors.Wrap(modelWrapError(err), path.Name)
}
modelFileResources = append(modelFileResources, *path)
} else {
(*path).Path = s.EnsureSuffix((*path).Path, "/")
modelResources = append(modelResources, *path)
}
}
}

Expand All @@ -1249,6 +1264,23 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
return modelWrapError(err)
}

for _, modelFileResource := range modelFileResources {
s3Path := strings.HasPrefix(modelFileResource.Path, "s3://")
gcsPath := strings.HasPrefix(modelFileResource.Path, "gs://")
localPath := !s3Path && !gcsPath

*models = append(*models, CuratedModelResource{
ModelResource: &userconfig.ModelResource{
Name: modelFileResource.Name,
Path: modelFileResource.Path,
},
S3Path: s3Path,
GCSPath: gcsPath,
LocalPath: localPath,
IsFilePath: true,
})
}

if hasMultiModels {
for _, model := range *models {
if model.Name == consts.SingleModelName {
Expand All @@ -1264,6 +1296,58 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
return nil
}

func validateONNXModelFilePath(modelPath string, projectDir string, awsClient *aws.Client, gcpClient *gcp.Client) error {
s3Path := strings.HasPrefix(modelPath, "s3://")
gcsPath := strings.HasPrefix(modelPath, "gs://")
localPath := !s3Path && !gcsPath

if s3Path {
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
if err != nil {
return err
}

bucket, modelPrefix, err := aws.SplitS3Path(modelPath)
if err != nil {
return err
}

isS3File, err := awsClientForBucket.IsS3File(bucket, modelPrefix)
if err != nil {
return err
}

if !isS3File {
return ErrorInvalidONNXModelFilePath(modelPrefix)
}
}

if gcsPath {
bucket, modelPrefix, err := gcp.SplitGCSPath(modelPath)
if err != nil {
return err
}

isGCSFile, err := gcpClient.IsGCSFile(bucket, modelPrefix)
if err != nil {
return err
}

if !isGCSFile {
return ErrorInvalidONNXModelFilePath(modelPrefix)
}
}

if localPath {
expandedLocalPath := files.RelToAbsPath(modelPath, projectDir)
if err := files.CheckFile(expandedLocalPath); err != nil {
return err
}
}

return nil
}

func validatePythonPath(predictor *userconfig.Predictor, projectFiles ProjectFiles) error {
if !projectFiles.HasDir(*predictor.PythonPath) {
return ErrorPythonPathNotFound(*predictor.PythonPath)
Expand Down