Skip to content

Commit dcaba64

Browse files
authored
Support file paths for ONNX predictor (#1711)
1 parent 202749a commit dcaba64

File tree

7 files changed

+179
-35
lines changed

7 files changed

+179
-35
lines changed

Diff for: cli/local/model_cache.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func CacheLocalModels(apiSpec *spec.API, models []spec.CuratedModelResource) err
6464
if wasAlreadyCached {
6565
modelsThatWereCachedAlready++
6666
}
67-
if len(model.Versions) == 0 {
67+
if model.IsFilePath || len(model.Versions) == 0 {
6868
localModelCache.TargetPath = filepath.Join(model.Name, "1")
6969
} else {
7070
localModelCache.TargetPath = model.Name
@@ -98,7 +98,7 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo
9898
destModelDir := filepath.Join(_modelCacheDir, localModelCache.ID)
9999

100100
if files.IsDir(destModelDir) {
101-
if len(model.Versions) == 0 {
101+
if model.IsFilePath || len(model.Versions) == 0 {
102102
localModelCache.HostPath = filepath.Join(destModelDir, "1")
103103
} else {
104104
localModelCache.HostPath = destModelDir
@@ -110,7 +110,7 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo
110110
if err != nil {
111111
return nil, false, err
112112
}
113-
if len(model.Versions) == 0 {
113+
if model.IsFilePath || len(model.Versions) == 0 {
114114
if _, err := files.CreateDirIfMissing(filepath.Join(destModelDir, "1")); err != nil {
115115
return nil, false, err
116116
}
@@ -137,10 +137,16 @@ func cacheLocalModel(model spec.CuratedModelResource) (*spec.LocalModelCache, bo
137137
}
138138
}
139139

140-
if len(model.Versions) == 0 {
140+
if model.IsFilePath || len(model.Versions) == 0 {
141141
destModelDir = filepath.Join(destModelDir, "1")
142142
}
143-
if err := files.CopyDirOverwrite(strings.TrimSuffix(model.Path, "/"), s.EnsureSuffix(destModelDir, "/")); err != nil {
143+
144+
if model.IsFilePath {
145+
err = files.CopyFileOverwrite(model.Path, filepath.Join(destModelDir, filepath.Base(model.Path)))
146+
} else {
147+
err = files.CopyDirOverwrite(strings.TrimSuffix(model.Path, "/"), s.EnsureSuffix(destModelDir, "/"))
148+
}
149+
if err != nil {
144150
return nil, false, err
145151
}
146152

Diff for: pkg/cortex/downloader/download.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def start(args):
3131

3232
if from_path.startswith("s3://"):
3333
bucket_name, prefix = S3.deconstruct_s3_path(from_path)
34-
client = S3(bucket_name, client_config={})
34+
client = S3(bucket_name)
3535
elif from_path.startswith("gs://"):
3636
bucket_name, prefix = GCS.deconstruct_gcs_path(from_path)
3737
client = GCS(bucket_name)

Diff for: pkg/cortex/serve/cortex_internal/lib/model/type.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,34 @@ def get_models_from_api_spec(
188188
for model in models:
189189
model_resource = {}
190190
model_resource["name"] = model["name"]
191-
model_resource["s3_path"] = model["path"].startswith("s3://")
192-
model_resource["gcs_path"] = model["path"].startswith("gs://")
193-
model_resource["local_path"] = (
194-
not model_resource["s3_path"] and not model_resource["gcs_path"]
195-
)
196191

197192
if not model["signature_key"]:
198193
model_resource["signature_key"] = models_spec["signature_key"]
199194
else:
200195
model_resource["signature_key"] = model["signature_key"]
201196

197+
ends_as_file_path = model["path"].endswith(".onnx")
198+
if ends_as_file_path and os.path.exists(
199+
os.path.join(model_dir, model_resource["name"], "1", os.path.basename(model["path"]))
200+
):
201+
model_resource["is_file_path"] = True
202+
model_resource["s3_path"] = False
203+
model_resource["gcs_path"] = False
204+
model_resource["local_path"] = True
205+
model_resource["versions"] = []
206+
model_resource["path"] = os.path.join(
207+
model_dir, model_resource["name"], "1", os.path.basename(model["path"])
208+
)
209+
model_resources.append(model_resource)
210+
continue
211+
model_resource["is_file_path"] = False
212+
213+
model_resource["s3_path"] = model["path"].startswith("s3://")
214+
model_resource["gcs_path"] = model["path"].startswith("gs://")
215+
model_resource["local_path"] = (
216+
not model_resource["s3_path"] and not model_resource["gcs_path"]
217+
)
218+
202219
if model_resource["s3_path"] or model_resource["gcs_path"]:
203220
model_resource["path"] = model["path"]
204221
_, versions, _, _, _, _, _ = find_all_cloud_models(

Diff for: pkg/operator/operator/k8s.go

+35-11
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ const (
4848

4949
const (
5050
_specCacheDir = "/mnt/spec"
51+
_modelDir = "/mnt/model"
5152
_emptyDirMountPath = "/mnt"
5253
_emptyDirVolumeName = "mnt"
5354
_tfServingContainerName = "serve"
@@ -570,20 +571,43 @@ func pythonDownloadArgs(api *spec.API) string {
570571
}
571572

572573
func onnxDownloadArgs(api *spec.API) string {
573-
downloadConfig := downloadContainerConfig{
574-
LastLog: fmt.Sprintf(_downloaderLastLog, "onnx"),
575-
DownloadArgs: []downloadContainerArg{
576-
{
577-
From: config.BucketPath(api.ProjectKey),
578-
To: path.Join(_emptyDirMountPath, "project"),
579-
Unzip: true,
580-
ItemName: "the project code",
581-
HideFromLog: true,
582-
HideUnzippingLog: true,
583-
},
574+
downloadContainerArs := []downloadContainerArg{
575+
{
576+
From: config.BucketPath(api.ProjectKey),
577+
To: path.Join(_emptyDirMountPath, "project"),
578+
Unzip: true,
579+
ItemName: "the project code",
580+
HideFromLog: true,
581+
HideUnzippingLog: true,
584582
},
585583
}
586584

585+
if api.Predictor.Models.Path != nil && strings.HasSuffix(*api.Predictor.Models.Path, ".onnx") {
586+
downloadContainerArs = append(downloadContainerArs, downloadContainerArg{
587+
From: *api.Predictor.Models.Path,
588+
To: path.Join(_modelDir, consts.SingleModelName, "1"),
589+
ItemName: "the onnx model",
590+
})
591+
}
592+
593+
for _, model := range api.Predictor.Models.Paths {
594+
if model == nil {
595+
continue
596+
}
597+
if strings.HasSuffix(model.Path, ".onnx") {
598+
downloadContainerArs = append(downloadContainerArs, downloadContainerArg{
599+
From: model.Path,
600+
To: path.Join(_modelDir, model.Name, "1"),
601+
ItemName: fmt.Sprintf("%s onnx model", model.Name),
602+
})
603+
}
604+
}
605+
606+
downloadConfig := downloadContainerConfig{
607+
LastLog: fmt.Sprintf(_downloaderLastLog, "onnx"),
608+
DownloadArgs: downloadContainerArs,
609+
}
610+
587611
downloadArgsBytes, _ := json.Marshal(downloadConfig)
588612
return base64.URLEncoding.EncodeToString(downloadArgsBytes)
589613
}

Diff for: pkg/types/spec/api.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ type LocalModelCache struct {
5757

5858
type CuratedModelResource struct {
5959
*userconfig.ModelResource
60-
S3Path bool `json:"s3_path"`
61-
GCSPath bool `json:"gcs_path"`
62-
LocalPath bool `json:"local_path"`
63-
Versions []int64 `json:"versions"`
60+
S3Path bool `json:"s3_path"`
61+
GCSPath bool `json:"gcs_path"`
62+
LocalPath bool `json:"local_path"`
63+
IsFilePath bool `json:"file_path"`
64+
Versions []int64 `json:"versions"`
6465
}
6566

6667
/*

Diff for: pkg/types/spec/errors.go

+12
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ const (
6565
ErrInvalidPythonModelPath = "spec.invalid_python_model_path"
6666
ErrInvalidTensorFlowModelPath = "spec.invalid_tensorflow_model_path"
6767
ErrInvalidONNXModelPath = "spec.invalid_onnx_model_path"
68+
ErrInvalidONNXModelFilePath = "spec.invalid_onnx_model_file_path"
6869

6970
ErrDuplicateModelNames = "spec.duplicate_model_names"
7071
ErrReservedModelName = "spec.reserved_model_name"
@@ -438,6 +439,17 @@ func ErrorInvalidONNXModelPath(modelPath string, modelSubPaths []string) error {
438439
})
439440
}
440441

442+
func ErrorInvalidONNXModelFilePath(filePath string) error {
443+
message := fmt.Sprintf("%s: invalid %s model file path; specify an ONNX file path or provide a directory with one of the following structures:\n", filePath, userconfig.ONNXPredictorType.CasedString())
444+
templateModelPath := "path/to/model/directory/"
445+
message += fmt.Sprintf(_onnxVersionedExpectedStructMessage, templateModelPath, templateModelPath)
446+
447+
return errors.WithStack(&errors.Error{
448+
Kind: ErrInvalidONNXModelFilePath,
449+
Message: message,
450+
})
451+
}
452+
441453
func ErrorDuplicateModelNames(duplicateModel string) error {
442454
return errors.WithStack(&errors.Error{
443455
Kind: ErrDuplicateModelNames,

Diff for: pkg/types/spec/validations.go

+93-9
Original file line numberDiff line numberDiff line change
@@ -1197,18 +1197,26 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
11971197

11981198
var modelWrapError func(error) error
11991199
var modelResources []userconfig.ModelResource
1200+
var modelFileResources []userconfig.ModelResource
12001201

12011202
if hasSingleModel {
12021203
modelWrapError = func(err error) error {
1203-
return errors.Wrap(err, userconfig.ModelsPathKey)
1204+
return errors.Wrap(err, userconfig.ModelsKey, userconfig.ModelsPathKey)
12041205
}
1205-
modelResources = []userconfig.ModelResource{
1206-
{
1207-
Name: consts.SingleModelName,
1208-
Path: *predictor.Models.Path,
1209-
},
1206+
modelResource := userconfig.ModelResource{
1207+
Name: consts.SingleModelName,
1208+
Path: *predictor.Models.Path,
1209+
}
1210+
1211+
if strings.HasSuffix(*predictor.Models.Path, ".onnx") && provider != types.LocalProviderType {
1212+
if err := validateONNXModelFilePath(*predictor.Models.Path, projectFiles.ProjectDir(), awsClient, gcpClient); err != nil {
1213+
return modelWrapError(err)
1214+
}
1215+
modelFileResources = append(modelFileResources, modelResource)
1216+
} else {
1217+
modelResources = append(modelResources, modelResource)
1218+
*predictor.Models.Path = s.EnsureSuffix(*predictor.Models.Path, "/")
12101219
}
1211-
*predictor.Models.Path = s.EnsureSuffix(*predictor.Models.Path, "/")
12121220
}
12131221
if hasMultiModels {
12141222
if len(predictor.Models.Paths) > 0 {
@@ -1225,8 +1233,15 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
12251233
path.Name,
12261234
)
12271235
}
1228-
(*path).Path = s.EnsureSuffix((*path).Path, "/")
1229-
modelResources = append(modelResources, *path)
1236+
if strings.HasSuffix((*path).Path, ".onnx") && provider != types.LocalProviderType {
1237+
if err := validateONNXModelFilePath((*path).Path, projectFiles.ProjectDir(), awsClient, gcpClient); err != nil {
1238+
return errors.Wrap(modelWrapError(err), path.Name)
1239+
}
1240+
modelFileResources = append(modelFileResources, *path)
1241+
} else {
1242+
(*path).Path = s.EnsureSuffix((*path).Path, "/")
1243+
modelResources = append(modelResources, *path)
1244+
}
12301245
}
12311246
}
12321247

@@ -1249,6 +1264,23 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
12491264
return modelWrapError(err)
12501265
}
12511266

1267+
for _, modelFileResource := range modelFileResources {
1268+
s3Path := strings.HasPrefix(modelFileResource.Path, "s3://")
1269+
gcsPath := strings.HasPrefix(modelFileResource.Path, "gs://")
1270+
localPath := !s3Path && !gcsPath
1271+
1272+
*models = append(*models, CuratedModelResource{
1273+
ModelResource: &userconfig.ModelResource{
1274+
Name: modelFileResource.Name,
1275+
Path: modelFileResource.Path,
1276+
},
1277+
S3Path: s3Path,
1278+
GCSPath: gcsPath,
1279+
LocalPath: localPath,
1280+
IsFilePath: true,
1281+
})
1282+
}
1283+
12521284
if hasMultiModels {
12531285
for _, model := range *models {
12541286
if model.Name == consts.SingleModelName {
@@ -1264,6 +1296,58 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
12641296
return nil
12651297
}
12661298

1299+
func validateONNXModelFilePath(modelPath string, projectDir string, awsClient *aws.Client, gcpClient *gcp.Client) error {
1300+
s3Path := strings.HasPrefix(modelPath, "s3://")
1301+
gcsPath := strings.HasPrefix(modelPath, "gs://")
1302+
localPath := !s3Path && !gcsPath
1303+
1304+
if s3Path {
1305+
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
1306+
if err != nil {
1307+
return err
1308+
}
1309+
1310+
bucket, modelPrefix, err := aws.SplitS3Path(modelPath)
1311+
if err != nil {
1312+
return err
1313+
}
1314+
1315+
isS3File, err := awsClientForBucket.IsS3File(bucket, modelPrefix)
1316+
if err != nil {
1317+
return err
1318+
}
1319+
1320+
if !isS3File {
1321+
return ErrorInvalidONNXModelFilePath(modelPrefix)
1322+
}
1323+
}
1324+
1325+
if gcsPath {
1326+
bucket, modelPrefix, err := gcp.SplitGCSPath(modelPath)
1327+
if err != nil {
1328+
return err
1329+
}
1330+
1331+
isGCSFile, err := gcpClient.IsGCSFile(bucket, modelPrefix)
1332+
if err != nil {
1333+
return err
1334+
}
1335+
1336+
if !isGCSFile {
1337+
return ErrorInvalidONNXModelFilePath(modelPrefix)
1338+
}
1339+
}
1340+
1341+
if localPath {
1342+
expandedLocalPath := files.RelToAbsPath(modelPath, projectDir)
1343+
if err := files.CheckFile(expandedLocalPath); err != nil {
1344+
return err
1345+
}
1346+
}
1347+
1348+
return nil
1349+
}
1350+
12671351
func validatePythonPath(predictor *userconfig.Predictor, projectFiles ProjectFiles) error {
12681352
if !projectFiles.HasDir(*predictor.PythonPath) {
12691353
return ErrorPythonPathNotFound(*predictor.PythonPath)

0 commit comments

Comments
 (0)