From 646842542d6a870622b91ba5336a8acbb3ea502b Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Fri, 12 May 2023 17:52:21 -0700 Subject: [PATCH 1/2] create_model_bundle metadata param --- launch/client.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/launch/client.py b/launch/client.py index 41c4c8cb..62c1cecd 100644 --- a/launch/client.py +++ b/launch/client.py @@ -330,11 +330,8 @@ def _upload_model_bundle( self, load_model_fn: Callable, load_predict_fn: Callable, - bundle_metadata: Dict[str, Any], ): bundle = dict(load_model_fn=load_model_fn, load_predict_fn=load_predict_fn) - bundle_metadata["load_predict_fn"] = inspect.getsource(load_predict_fn) # type: ignore - bundle_metadata["load_model_fn"] = inspect.getsource(load_model_fn) # type: ignore serialized_bundle = cloudpickle.dumps(bundle) bundle_location = self._upload_data(data=serialized_bundle) return bundle_location @@ -361,6 +358,7 @@ def create_model_bundle_from_callable_v2( custom_base_image_repository: Optional[str] = None, custom_base_image_tag: Optional[str] = None, app_config: Optional[Union[Dict[str, Any], str]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> CreateModelBundleV2Response: """ Uploads and registers a model bundle to Scale Launch. @@ -420,14 +418,15 @@ def predict_fn(input): bundle when it is run. These values can be accessed by the bundle via the ``app_config`` global variable. + metadata: Metadata to record with the bundle. + Returns: An object containing the following keys: - ``model_bundle_id``: The ID of the created model bundle. """ nonnull_requirements = requirements or [] - bundle_metadata: Dict[str, Any] = {} - bundle_location = self._upload_model_bundle(load_model_fn, load_predict_fn, bundle_metadata) + bundle_location = self._upload_model_bundle(load_model_fn, load_predict_fn) schema_location = self._upload_schemas(request_schema=request_schema, response_schema=response_schema) framework = _get_model_bundle_framework( pytorch_image_tag=pytorch_image_tag, @@ -450,6 +449,7 @@ def predict_fn(input): name=model_bundle_name, schema_location=schema_location, flavor=flavor, + metadata=metadata, ) with ApiClient(self.configuration) as api_client: api_instance = DefaultApi(api_client) @@ -476,6 +476,7 @@ def create_model_bundle_from_dirs_v2( custom_base_image_repository: Optional[str] = None, custom_base_image_tag: Optional[str] = None, app_config: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> CreateModelBundleV2Response: """ Packages up code from one or more local filesystem folders and uploads them as a bundle @@ -556,6 +557,8 @@ def create_model_bundle_from_dirs_v2( bundle when it is run. These values can be accessed by the bundle via the ``app_config`` global variable. + metadata: Metadata to record with the bundle. + Returns: An object containing the following keys: @@ -588,6 +591,7 @@ def create_model_bundle_from_dirs_v2( name=model_bundle_name, schema_location=schema_location, flavor=flavor, + metadata=metadata, ) with ApiClient(self.configuration) as api_client: api_instance = DefaultApi(api_client) @@ -610,6 +614,7 @@ def create_model_bundle_from_runnable_image_v2( command: List[str], env: Dict[str, str], readiness_initial_delay_seconds: int, + metadata: Optional[Dict[str, Any]] = None, ) -> CreateModelBundleV2Response: """ Create a model bundle from a runnable image. The specified ``command`` must start a process @@ -636,6 +641,7 @@ def create_model_bundle_from_runnable_image_v2( readiness_initial_delay_seconds: The number of seconds to wait for the HTTP server to become ready and successfully respond on its healthcheck. + metadata: Metadata to record with the bundle. Returns: An object containing the following keys: @@ -658,6 +664,7 @@ def create_model_bundle_from_runnable_image_v2( name=model_bundle_name, schema_location=schema_location, flavor=flavor, + metadata=metadata, ) with ApiClient(self.configuration) as api_client: @@ -688,6 +695,7 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2( triton_storage: Optional[str], triton_memory: Optional[str], triton_readiness_initial_delay_seconds: int, + metadata: Optional[Dict[str, Any]] = None, ) -> CreateModelBundleV2Response: """ Create a model bundle from a runnable image and a tritonserver image. @@ -732,6 +740,8 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2( triton_readiness_initial_delay_seconds: Like readiness_initial_delay_seconds, but for tritonserver's own healthcheck. + metadata: Metadata to record with the bundle. + Returns: An object containing the following keys: @@ -760,6 +770,7 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2( name=model_bundle_name, schema_location=schema_location, flavor=flavor, + metadata=metadata, ) with ApiClient(self.configuration) as api_client: From ad6a044ecb85bf4d32b9e91d4941bee8717cb92c Mon Sep 17 00:00:00 2001 From: Katie Wu Date: Mon, 15 May 2023 00:48:16 -0700 Subject: [PATCH 2/2] Fix --- launch/client.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/launch/client.py b/launch/client.py index 62c1cecd..a937f928 100644 --- a/launch/client.py +++ b/launch/client.py @@ -446,10 +446,12 @@ def predict_fn(input): ) ) create_model_bundle_request = CreateModelBundleV2Request( - name=model_bundle_name, - schema_location=schema_location, - flavor=flavor, - metadata=metadata, + **dict_not_none( + name=model_bundle_name, + schema_location=schema_location, + flavor=flavor, + metadata=metadata, + ) ) with ApiClient(self.configuration) as api_client: api_instance = DefaultApi(api_client) @@ -588,10 +590,12 @@ def create_model_bundle_from_dirs_v2( ) ) create_model_bundle_request = CreateModelBundleV2Request( - name=model_bundle_name, - schema_location=schema_location, - flavor=flavor, - metadata=metadata, + **dict_not_none( + name=model_bundle_name, + schema_location=schema_location, + flavor=flavor, + metadata=metadata, + ) ) with ApiClient(self.configuration) as api_client: api_instance = DefaultApi(api_client) @@ -661,10 +665,12 @@ def create_model_bundle_from_runnable_image_v2( ) ) create_model_bundle_request = CreateModelBundleV2Request( - name=model_bundle_name, - schema_location=schema_location, - flavor=flavor, - metadata=metadata, + **dict_not_none( + name=model_bundle_name, + schema_location=schema_location, + flavor=flavor, + metadata=metadata, + ) ) with ApiClient(self.configuration) as api_client: @@ -767,10 +773,12 @@ def create_model_bundle_from_triton_enhanced_runnable_image_v2( ) ) create_model_bundle_request = CreateModelBundleV2Request( - name=model_bundle_name, - schema_location=schema_location, - flavor=flavor, - metadata=metadata, + **dict_not_none( + name=model_bundle_name, + schema_location=schema_location, + flavor=flavor, + metadata=metadata, + ) ) with ApiClient(self.configuration) as api_client: