diff --git a/generative_ai/tune_code_generation_model.py b/generative_ai/tune_code_generation_model.py
index 88da72ed8408..762455380306 100644
--- a/generative_ai/tune_code_generation_model.py
+++ b/generative_ai/tune_code_generation_model.py
@@ -16,71 +16,34 @@
from __future__ import annotations
-from typing import Optional
-
+def tune_code_generation_model(
+ project_id: str
+) -> None:
-from google.auth import default
-from google.cloud import aiplatform
-import pandas as pd
-import vertexai
-from vertexai.preview.language_models import CodeGenerationModel, TuningEvaluationSpec
+ # [START generativeaionvertexai_tune_code_generation_model]
+ from google.auth import default
+ import vertexai
+ from vertexai.language_models import CodeGenerationModel
+ credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
-credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
+ # TODO(developer): Update and un-comment below lines
+ # project_id = "PROJECT_ID"
+ vertexai.init(project=project_id, location="us-central1", credentials=credentials)
-def tune_code_generation_model(
- project_id: str,
- location: str,
- training_data: pd.DataFrame | str,
- train_steps: int = 300,
- evaluation_dataset: Optional[str] = None,
- tensorboard_instance_name: Optional[str] = None,
-) -> None:
- """Tune a new model, based on a prompt-response data.
+ model = CodeGenerationModel.from_pretrained("code-bison@002")
- "training_data" can be either the GCS URI of a file formatted in JSONL format
- (for example: training_data=f'gs://{bucket}/{filename}.jsonl'), or a pandas
- DataFrame. Each training example should be JSONL record with two keys, for
- example:
- {
- "input_text": ,
- "output_text":
- },
- or the pandas DataFame should contain two columns:
- ['input_text', 'output_text']
- with rows for each training example.
-
- Args:
- project_id: GCP Project ID, used to initialize vertexai
- location: GCP Region, used to initialize vertexai
- training_data: GCS URI of jsonl file or pandas dataframe of training data
- train_steps: Number of training steps to use when tuning the model.
- evaluation_dataset: GCS URI of jsonl file of evaluation data.
- tensorboard_instance_name: The full name of the existing Vertex AI TensorBoard instance:
- projects/PROJECT_ID/locations/LOCATION_ID/tensorboards/TENSORBOARD_INSTANCE_ID
- Note that this instance must be in the same region as your tuning job.
- """
- vertexai.init(project=project_id, location=location, credentials=credentials)
- eval_spec = TuningEvaluationSpec(evaluation_data=evaluation_dataset)
- eval_spec.tensorboard = aiplatform.Tensorboard(
- tensorboard_name=tensorboard_instance_name
- )
- model = CodeGenerationModel.from_pretrained("code-bison@001")
-
- model.tune_model(
- training_data=training_data,
- # Optional:
- train_steps=train_steps,
+ tuning_job = model.tune_model(
+ training_data="gs://cloud-samples-data/ai-platform/generative_ai/headline_classification.jsonl",
tuning_job_location="europe-west4",
- tuned_model_location=location,
- tuning_evaluation_spec=eval_spec,
+ tuned_model_location="us-central1",
)
- print(model._job.status)
+ print(tuning_job._status)
+ # [END generativeaionvertexai_tune_code_generation_model]
+
return model
# [END aiplatform_sdk_tune_code_generation_model]
-if __name__ == "__main__":
- tune_code_generation_model()
diff --git a/generative_ai/tune_code_generation_model_test.py b/generative_ai/tune_code_generation_model_test.py
index 1ee4427fd056..f7ad5173f477 100644
--- a/generative_ai/tune_code_generation_model_test.py
+++ b/generative_ai/tune_code_generation_model_test.py
@@ -13,97 +13,36 @@
# limitations under the License.
import os
-import uuid
from google.cloud import aiplatform
-from google.cloud import storage
-from google.cloud.aiplatform.compat.types import pipeline_state
import pytest
-from vertexai.preview.language_models import TextGenerationModel
+from vertexai.language_models import TextGenerationModel
import tune_code_generation_model
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
-_LOCATION = "us-central1"
-_BUCKET = os.environ["CLOUD_STORAGE_BUCKET"]
-def get_model_display_name(tuned_model: TextGenerationModel) -> str:
- language_model_tuning_job = tuned_model._job
- pipeline_job = language_model_tuning_job._job
- return dict(pipeline_job._gca_resource.runtime_config.parameter_values)[
- "model_display_name"
- ]
-
-
-def upload_to_gcs(bucket: str, name: str, data: str) -> None:
- client = storage.Client()
- bucket = client.get_bucket(bucket)
- blob = bucket.blob(name)
- blob.upload_from_string(data)
-
-
-def download_from_gcs(bucket: str, name: str) -> str:
- client = storage.Client()
- bucket = client.get_bucket(bucket)
- blob = bucket.blob(name)
- data = blob.download_as_bytes()
- return "\n".join(data.decode().splitlines()[:10])
-
-
-def delete_from_gcs(bucket: str, name: str) -> None:
- client = storage.Client()
- bucket = client.get_bucket(bucket)
- blob = bucket.blob(name)
- blob.delete()
-
-
-@pytest.fixture(scope="function")
-def training_data_filename() -> str:
- temp_filename = f"{uuid.uuid4()}.jsonl"
- data = download_from_gcs(
- "cloud-samples-data", "ai-platform/generative_ai/headline_classification.jsonl"
- )
- upload_to_gcs(_BUCKET, temp_filename, data)
- try:
- yield f"gs://{_BUCKET}/{temp_filename}"
- finally:
- delete_from_gcs(_BUCKET, temp_filename)
-
-
-def teardown_model(
- tuned_model: TextGenerationModel, training_data_filename: str
-) -> None:
+def teardown_model(tuned_model: TextGenerationModel) -> None:
for tuned_model_name in tuned_model.list_tuned_model_names():
model_registry = aiplatform.models.ModelRegistry(model=tuned_model_name)
- if (
- training_data_filename
- in model_registry.get_version_info("1").model_display_name
- ):
- display_name = model_registry.get_version_info("1").model_display_name
- for endpoint in aiplatform.Endpoint.list():
- for _ in endpoint.list_models():
- if endpoint.display_name == display_name:
- endpoint.undeploy_all()
- endpoint.delete()
- aiplatform.Model(model_registry.model_resource_name).delete()
+
+ display_name = model_registry.get_version_info("1").model_display_name
+ for endpoint in aiplatform.Endpoint.list():
+ for _ in endpoint.list_models():
+ if endpoint.display_name == display_name:
+ endpoint.undeploy_all()
+ endpoint.delete()
+ aiplatform.Model(model_registry.model_resource_name).delete()
@pytest.mark.skip("Blocked on b/277959219")
-def test_tuning_code_generation_model(training_data_filename: str) -> None:
+def test_tuning_code_generation_model() -> None:
"""Takes approx. 20 minutes."""
tuned_model = tune_code_generation_model.tune_code_generation_model(
- training_data=training_data_filename,
- project_id=_PROJECT_ID,
- location=_LOCATION,
- train_steps=1,
- evaluation_dataset=training_data_filename,
- tensorboard_instance_name="python-docs-samples-test",
+ project_id=_PROJECT_ID
)
try:
- assert (
- tuned_model._job.status
- == pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
- )
+ assert tuned_model
finally:
- teardown_model(tuned_model, training_data_filename)
+ teardown_model(tuned_model)