8282    AquaDeploymentDetail ,
8383    ConfigValidationError ,
8484    CreateModelDeploymentDetails ,
85+     ModelDeploymentDetails ,
86+     UpdateModelDeploymentDetails ,
8587)
8688from  ads .aqua .modeldeployment .model_group_config  import  ModelGroupConfig 
8789from  ads .aqua .shaperecommend .recommend  import  AquaShapeRecommend 
110112    ModelDeploymentInfrastructure ,
111113    ModelDeploymentMode ,
112114)
115+ from  ads .model .deployment .model_deployment  import  (
116+     ModelDeploymentUpdateType ,
117+ )
113118from  ads .model .model_metadata  import  ModelCustomMetadata , ModelCustomMetadataItem 
114119from  ads .telemetry  import  telemetry 
115120
@@ -397,14 +402,14 @@ def create(
397402
398403    def  _validate_input_models (
399404        self ,
400-         create_deployment_details :  CreateModelDeploymentDetails ,
405+         deployment_details :  ModelDeploymentDetails ,
401406    ):
402-         """Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment.""" 
407+         """Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details  for stacked or multi model deployment.""" 
403408        # Collect all unique model IDs (including fine-tuned models) 
404409        source_model_ids  =  list (
405410            {
406411                model_id 
407-                 for  model  in  create_deployment_details .models 
412+                 for  model  in  deployment_details .models 
408413                for  model_id  in  model .all_model_ids ()
409414            }
410415        )
@@ -415,7 +420,7 @@ def _validate_input_models(
415420        source_models  =  self .get_multi_source (source_model_ids ) or  {}
416421
417422        try :
418-             create_deployment_details .validate_input_models (model_details = source_models )
423+             deployment_details .validate_input_models (model_details = source_models )
419424        except  ConfigValidationError  as  err :
420425            raise  AquaValueError (f"{ err }  ) from  err 
421426
@@ -1249,6 +1254,219 @@ def _get_container_type_key(
12491254
12501255        return  container_type_key 
12511256
1257+     @telemetry (entry_point = "plugin=deployment&action=update" , name = "aqua" ) 
1258+     def  update (
1259+         self ,
1260+         model_deployment_id : str ,
1261+         update_model_deployment_details : Optional [UpdateModelDeploymentDetails ] =  None ,
1262+         ** kwargs ,
1263+     ) ->  AquaDeployment :
1264+         """Updates a AQUA model group deployment. 
1265+ 
1266+         Args: 
1267+             update_model_deployment_details : UpdateModelDeploymentDetails, optional 
1268+                 An instance of UpdateModelDeploymentDetails containing all optional 
1269+                 fields for updating a model deployment via Aqua. 
1270+             kwargs: 
1271+                 display_name (str): The name of the model deployment. 
1272+                 description (Optional[str]): The description of the deployment. 
1273+                 models (Optional[List[AquaMultiModelRef]]): List of models for deployment. 
1274+                 instance_count (int): Number of instances used for deployment. 
1275+                 log_group_id (Optional[str]): OCI logging group ID for logs. 
1276+                 access_log_id (Optional[str]): OCID for access logs. 
1277+                 predict_log_id (Optional[str]): OCID for prediction logs. 
1278+                 bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps. 
1279+                 web_concurrency (Optional[int]): Number of worker processes/threads for handling requests. 
1280+                 memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape. 
1281+                 ocpus (Optional[float]): OCPU count for the selected shape. 
1282+                 freeform_tags (Optional[Dict]): Freeform tags for model deployment. 
1283+                 defined_tags (Optional[Dict]): Defined tags for model deployment. 
1284+ 
1285+         Returns 
1286+         ------- 
1287+         AquaDeployment 
1288+             An Aqua deployment instance. 
1289+         """ 
1290+         if  not  update_model_deployment_details :
1291+             try :
1292+                 update_model_deployment_details  =  UpdateModelDeploymentDetails (** kwargs )
1293+             except  ValidationError  as  ex :
1294+                 custom_errors  =  build_pydantic_error_message (ex )
1295+                 raise  AquaValueError (
1296+                     f"Invalid parameters for updating a model group deployment. Error details: { custom_errors }  
1297+                 ) from  ex 
1298+ 
1299+         model_deployment  =  ModelDeployment .from_id (model_deployment_id )
1300+ 
1301+         infrastructure  =  model_deployment .infrastructure 
1302+         runtime  =  model_deployment .runtime 
1303+ 
1304+         if  not  runtime .model_group_id :
1305+             raise  AquaValueError (
1306+                 "Invalid 'model_deployment_id'. Only model group deployment is supported to update." 
1307+             )
1308+ 
1309+         # updates model group if fine tuned weights changed. 
1310+         model  =  self ._update_model_group (
1311+             runtime .model_group_id , update_model_deployment_details 
1312+         )
1313+ 
1314+         # updates model group deployment infrastructure 
1315+         (
1316+             infrastructure .with_bandwidth_mbps (
1317+                 update_model_deployment_details .bandwidth_mbps 
1318+                 or  infrastructure .bandwidth_mbps 
1319+             )
1320+             .with_replica (
1321+                 update_model_deployment_details .instance_count  or  infrastructure .replica 
1322+             )
1323+             .with_web_concurrency (
1324+                 update_model_deployment_details .web_concurrency 
1325+                 or  infrastructure .web_concurrency 
1326+             )
1327+         )
1328+ 
1329+         if  (
1330+             update_model_deployment_details .log_group_id 
1331+             and  update_model_deployment_details .access_log_id 
1332+         ):
1333+             infrastructure .with_access_log (
1334+                 log_group_id = update_model_deployment_details .log_group_id ,
1335+                 log_id = update_model_deployment_details .access_log_id ,
1336+             )
1337+ 
1338+         if  (
1339+             update_model_deployment_details .log_group_id 
1340+             and  update_model_deployment_details .predict_log_id 
1341+         ):
1342+             infrastructure .with_predict_log (
1343+                 log_group_id = update_model_deployment_details .log_group_id ,
1344+                 log_id = update_model_deployment_details .predict_log_id ,
1345+             )
1346+ 
1347+         if  (
1348+             update_model_deployment_details .memory_in_gbs 
1349+             and  update_model_deployment_details .ocpus 
1350+             and  infrastructure .shape_name .endswith ("Flex" )
1351+         ):
1352+             infrastructure .with_shape_config_details (
1353+                 ocpus = update_model_deployment_details .ocpus ,
1354+                 memory_in_gbs = update_model_deployment_details .memory_in_gbs ,
1355+             )
1356+ 
1357+         # applies ZDT as default type to update parameters if model group id hasn't been changed 
1358+         update_type  =  ModelDeploymentUpdateType .ZDT 
1359+         # applies LIVE update if model group id has been changed 
1360+         if  runtime .model_group_id  !=  model .id :
1361+             runtime .with_model_group_id (model .id )
1362+             update_type  =  ModelDeploymentUpdateType .LIVE 
1363+ 
1364+         freeform_tags  =  (
1365+             update_model_deployment_details .freeform_tags 
1366+             or  model_deployment .freeform_tags 
1367+         )
1368+         defined_tags  =  (
1369+             update_model_deployment_details .defined_tags 
1370+             or  model_deployment .defined_tags 
1371+         )
1372+ 
1373+         # updates model group deployment 
1374+         (
1375+             model_deployment .with_display_name (
1376+                 update_model_deployment_details .display_name 
1377+                 or  model_deployment .display_name 
1378+             )
1379+             .with_description (
1380+                 update_model_deployment_details .description 
1381+                 or  model_deployment .description 
1382+             )
1383+             .with_freeform_tags (** (freeform_tags  or  {}))
1384+             .with_defined_tags (** (defined_tags  or  {}))
1385+             .with_infrastructure (infrastructure )
1386+             .with_runtime (runtime )
1387+         )
1388+ 
1389+         model_deployment .update (wait_for_completion = False , update_type = update_type )
1390+ 
1391+         logger .info (f"Updating Aqua Model Deployment { model_deployment .id }  )
1392+ 
1393+         return  AquaDeployment .from_oci_model_deployment (
1394+             model_deployment .dsc_model_deployment , self .region 
1395+         )
1396+ 
1397+     def  _update_model_group (
1398+         self ,
1399+         model_group_id : str ,
1400+         update_model_deployment_details : UpdateModelDeploymentDetails ,
1401+     ) ->  DataScienceModelGroup :
1402+         """Creates a new model group if fine tuned weights changed. 
1403+ 
1404+         Parameters 
1405+         ---------- 
1406+         model_group_id: str 
1407+             The model group id. 
1408+         update_model_deployment_details: UpdateModelDeploymentDetails 
1409+             An instance of UpdateModelDeploymentDetails containing all optional 
1410+             fields for updating a model deployment via Aqua. 
1411+ 
1412+         Returns 
1413+         ------- 
1414+         DataScienceModelGroup 
1415+             The instance of DataScienceModelGroup. 
1416+         """ 
1417+         model_group  =  DataScienceModelGroup .from_id (model_group_id )
1418+         # create a new model group if fine tune weights changed as member models in ds model group is inmutable 
1419+         if  update_model_deployment_details .models :
1420+             if  len (update_model_deployment_details .models ) !=  1 :
1421+                 raise  AquaValueError (
1422+                     "Invalid 'models' provided. Only one base model is required for updating model stack deployment." 
1423+                 )
1424+             # validates input base and fine tune models 
1425+             self ._validate_input_models (update_model_deployment_details )
1426+             target_stacked_model  =  update_model_deployment_details .models [0 ]
1427+             target_base_model_id  =  target_stacked_model .model_id 
1428+             if  model_group .base_model_id  !=  target_base_model_id :
1429+                 raise  AquaValueError (
1430+                     "Invalid parameter 'models'. Base model id can't be changed for stacked model deployment." 
1431+                 )
1432+ 
1433+             # add member models 
1434+             member_models  =  [
1435+                 {
1436+                     "inference_key" : fine_tune_weight .model_name ,
1437+                     "model_id" : fine_tune_weight .model_id ,
1438+                 }
1439+                 for  fine_tune_weight  in  target_stacked_model .fine_tune_weights 
1440+             ]
1441+             # add base model 
1442+             member_models .append (
1443+                 {
1444+                     "inference_key" : target_stacked_model .model_name ,
1445+                     "model_id" : target_base_model_id ,
1446+                 }
1447+             )
1448+ 
1449+             # creates a model group with the same configurations from original model group except member models 
1450+             model_group  =  (
1451+                 DataScienceModelGroup ()
1452+                 .with_compartment_id (model_group .compartment_id )
1453+                 .with_project_id (model_group .project_id )
1454+                 .with_display_name (model_group .display_name )
1455+                 .with_description (model_group .description )
1456+                 .with_freeform_tags (** (model_group .freeform_tags  or  {}))
1457+                 .with_defined_tags (** (model_group .defined_tags  or  {}))
1458+                 .with_custom_metadata_list (model_group .custom_metadata_list )
1459+                 .with_base_model_id (target_base_model_id )
1460+                 .with_member_models (member_models )
1461+                 .create ()
1462+             )
1463+ 
1464+             logger .info (
1465+                 f"Model group of base model { target_base_model_id } { model_group .id }  
1466+             )
1467+ 
1468+         return  model_group 
1469+ 
12521470    @telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" ) 
12531471    def  list (self , ** kwargs ) ->  List ["AquaDeployment" ]:
12541472        """List Aqua model deployments in a given compartment and under certain project. 
0 commit comments