Skip to content

Commit 6367b1c

Browse files
committed
add inference v3 support - testing
1 parent 73e8e69 commit 6367b1c

File tree

3 files changed

+95
-27
lines changed

3 files changed

+95
-27
lines changed

centml/cli/cluster.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
DeploymentType.COMPUTE: "compute",
1414
DeploymentType.COMPILATION: "compilation",
1515
DeploymentType.INFERENCE_V2: "inference",
16+
DeploymentType.INFERENCE_V3: "inference",
1617
DeploymentType.COMPUTE_V2: "compute",
1718
# For user, they are all cserve.
1819
DeploymentType.CSERVE: "cserve",
@@ -22,7 +23,7 @@
2223
}
2324
# use latest type to for user requests
2425
depl_name_to_type_map = {
25-
"inference": DeploymentType.INFERENCE_V2,
26+
"inference": DeploymentType.INFERENCE_V3,
2627
"cserve": DeploymentType.CSERVE_V3,
2728
"compute": DeploymentType.COMPUTE_V2,
2829
"rag": DeploymentType.RAG,
@@ -140,8 +141,8 @@ def get(type, id):
140141
with get_centml_client() as cclient:
141142
depl_type = depl_name_to_type_map[type]
142143

143-
if depl_type == DeploymentType.INFERENCE_V2:
144-
deployment = cclient.get_inference(id)
144+
if depl_type in [DeploymentType.INFERENCE_V2, DeploymentType.INFERENCE_V3]:
145+
deployment = cclient.get_inference(id) # handles both V2 and V3
145146
elif depl_type == DeploymentType.COMPUTE_V2:
146147
deployment = cclient.get_compute(id)
147148
elif depl_type in [DeploymentType.CSERVE_V2, DeploymentType.CSERVE_V3]:
@@ -169,21 +170,18 @@ def get(type, id):
169170
)
170171

171172
click.echo("Additional deployment configurations:")
172-
if depl_type == DeploymentType.INFERENCE_V2:
173-
click.echo(
174-
tabulate(
175-
[
176-
("Image", deployment.image_url),
177-
("Container port", deployment.container_port),
178-
("Healthcheck", deployment.healthcheck or "/"),
179-
("Replicas", _get_replica_info(deployment)),
180-
("Environment variables", deployment.env_vars or "None"),
181-
("Max concurrency", deployment.concurrency or "None"),
182-
],
183-
tablefmt="rounded_outline",
184-
disable_numparse=True,
185-
)
186-
)
173+
if depl_type in [DeploymentType.INFERENCE_V2, DeploymentType.INFERENCE_V3]:
174+
replica_info = _get_replica_info(deployment)
175+
display_rows = [
176+
("Image", deployment.image_url),
177+
("Container port", deployment.container_port),
178+
("Healthcheck", deployment.healthcheck or "/"),
179+
("Replicas", replica_info),
180+
("Environment variables", deployment.env_vars or "None"),
181+
("Max concurrency", deployment.concurrency or "None"),
182+
]
183+
184+
click.echo(tabulate(display_rows, tablefmt="rounded_outline", disable_numparse=True))
187185
elif depl_type == DeploymentType.COMPUTE_V2:
188186
click.echo(
189187
tabulate(

centml/sdk/api.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DeploymentType,
77
DeploymentStatus,
88
CreateInferenceDeploymentRequest,
9+
CreateInferenceV3DeploymentRequest,
910
CreateComputeDeploymentRequest,
1011
CreateCServeV2DeploymentRequest,
1112
CreateCServeV3DeploymentRequest,
@@ -30,7 +31,21 @@ def get_status(self, id):
3031
return self._api.get_deployment_status_deployments_status_deployment_id_get(id)
3132

3233
def get_inference(self, id):
33-
return self._api.get_inference_deployment_deployments_inference_deployment_id_get(id)
34+
"""Get Inference deployment details - automatically handles both V2 and V3 deployments"""
35+
# Try V3 first (recommended), fallback to V2 if deployment is V2
36+
try:
37+
return self._api.get_inference_v3_deployment_deployments_inference_v3_deployment_id_get(id)
38+
except ApiException as e:
39+
# If V3 fails with 404 or similar, try V2
40+
if e.status in [404, 400]: # Deployment might be V2 or endpoint not found
41+
try:
42+
return self._api.get_inference_deployment_deployments_inference_deployment_id_get(id)
43+
except ApiException as v2_error:
44+
# If both fail, raise the original V3 error as it's more likely to be the real issue
45+
raise e
46+
else:
47+
# For other errors (auth, network, etc.), raise immediately
48+
raise
3449

3550
def get_compute(self, id):
3651
return self._api.get_compute_deployment_deployments_compute_deployment_id_get(id)
@@ -52,9 +67,15 @@ def get_cserve(self, id):
5267
# For other errors (auth, network, etc.), raise immediately
5368
raise
5469

55-
def create_inference(self, request: CreateInferenceDeploymentRequest):
70+
def create_inference(self, request: CreateInferenceV3DeploymentRequest):
71+
return self._api.create_inference_v3_deployment_deployments_inference_v3_post(request)
72+
73+
def create_inference_v2(self, request: CreateInferenceDeploymentRequest):
5674
return self._api.create_inference_deployment_deployments_inference_post(request)
5775

76+
def create_inference_v3(self, request: CreateInferenceV3DeploymentRequest):
77+
return self._api.create_inference_v3_deployment_deployments_inference_v3_post(request)
78+
5879
def create_compute(self, request: CreateComputeDeploymentRequest):
5980
return self._api.create_compute_deployment_deployments_compute_post(request)
6081

@@ -67,8 +88,51 @@ def create_cserve_v2(self, request: CreateCServeV2DeploymentRequest):
6788
def create_cserve_v3(self, request: CreateCServeV3DeploymentRequest):
6889
return self._api.create_cserve_v3_deployment_deployments_cserve_v3_post(request)
6990

70-
def update_inference(self, deployment_id: int, request: CreateInferenceDeploymentRequest):
71-
return self._api.update_inference_deployment_deployments_inference_put(deployment_id, request)
91+
def detect_inference_deployment_version(self, deployment_id: int) -> str:
92+
"""Detect if an inference deployment is V2 or V3 by testing the specific API endpoints"""
93+
try:
94+
# Try V3 endpoint first
95+
self._api.get_inference_v3_deployment_deployments_inference_v3_deployment_id_get(deployment_id)
96+
return 'v3'
97+
except ApiException as e:
98+
if e.status in [404, 400]: # V3 endpoint doesn't exist for this deployment
99+
try:
100+
# Try V2 endpoint
101+
self._api.get_inference_deployment_deployments_inference_deployment_id_get(deployment_id)
102+
return 'v2'
103+
except ApiException:
104+
# If both fail, it might not be an inference deployment or doesn't exist
105+
raise ValueError(
106+
f"Deployment {deployment_id} is not a valid inference deployment or does not exist"
107+
)
108+
else:
109+
# Other error (auth, network, etc.)
110+
raise
111+
112+
def update_inference(
113+
self, deployment_id: int, request: Union[CreateInferenceDeploymentRequest, CreateInferenceV3DeploymentRequest]
114+
):
115+
"""Update Inference deployment - validates request type matches deployment version"""
116+
# Detect the deployment version
117+
deployment_version = self.detect_inference_deployment_version(deployment_id)
118+
119+
# Validate request type matches deployment version
120+
if isinstance(request, CreateInferenceV3DeploymentRequest):
121+
if deployment_version != 'v3':
122+
raise ValueError(
123+
f"Deployment {deployment_id} is Inference {deployment_version.upper()}, but you provided a V3 request. Please use CreateInferenceDeploymentRequest instead."
124+
)
125+
return self._api.update_inference_v3_deployment_deployments_inference_v3_put(deployment_id, request)
126+
elif isinstance(request, CreateInferenceDeploymentRequest):
127+
if deployment_version != 'v2':
128+
raise ValueError(
129+
f"Deployment {deployment_id} is Inference {deployment_version.upper()}, but you provided a V2 request. Please use CreateInferenceV3DeploymentRequest instead."
130+
)
131+
return self._api.update_inference_deployment_deployments_inference_put(deployment_id, request)
132+
else:
133+
raise ValueError(
134+
f"Unsupported request type: {type(request)}. Expected CreateInferenceDeploymentRequest or CreateInferenceV3DeploymentRequest."
135+
)
72136

73137
def update_compute(self, deployment_id: int, request: CreateComputeDeploymentRequest):
74138
return self._api.update_compute_deployment_deployments_compute_put(deployment_id, request)

examples/sdk/create_inference.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
11
import centml
22
from centml.sdk.api import get_centml_client
3-
from centml.sdk import DeploymentType, CreateInferenceDeploymentRequest, UserVaultType
3+
from centml.sdk import DeploymentType, CreateInferenceV3DeploymentRequest, UserVaultType
44

55

66
def main():
77
with get_centml_client() as cclient:
88
certs = cclient.get_user_vault(UserVaultType.CERTIFICATES)
99

10-
request = CreateInferenceDeploymentRequest(
10+
request = CreateInferenceV3DeploymentRequest(
1111
name="nginx",
1212
cluster_id=1000,
1313
hardware_instance_id=1000,
1414
image_url="nginxinc/nginx-unprivileged",
1515
port=8080,
16-
min_scale=1,
17-
max_scale=1,
16+
min_replicas=1, # V3 uses min_replicas instead of min_scale
17+
max_replicas=3, # V3 uses max_replicas instead of max_scale
18+
initial_replicas=1, # Optional in V3 - initial number of replicas
1819
endpoint_certificate_authority=certs["my_cert"],
20+
# V3 rollout strategy parameters
21+
max_surge=1, # Allow 1 extra pod during updates
22+
max_unavailable=0, # Keep all pods available during updates
23+
healthcheck="/",
24+
concurrency=10,
1925
)
2026
response = cclient.create_inference(request)
2127
print("Create deployment response: ", response)
2228

23-
### Get deployment details
29+
### Get deployment details (automatically detects V2 or V3)
2430
deployment = cclient.get_inference(response.id)
2531
print("Deployment details: ", deployment)
2632

0 commit comments

Comments
 (0)