Skip to content

Commit 2db4023

Browse files
committed
simplify update
1 parent ac79cec commit 2db4023

File tree

1 file changed

+34
-66
lines changed

1 file changed

+34
-66
lines changed

centml/sdk/api.py

Lines changed: 34 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
CreateComputeDeploymentRequest,
1010
CreateCServeV2DeploymentRequest,
1111
CreateCServeV3DeploymentRequest,
12-
CServeV2Recipe,
1312
ApiException,
1413
Metric,
1514
)
@@ -74,59 +73,50 @@ def update_inference(self, deployment_id: int, request: CreateInferenceDeploymen
7473
def update_compute(self, deployment_id: int, request: CreateComputeDeploymentRequest):
7574
return self._api.update_compute_deployment_deployments_compute_put(deployment_id, request)
7675

76+
def detect_deployment_version(self, deployment_id: int) -> str:
77+
"""Detect if a deployment is V2 or V3 by testing the specific API endpoints"""
78+
try:
79+
# Try V3 endpoint first
80+
self._api.get_cserve_v3_deployment_deployments_cserve_v3_deployment_id_get(deployment_id)
81+
return 'v3'
82+
except ApiException as e:
83+
if e.status in [404, 400]: # V3 endpoint doesn't exist for this deployment
84+
try:
85+
# Try V2 endpoint
86+
self._api.get_cserve_v2_deployment_deployments_cserve_v2_deployment_id_get(deployment_id)
87+
return 'v2'
88+
except ApiException:
89+
# If both fail, it might not be a CServe deployment or doesn't exist
90+
raise ValueError(f"Deployment {deployment_id} is not a valid CServe deployment or does not exist")
91+
else:
92+
# Other error (auth, network, etc.)
93+
raise
94+
7795
def update_cserve(
7896
self, deployment_id: int, request: Union[CreateCServeV2DeploymentRequest, CreateCServeV3DeploymentRequest]
7997
):
80-
"""Update CServe deployment - automatically handles both V2 and V3 deployments"""
81-
# Determine the approach based on the request type
98+
"""Update CServe deployment - validates request type matches deployment version"""
99+
# Detect the deployment version
100+
deployment_version = self.detect_deployment_version(deployment_id)
101+
102+
# Validate request type matches deployment version
82103
if isinstance(request, CreateCServeV3DeploymentRequest):
83-
# V3 request - try V3 API first, fallback if deployment is actually V2
84-
try:
85-
return self._api.update_cserve_v3_deployment_deployments_cserve_v3_put(deployment_id, request)
86-
except ApiException as e:
87-
if e.status in [404, 400]: # V3 API failed, deployment might be V2
88-
# Convert V3 request to V2 and try V2 API
89-
v2_request = self._convert_v3_to_v2_request(request)
90-
return self._api.update_cserve_v2_deployment_deployments_cserve_v2_put(deployment_id, v2_request)
91-
else:
92-
raise
104+
if deployment_version != 'v3':
105+
raise ValueError(
106+
f"Deployment {deployment_id} is CServe {deployment_version.upper()}, but you provided a V3 request. Please use CreateCServeV2DeploymentRequest instead."
107+
)
108+
return self._api.update_cserve_v3_deployment_deployments_cserve_v3_put(deployment_id, request)
93109
elif isinstance(request, CreateCServeV2DeploymentRequest):
94-
# V2 request - try V2 API first, fallback to V3 if deployment is actually V3
95-
try:
96-
return self._api.update_cserve_v2_deployment_deployments_cserve_v2_put(deployment_id, request)
97-
except ApiException as e:
98-
if e.status in [404, 400]: # V2 API failed, deployment might be V3
99-
# Convert V2 request to V3 and try V3 API
100-
v3_request = self.convert_v2_to_v3_request(request)
101-
return self._api.update_cserve_v3_deployment_deployments_cserve_v3_put(deployment_id, v3_request)
102-
else:
103-
raise
110+
if deployment_version != 'v2':
111+
raise ValueError(
112+
f"Deployment {deployment_id} is CServe {deployment_version.upper()}, but you provided a V2 request. Please use CreateCServeV3DeploymentRequest instead."
113+
)
114+
return self._api.update_cserve_v2_deployment_deployments_cserve_v2_put(deployment_id, request)
104115
else:
105116
raise ValueError(
106117
f"Unsupported request type: {type(request)}. Expected CreateCServeV2DeploymentRequest or CreateCServeV3DeploymentRequest."
107118
)
108119

109-
def _convert_v3_to_v2_request(self, v3_request: CreateCServeV3DeploymentRequest) -> CreateCServeV2DeploymentRequest:
110-
"""Convert V3 request format to V2 format (reverse of convert_v2_to_v3_request)"""
111-
# Get all fields from V3 request
112-
kwargs = v3_request.model_dump() if hasattr(v3_request, 'model_dump') else v3_request.dict()
113-
114-
# Remove old V3 field names
115-
min_replicas = kwargs.pop('min_replicas', None)
116-
max_replicas = kwargs.pop('max_replicas', None)
117-
initial_replicas = kwargs.pop('initial_replicas', None)
118-
# Remove V3-only fields
119-
kwargs.pop('max_surge', None)
120-
kwargs.pop('max_unavailable', None)
121-
122-
# Add new V2 field names
123-
kwargs['min_scale'] = min_replicas
124-
kwargs['max_scale'] = max_replicas
125-
if initial_replicas is not None:
126-
kwargs['initial_scale'] = initial_replicas
127-
128-
return CreateCServeV2DeploymentRequest(**kwargs)
129-
130120
def _update_status(self, id, new_status):
131121
status_req = platform_api_python_client.DeploymentStatusRequest(status=new_status)
132122
self._api.update_deployment_status_deployments_status_deployment_id_put(id, status_req)
@@ -181,28 +171,6 @@ def detect_cserve_deployment_version(self, deployment_response):
181171
# Default to V2 for backward compatibility
182172
return 'v2'
183173

184-
def convert_v2_to_v3_request(self, v2_request: CreateCServeV2DeploymentRequest) -> CreateCServeV3DeploymentRequest:
185-
"""Convert V2 request format to V3 format with field mapping"""
186-
# Get all fields from V2 request
187-
kwargs = v2_request.model_dump() if hasattr(v2_request, 'model_dump') else v2_request.dict()
188-
189-
# Remove old V2 field names
190-
min_scale = kwargs.pop('min_scale', None)
191-
max_scale = kwargs.pop('max_scale', None)
192-
initial_scale = kwargs.pop('initial_scale', None)
193-
194-
# Add new V3 field names
195-
kwargs['min_replicas'] = min_scale
196-
kwargs['max_replicas'] = max_scale
197-
if initial_scale is not None:
198-
kwargs['initial_replicas'] = initial_scale
199-
200-
# Add V3-specific fields
201-
kwargs['max_surge'] = None
202-
kwargs['max_unavailable'] = None
203-
204-
return CreateCServeV3DeploymentRequest(**kwargs)
205-
206174
# pylint: disable=R0917
207175
def get_deployment_usage(
208176
self, id: int, metric: Metric, start_time_in_seconds: int, end_time_in_seconds: int, step: int

0 commit comments

Comments
 (0)