diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2076b56541..5d197aefc6 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -1996,7 +1996,7 @@ class TextEmbeddingInput: title: Optional[str] = None -class TextEmbeddingModel(_LanguageModel): +class _TextEmbeddingModel(_LanguageModel): """TextEmbeddingModel class calculates embeddings for the given texts. Examples:: @@ -2126,6 +2126,69 @@ async def get_embeddings_async( ] +class _TunableTextEmbeddingModelMixin(_TunableModelMixin): + @classmethod + def get_tuned_model(): + raise NotImplementedError( + "Use deploy_tuned_model instead to get the tuned model." + ) + + # IMPORTANT: Keep this method supported even if you end up deploying the tuned model as part of the tuning pipeline template. + @classmethod + def deploy_tuned_model( + cls, + tuned_model_name: str, + machine_type: Optional[str] = None, + accelerator: Optional[str] = None, + accelerator_count: Optional[int] = None, + ) -> "_LanguageModel": + """Loads the specified tuned language model. + + Args: + tuned_model_name: Tuned model's resource name. + machine_type: Machine type. E.g., "a2-highgpu-1g". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute. + accelerator: Kind of accelerator. E.g., "NVIDIA_TESLA_A100". See also: https://cloud.google.com/vertex-ai/docs/training/configure-compute. + accelerator_count: Count of accelerators. + + Returns: + Tuned `LanguageModel` object. + """ + tuned_vertex_model = aiplatform.Model(tuned_model_name) + tuned_model_labels = tuned_vertex_model.labels + + if _TUNING_BASE_MODEL_ID_LABEL_KEY not in tuned_model_labels: + raise ValueError( + f"The provided model {tuned_model_name} does not have a base model ID." + ) + + tuning_model_id = tuned_vertex_model.labels[_TUNING_BASE_MODEL_ID_LABEL_KEY] + tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models + if len(tuned_model_deployments) == 0: + # Deploying a model to an endpoint requires a resource quota. + endpoint_name = tuned_vertex_model.deploy( + machine_type=machine_type, + accelerator_type=accelerator, + accelerator_count=accelerator_count, + ).resource_name + else: + endpoint_name = tuned_model_deployments[0].endpoint + + base_model_id = _get_model_id_from_tuning_model_id(tuning_model_id) + model_info = _model_garden_models._get_model_info( + model_id=base_model_id, + schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls}, + ) + model = model_info.interface_class( + model_id=base_model_id, + endpoint_name=endpoint_name, + ) + return model + + +class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin): + __module__ = "vertexai.language_models" + + class _PreviewTextEmbeddingModel( TextEmbeddingModel, _ModelWithBatchPredict, _CountTokensMixin ):