Skip to content

Commit 29f5314

Browse files
authored
create_model_bundle metadata param (#98)
* create_model_bundle metadata param * Fix
1 parent e3a5eef commit 29f5314

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

Diff for: launch/client.py

+36-17
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,8 @@ def _upload_model_bundle(
330330
self,
331331
load_model_fn: Callable,
332332
load_predict_fn: Callable,
333-
bundle_metadata: Dict[str, Any],
334333
):
335334
bundle = dict(load_model_fn=load_model_fn, load_predict_fn=load_predict_fn)
336-
bundle_metadata["load_predict_fn"] = inspect.getsource(load_predict_fn) # type: ignore
337-
bundle_metadata["load_model_fn"] = inspect.getsource(load_model_fn) # type: ignore
338335
serialized_bundle = cloudpickle.dumps(bundle)
339336
bundle_location = self._upload_data(data=serialized_bundle)
340337
return bundle_location
@@ -361,6 +358,7 @@ def create_model_bundle_from_callable_v2(
361358
custom_base_image_repository: Optional[str] = None,
362359
custom_base_image_tag: Optional[str] = None,
363360
app_config: Optional[Union[Dict[str, Any], str]] = None,
361+
metadata: Optional[Dict[str, Any]] = None,
364362
) -> CreateModelBundleV2Response:
365363
"""
366364
Uploads and registers a model bundle to Scale Launch.
@@ -420,14 +418,15 @@ def predict_fn(input):
420418
bundle when it is run. These values can be accessed by the bundle via the
421419
``app_config`` global variable.
422420
421+
metadata: Metadata to record with the bundle.
422+
423423
Returns:
424424
An object containing the following keys:
425425
426426
- ``model_bundle_id``: The ID of the created model bundle.
427427
"""
428428
nonnull_requirements = requirements or []
429-
bundle_metadata: Dict[str, Any] = {}
430-
bundle_location = self._upload_model_bundle(load_model_fn, load_predict_fn, bundle_metadata)
429+
bundle_location = self._upload_model_bundle(load_model_fn, load_predict_fn)
431430
schema_location = self._upload_schemas(request_schema=request_schema, response_schema=response_schema)
432431
framework = _get_model_bundle_framework(
433432
pytorch_image_tag=pytorch_image_tag,
@@ -447,9 +446,12 @@ def predict_fn(input):
447446
)
448447
)
449448
create_model_bundle_request = CreateModelBundleV2Request(
450-
name=model_bundle_name,
451-
schema_location=schema_location,
452-
flavor=flavor,
449+
**dict_not_none(
450+
name=model_bundle_name,
451+
schema_location=schema_location,
452+
flavor=flavor,
453+
metadata=metadata,
454+
)
453455
)
454456
with ApiClient(self.configuration) as api_client:
455457
api_instance = DefaultApi(api_client)
@@ -476,6 +478,7 @@ def create_model_bundle_from_dirs_v2(
476478
custom_base_image_repository: Optional[str] = None,
477479
custom_base_image_tag: Optional[str] = None,
478480
app_config: Optional[Dict[str, Any]] = None,
481+
metadata: Optional[Dict[str, Any]] = None,
479482
) -> CreateModelBundleV2Response:
480483
"""
481484
Packages up code from one or more local filesystem folders and uploads them as a bundle
@@ -556,6 +559,8 @@ def create_model_bundle_from_dirs_v2(
556559
bundle when it is run. These values can be accessed by the bundle via the
557560
``app_config`` global variable.
558561
562+
metadata: Metadata to record with the bundle.
563+
559564
Returns:
560565
An object containing the following keys:
561566
@@ -585,9 +590,12 @@ def create_model_bundle_from_dirs_v2(
585590
)
586591
)
587592
create_model_bundle_request = CreateModelBundleV2Request(
588-
name=model_bundle_name,
589-
schema_location=schema_location,
590-
flavor=flavor,
593+
**dict_not_none(
594+
name=model_bundle_name,
595+
schema_location=schema_location,
596+
flavor=flavor,
597+
metadata=metadata,
598+
)
591599
)
592600
with ApiClient(self.configuration) as api_client:
593601
api_instance = DefaultApi(api_client)
@@ -610,6 +618,7 @@ def create_model_bundle_from_runnable_image_v2(
610618
command: List[str],
611619
env: Dict[str, str],
612620
readiness_initial_delay_seconds: int,
621+
metadata: Optional[Dict[str, Any]] = None,
613622
) -> CreateModelBundleV2Response:
614623
"""
615624
Create a model bundle from a runnable image. The specified ``command`` must start a process
@@ -636,6 +645,7 @@ def create_model_bundle_from_runnable_image_v2(
636645
readiness_initial_delay_seconds: The number of seconds to wait for the HTTP server to become ready and
637646
successfully respond on its healthcheck.
638647
648+
metadata: Metadata to record with the bundle.
639649
640650
Returns:
641651
An object containing the following keys:
@@ -655,9 +665,12 @@ def create_model_bundle_from_runnable_image_v2(
655665
)
656666
)
657667
create_model_bundle_request = CreateModelBundleV2Request(
658-
name=model_bundle_name,
659-
schema_location=schema_location,
660-
flavor=flavor,
668+
**dict_not_none(
669+
name=model_bundle_name,
670+
schema_location=schema_location,
671+
flavor=flavor,
672+
metadata=metadata,
673+
)
661674
)
662675

663676
with ApiClient(self.configuration) as api_client:
@@ -688,6 +701,7 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2(
688701
triton_storage: Optional[str],
689702
triton_memory: Optional[str],
690703
triton_readiness_initial_delay_seconds: int,
704+
metadata: Optional[Dict[str, Any]] = None,
691705
) -> CreateModelBundleV2Response:
692706
"""
693707
Create a model bundle from a runnable image and a tritonserver image.
@@ -732,6 +746,8 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2(
732746
triton_readiness_initial_delay_seconds: Like readiness_initial_delay_seconds, but for
733747
tritonserver's own healthcheck.
734748
749+
metadata: Metadata to record with the bundle.
750+
735751
Returns:
736752
An object containing the following keys:
737753
@@ -757,9 +773,12 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2(
757773
)
758774
)
759775
create_model_bundle_request = CreateModelBundleV2Request(
760-
name=model_bundle_name,
761-
schema_location=schema_location,
762-
flavor=flavor,
776+
**dict_not_none(
777+
name=model_bundle_name,
778+
schema_location=schema_location,
779+
flavor=flavor,
780+
metadata=metadata,
781+
)
763782
)
764783

765784
with ApiClient(self.configuration) as api_client:

0 commit comments

Comments
 (0)