diff --git a/centml/cli/cluster.py b/centml/cli/cluster.py index b5d143c..bbb55ad 100644 --- a/centml/cli/cluster.py +++ b/centml/cli/cluster.py @@ -177,10 +177,13 @@ def get(type, id): click.echo( tabulate( [ - ("Hugging face model", deployment.model), + ("Hugging face model", deployment.recipe.model), ( "Parallelism", - {"tensor": deployment.tensor_parallel_size, "pipeline": deployment.pipeline_parallel_size}, + { + "tensor": deployment.recipe.additional_properties['tensor_parallel_size'], + "pipeline": deployment.recipe.additional_properties['pipeline_parallel_size'], + }, ), ("Replicas", {"min": deployment.min_scale, "max": deployment.max_scale}), ("Max concurrency", deployment.concurrency or "None"), diff --git a/centml/sdk/api.py b/centml/sdk/api.py index 3a637c6..247e0c4 100644 --- a/centml/sdk/api.py +++ b/centml/sdk/api.py @@ -6,7 +6,7 @@ DeploymentStatus, CreateInferenceDeploymentRequest, CreateComputeDeploymentRequest, - CreateCServeDeploymentRequest, + CreateCServeV2DeploymentRequest, ) from centml.sdk import auth @@ -32,7 +32,7 @@ def get_compute(self, id): return self._api.get_compute_deployment_deployments_compute_deployment_id_get(id) def get_cserve(self, id): - return self._api.get_cserve_deployment_deployments_cserve_deployment_id_get(id) + return self._api.get_cserve_v2_deployment_deployments_cserve_v2_deployment_id_get(id) def create_inference(self, request: CreateInferenceDeploymentRequest): return self._api.create_inference_deployment_deployments_inference_post(request) @@ -40,8 +40,8 @@ def create_inference(self, request: CreateInferenceDeploymentRequest): def create_compute(self, request: CreateComputeDeploymentRequest): return self._api.create_compute_deployment_deployments_compute_post(request) - def create_cserve(self, request: CreateCServeDeploymentRequest): - return self._api.create_cserve_deployment_deployments_cserve_post(request) + def create_cserve(self, request: CreateCServeV2DeploymentRequest): + return self._api.create_cserve_v2_deployment_deployments_cserve_v2_post(request) def _update_status(self, id, new_status): status_req = platform_api_python_client.DeploymentStatusRequest(status=new_status) @@ -67,8 +67,16 @@ def get_hardware_instances(self, cluster_id=None): def get_prebuilt_images(self, depl_type: DeploymentType): return self._api.get_prebuilt_images_prebuilt_images_get(type=depl_type) - def get_cserve_recipe(self): - return self._api.get_cserve_recipe_deployments_cserve_recipes_get().results + def get_cserve_recipe(self, model=None, hf_token=None): + return self._api.get_cserve_recipe_deployments_cserve_recipes_get(model=model, hf_token=hf_token).results + + def get_cluster_id(self, hardware_instance_id): + filtered_hw = list(filter(lambda h: h.id == hardware_instance_id, self.get_hardware_instances())) + + if len(filtered_hw) == 0: + raise Exception(f"Invalid hardware instance id {hardware_instance_id}") + + return filtered_hw[0].cluster_id @contextmanager diff --git a/examples/sdk/create_cserve.py b/examples/sdk/create_cserve.py new file mode 100644 index 0000000..2173b8c --- /dev/null +++ b/examples/sdk/create_cserve.py @@ -0,0 +1,34 @@ +import time +import centml +from centml.sdk.api import get_centml_client +from centml.sdk import DeploymentType, CreateCServeV2DeploymentRequest + +with get_centml_client() as cclient: + # Get fastest recipe for the Qwen model + fastest = cclient.get_cserve_recipe(model="Qwen/Qwen2-VL-7B-Instruct")[0].fastest + + # Modify the recipe if necessary + fastest.recipe.additional_properties["max_num_seqs"] = 512 + + # Create CServeV2 deployment + request = CreateCServeV2DeploymentRequest( + name="qwen-fastest", + cluster_id=cclient.get_cluster_id(fastest.hardware_instance_id), + hardware_instance_id=fastest.hardware_instance_id, + recipe=fastest.recipe, + min_scale=1, + max_scale=1, + env_vars={}, + ) + response = cclient.create_cserve(request) + print("Create deployment response: ", response) + + # Get deployment details + deployment = cclient.get_cserve(response.id) + print("Deployment details: ", deployment) + + # Pause the deployment + cclient.pause(deployment.id) + + # Delete the deployment + cclient.delete(deployment.id)