Skip to content

Commit

Permalink
refactor(generative-ai): Update text model tuning sample to the new s…
Browse files Browse the repository at this point in the history
…tyle (#11594)

* refactor(generative-ai): Update text model tuning sample to the new style

* Update tuning.py - Removing extra tag

* Readd old tag back to statisfy snippet-bot

---------

Co-authored-by: Holt Skinner <13262395+holtskinner@users.noreply.github.com>
  • Loading branch information
gericdong and holtskinner authored Apr 26, 2024
1 parent 4967806 commit 70ca6b3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 133 deletions.
75 changes: 16 additions & 59 deletions generative_ai/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,76 +16,33 @@
from __future__ import annotations


from typing import Optional

def tuning(
project_id: str,
) -> None:

from google.auth import default
from google.cloud import aiplatform
import pandas as pd
import vertexai
from vertexai.language_models import TextGenerationModel
from vertexai.preview.language_models import TuningEvaluationSpec
# [START generativeaionvertexai_tuning]
import vertexai
from vertexai.language_models import TextGenerationModel
from google.auth import default

credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])

credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
# TODO(developer): Update and un-comment below line
# project_id = "PROJECT_ID"

vertexai.init(project=project_id, location="us-central1", credentials=credentials)

def tuning(
project_id: str,
location: str,
model_display_name: str,
training_data: pd.DataFrame | str,
train_steps: int = 10,
evaluation_dataset: Optional[str] = None,
tensorboard_instance_name: Optional[str] = None,
) -> TextGenerationModel:
"""Tune a new model, based on a prompt-response data.
"training_data" can be either the GCS URI of a file formatted in JSONL format
(for example: training_data=f'gs://{bucket}/{filename}.jsonl'), or a pandas
DataFrame. Each training example should be JSONL record with two keys, for
example:
{
"input_text": <input prompt>,
"output_text": <associated output>
},
or the pandas DataFame should contain two columns:
['input_text', 'output_text']
with rows for each training example.
Args:
project_id: GCP Project ID, used to initialize vertexai
location: GCP Region, used to initialize vertexai
model_display_name: Customized Tuned LLM model name.
training_data: GCS URI of jsonl file or pandas dataframe of training data.
train_steps: Number of training steps to use when tuning the model.
evaluation_dataset: GCS URI of jsonl file of evaluation data.
tensorboard_instance_name: The full name of the existing Vertex AI TensorBoard instance:
projects/PROJECT_ID/locations/LOCATION_ID/tensorboards/TENSORBOARD_INSTANCE_ID
Note that this instance must be in the same region as your tuning job.
"""
vertexai.init(project=project_id, location=location, credentials=credentials)
eval_spec = TuningEvaluationSpec(evaluation_data=evaluation_dataset)
eval_spec.tensorboard = aiplatform.Tensorboard(
tensorboard_name=tensorboard_instance_name
)
model = TextGenerationModel.from_pretrained("text-bison@002")

model.tune_model(
training_data=training_data,
# Optional:
model_display_name=model_display_name,
train_steps=train_steps,
tuning_job = model.tune_model(
training_data="gs://cloud-samples-data/ai-platform/generative_ai/headline_classification.jsonl",
tuning_job_location="europe-west4",
tuned_model_location=location,
tuning_evaluation_spec=eval_spec,
tuned_model_location="us-central1",
)

print(model._job.status)
print(tuning_job._status)
# [END generativeaionvertexai_tuning]

return model


# [END aiplatform_sdk_tuning]
if __name__ == "__main__":
tuning()
86 changes: 12 additions & 74 deletions generative_ai/tuning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,98 +13,36 @@
# limitations under the License.

import os
import uuid

from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform.compat.types import pipeline_state
import pytest
from vertexai.preview.language_models import TextGenerationModel

import tuning

_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
_LOCATION = "us-central1"
_BUCKET = os.environ["CLOUD_STORAGE_BUCKET"]


def get_model_display_name(tuned_model: TextGenerationModel) -> str:
language_model_tuning_job = tuned_model._job
pipeline_job = language_model_tuning_job._job
return dict(pipeline_job._gca_resource.runtime_config.parameter_values)[
"model_display_name"
]


def upload_to_gcs(bucket: str, name: str, data: str) -> None:
client = storage.Client()
bucket = client.get_bucket(bucket)
blob = bucket.blob(name)
blob.upload_from_string(data)


def download_from_gcs(bucket: str, name: str) -> str:
client = storage.Client()
bucket = client.get_bucket(bucket)
blob = bucket.blob(name)
data = blob.download_as_bytes()
return "\n".join(data.decode().splitlines()[:10])


def delete_from_gcs(bucket: str, name: str) -> None:
client = storage.Client()
bucket = client.get_bucket(bucket)
blob = bucket.blob(name)
blob.delete()


@pytest.fixture(scope="function")
def training_data_filename() -> str:
temp_filename = f"{uuid.uuid4()}.jsonl"
data = download_from_gcs(
"cloud-samples-data", "ai-platform/generative_ai/headline_classification.jsonl"
)
upload_to_gcs(_BUCKET, temp_filename, data)
try:
yield f"gs://{_BUCKET}/{temp_filename}"
finally:
delete_from_gcs(_BUCKET, temp_filename)


def teardown_model(
tuned_model: TextGenerationModel, training_data_filename: str
) -> None:
def teardown_model(tuned_model: TextGenerationModel) -> None:
for tuned_model_name in tuned_model.list_tuned_model_names():
model_registry = aiplatform.models.ModelRegistry(model=tuned_model_name)
if (
training_data_filename
in model_registry.get_version_info("1").model_display_name
):
display_name = model_registry.get_version_info("1").model_display_name
for endpoint in aiplatform.Endpoint.list():
for _ in endpoint.list_models():
if endpoint.display_name == display_name:
endpoint.undeploy_all()
endpoint.delete()
aiplatform.Model(model_registry.model_resource_name).delete()

display_name = model_registry.get_version_info("1").model_display_name
for endpoint in aiplatform.Endpoint.list():
for _ in endpoint.list_models():
if endpoint.display_name == display_name:
endpoint.undeploy_all()
endpoint.delete()
aiplatform.Model(model_registry.model_resource_name).delete()


@pytest.mark.skip("Blocked on b/277959219")
def test_tuning(training_data_filename: str) -> None:
def test_tuning() -> None:
"""Takes approx. 20 minutes."""
tuned_model = tuning.tuning(
training_data=training_data_filename,
project_id=_PROJECT_ID,
location=_LOCATION,
model_display_name="YOUR_TUNED_MODEL",
train_steps=1,
evaluation_dataset=training_data_filename,
tensorboard_instance_name="python-docs-samples-test",
)
try:
assert (
tuned_model._job.status
== pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
)
assert tuned_model
finally:
teardown_model(tuned_model, training_data_filename)
teardown_model(tuned_model)

0 comments on commit 70ca6b3

Please sign in to comment.