diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 9ffc85b3f..4dfbd8e59 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -242,7 +242,9 @@ def create( model = create_deployment_details.models[0] else: try: - create_deployment_details.validate_base_model(model_id=model) + model = create_deployment_details.validate_base_model( + model_id=model + ) except ConfigValidationError as err: raise AquaValueError(f"{err}") from err diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index 5bc468300..184ceff2f 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -8,9 +8,10 @@ from pydantic import BaseModel, Field, model_validator from ads.aqua import logger -from ads.aqua.common.entities import AquaMultiModelRef +from ads.aqua.common.entities import AquaMultiModelRef, LoraModuleSpec from ads.aqua.common.enums import Tags from ads.aqua.common.errors import AquaValueError +from ads.aqua.common.utils import is_valid_ocid from ads.aqua.config.utils.serializer import Serializable from ads.aqua.constants import ( AQUA_FINE_TUNE_MODEL_VERSION, @@ -717,34 +718,65 @@ def validate_ft_model_v2( f"Invalid fine-tuned model ID '{base_model.id}': for fine tuned models like Phi4, the deployment is not supported. " ) - def validate_base_model(self, model_id: str) -> None: + def validate_base_model(self, model_id: str) -> Union[str, AquaMultiModelRef]: """ Validates the input base model for single model deployment configuration. Validation Criteria: - - Fine-tuned models are not supported in single model deployment. + - Legacy fine-tuned models will be deployed as single model deployment. + - Fine-tuned models v2 will be deployed as stacked deployment. Parameters ---------- model_id : str The OCID of DataScienceModel instance. + Returns + ------- + Union[str, AquaMultiModelRef] + A string of model id or an instance of AquaMultiModelRef. + Raises ------ ConfigValidationError If any of the above conditions are violated. """ base_model = DataScienceModel.from_id(model_id) - if Tags.AQUA_FINE_TUNED_MODEL_TAG in base_model.freeform_tags: - logger.error( - "Validation failed: Fine-tuned model ID '%s' is not supported for single-model deployment.", - base_model.id, - ) - raise ConfigValidationError( - f"Invalid base model ID '{base_model.id}': " - "single-model deployment does not support fine-tuned models. " - f"Please deploy the fine-tuned model '{base_model.id}' as a stacked model deployment instead." + freeform_tags = base_model.freeform_tags + aqua_fine_tuned_model = freeform_tags.get( + Tags.AQUA_FINE_TUNED_MODEL_TAG, UNKNOWN + ) + if aqua_fine_tuned_model: + fine_tuned_model_version = freeform_tags.get( + Tags.AQUA_FINE_TUNE_MODEL_VERSION, UNKNOWN ) + # TODO: revisit to block deploying single fine tuned model after AQUA UI is integrated. + if fine_tuned_model_version.lower() == AQUA_FINE_TUNE_MODEL_VERSION: + # extracts base model id from tag 'aqua_fine_tuned_model' and builds AquaMultiModelRef instance for stacked deployment. + logger.debug( + f"Detected base model is fine-tuned model {AQUA_FINE_TUNE_MODEL_VERSION} and switched to stack deployment." + ) + segments = aqua_fine_tuned_model.split("#") + if not segments or not is_valid_ocid(segments[0]): + logger.error( + "Validation failed: Fine-tuned model ID '%s' is not supported for model deployment.", + base_model.id, + ) + raise ConfigValidationError( + f"Invalid fine-tuned model ID '{base_model.id}': missing or invalid tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}' format. " + f"Make sure tag '{Tags.AQUA_FINE_TUNED_MODEL_TAG}' is added with format #." + ) + # reset the model_id and models in create_model_deployment_details for stack deployment + self.model_id = None + self.models = [ + AquaMultiModelRef( + model_id=segments[0], + fine_tune_weights=[LoraModuleSpec(model_id=base_model.id)], + ) + ] + return self.models[0] + + return model_id class Config: extra = "allow" diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 80158ebcd..0c86ff8e5 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -1539,7 +1539,7 @@ def test_create_deployment_for_foundation_model( mock_validate_base_model.assert_called() mock_create.assert_called_with( - model=TestDataset.MODEL_ID, + model=mock_validate_base_model.return_value, compartment_id=TestDataset.USER_COMPARTMENT_ID, project_id=TestDataset.USER_PROJECT_ID, freeform_tags=freeform_tags, @@ -1640,7 +1640,7 @@ def test_create_deployment_for_fine_tuned_model( mock_validate_base_model.assert_called() mock_create.assert_called_with( - model=TestDataset.MODEL_ID, + model=mock_validate_base_model.return_value, compartment_id=TestDataset.USER_COMPARTMENT_ID, project_id=TestDataset.USER_PROJECT_ID, freeform_tags=None, @@ -1741,7 +1741,7 @@ def test_create_deployment_for_gguf_model( mock_validate_base_model.assert_called() mock_create.assert_called_with( - model=TestDataset.MODEL_ID, + model=mock_validate_base_model.return_value, compartment_id=TestDataset.USER_COMPARTMENT_ID, project_id=TestDataset.USER_PROJECT_ID, freeform_tags=None, @@ -1849,7 +1849,7 @@ def test_create_deployment_for_tei_byoc_embedding_model( mock_validate_base_model.assert_called() mock_create.assert_called_with( - model=TestDataset.MODEL_ID, + model=mock_validate_base_model.return_value, compartment_id=TestDataset.USER_COMPARTMENT_ID, project_id=TestDataset.USER_PROJECT_ID, freeform_tags=None,