Skip to content

Commit

Permalink
feat: GenAI - Added support for supervised fine-tuning
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621984253
  • Loading branch information
Ark-kun authored and copybara-github committed Apr 4, 2024
1 parent a2778ba commit 036d2d0
Show file tree
Hide file tree
Showing 6 changed files with 553 additions and 3 deletions.
189 changes: 189 additions & 0 deletions tests/unit/vertexai/test_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for generative model tuning."""
# pylint: disable=protected-access,bad-continuation

import copy
import datetime
from typing import Dict, Iterable
from unittest import mock
import uuid

import vertexai
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform_v1.services import gen_ai_tuning_service
from google.cloud.aiplatform_v1.types import job_state
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job
from vertexai.preview import tuning
from vertexai.preview.tuning import sft as supervised_tuning

import pytest

from google.rpc import status_pb2


_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"


_global_tuning_jobs: Dict[str, gca_tuning_job.TuningJob] = {}


class MockGenAiTuningServiceClient(gen_ai_tuning_service.GenAiTuningServiceClient):
@property
def _tuning_jobs(self) -> Dict[str, gca_tuning_job.TuningJob]:
return _global_tuning_jobs

def create_tuning_job(
self,
*,
parent: str,
tuning_job: gca_tuning_job.TuningJob,
**_,
) -> gca_tuning_job.TuningJob:
tuning_job = copy.deepcopy(tuning_job)
resource_id = uuid.uuid4().hex
resource_name = f"{parent}/tuningJobs/{resource_id}"
tuning_job.name = resource_name
current_time = datetime.datetime.now(datetime.timezone.utc)
tuning_job.create_time = current_time
tuning_job.update_time = current_time
tuning_job.state = job_state.JobState.JOB_STATE_PENDING
self._tuning_jobs[resource_name] = tuning_job
return tuning_job

def _progress_tuning_job(self, name: str):
tuning_job: gca_tuning_job.TuningJob = self._tuning_jobs[name]
current_time = datetime.datetime.now(datetime.timezone.utc)
if tuning_job.state == job_state.JobState.JOB_STATE_PENDING:
if (
"invalid_dataset"
in tuning_job.supervised_tuning_spec.training_dataset_uri
):
tuning_job.state = job_state.JobState.JOB_STATE_FAILED
tuning_job.error = status_pb2.Status(
code=400, message="Invalid dataset."
)
else:
tuning_job.state = job_state.JobState.JOB_STATE_RUNNING
tuning_job.update_time = current_time
elif tuning_job.state == job_state.JobState.JOB_STATE_RUNNING:
parent = tuning_job.name.partition("/tuningJobs/")[0]
tuning_job.state = job_state.JobState.JOB_STATE_SUCCEEDED
experiment_id = uuid.uuid4().hex
tuned_model_id = uuid.uuid4().hex
tuned_model_endpoint_id = uuid.uuid4().hex
tuning_job.experiment = (
f"{parent}/metadataStores/default/contexts/{experiment_id}"
)
tuning_job.tuned_model = gca_tuning_job.TunedModel(
model=f"{parent}/models/{tuned_model_id}",
endpoint=f"{parent}/endpoints/{tuned_model_endpoint_id}",
)
tuning_job.end_time = current_time
tuning_job.update_time = current_time
else:
pass

def get_tuning_job(self, *, name: str, **_) -> gca_tuning_job.TuningJob:
tuning_job = self._tuning_jobs[name]
tuning_job = copy.deepcopy(tuning_job)
self._progress_tuning_job(name)

return tuning_job

def list_tuning_jobs(
self, *, parent: str, **_
) -> Iterable[gca_tuning_job.TuningJob]:
return [
tuning_job
for name, tuning_job in self._tuning_jobs.items()
if name.startswith(parent + "/")
]

def cancel_tuning_job(self, *, name: str, **_) -> None:
tuning_job = self._tuning_jobs[name]
assert tuning_job.state in (
job_state.JobState.JOB_STATE_RUNNING,
job_state.JobState.JOB_STATE_PENDING,
)
tuning_job.state = job_state.JobState.JOB_STATE_CANCELLED


class MockTuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
_is_temporary = False
_default_version = compat.V1
_version_map = (
(compat.V1, MockGenAiTuningServiceClient),
# v1beta1 version does not exist
# (compat.V1BETA1, gen_ai_tuning_service_v1beta1.client.JobServiceClient),
)


@pytest.mark.usefixtures("google_auth_mock")
class TestgenerativeModelTuning:
"""Unit tests for generative model tuning."""

def setup_method(self):
vertexai.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)

def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@mock.patch.object(
target=tuning.TuningJob,
attribute="client_class",
new=MockTuningJobClientWithOverride,
)
def test_genai_tuning_service_supervised_tuning_tune_model(self):
sft_tuning_job = supervised_tuning.train(
source_model="gemini-1.0-pro-001",
train_dataset="gs://some-bucket/some_dataset.jsonl",
# Optional:
validation_dataset="gs://some-bucket/some_dataset.jsonl",
epochs=300,
learning_rate_multiplier=1.0,
)
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not sft_tuning_job.has_ended
assert not sft_tuning_job.has_succeeded

# Refreshing the job
sft_tuning_job.refresh()
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING
assert not sft_tuning_job.has_ended
assert not sft_tuning_job.has_succeeded

# Refreshing the job
sft_tuning_job.refresh()
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_RUNNING
assert not sft_tuning_job.has_ended
assert not sft_tuning_job.has_succeeded

# Refreshing the job
sft_tuning_job.refresh()
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_SUCCEEDED
assert sft_tuning_job.has_ended
assert sft_tuning_job.has_succeeded
assert sft_tuning_job._experiment_name
assert sft_tuning_job.tuned_model_name
assert sft_tuning_job.tuned_model_endpoint_name
12 changes: 9 additions & 3 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,16 @@ def __init__(
Args:
model_name: Model Garden model resource name.
Alternatively, a tuned model endpoint resource name can be provided.
generation_config: Default generation config to use in generate_content.
safety_settings: Default safety settings to use in generate_content.
tools: Default tools to use in generate_content.
system_instruction: Default system instruction to use in generate_content.
Note: Only text should be used in parts.
Content of each part will become a separate paragraph.
"""
if not model_name:
raise ValueError("model_name must not be empty")
if "/" not in model_name:
model_name = "publishers/google/models/" + model_name
if model_name.startswith("models/"):
Expand All @@ -160,10 +163,13 @@ def __init__(
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location

if model_name.startswith("publishers/"):
prediction_resource_name = f"projects/{project}/locations/{location}/{model_name}"
else:
prediction_resource_name = model_name

self._model_name = model_name
self._prediction_resource_name = (
f"projects/{project}/locations/{location}/{model_name}"
)
self._prediction_resource_name = prediction_resource_name
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
Expand Down
23 changes: 23 additions & 0 deletions vertexai/preview/tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for tuning models."""

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.tuning._tuning import TuningJob

__all__ = [
"TuningJob",
]
27 changes: 27 additions & 0 deletions vertexai/preview/tuning/sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for supervised tuning."""

# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.tuning._supervised_tuning import (
train,
SupervisedTuningJob,
)

__all__ = [
"train",
"SupervisedTuningJob",
]
71 changes: 71 additions & 0 deletions vertexai/tuning/_supervised_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Optional, Union

from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types

from vertexai import generative_models
from vertexai.tuning import _tuning


def train(
*,
source_model: Union[str, generative_models.GenerativeModel],
train_dataset: str,
validation_dataset: Optional[str] = None,
tuned_model_display_name: Optional[str] = None,
epochs: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
) -> "SupervisedTuningJob":
"""Tunes a model using supervised training.
Args:
source_model (str):
Model name for tuning, e.g., "gemini-1.0-pro" or "gemini-1.0-pro-001".
train_dataset: Cloud Storage path to file containing training dataset for tuning.
The dataset should be in JSONL format.
validation_dataset: Cloud Storage path to file containing validation dataset for tuning.
The dataset should be in JSONL format.
tuned_model_display_name: The display name of the
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
be up to 128 characters long and can consist of any UTF-8 characters.
epochs: Number of training epoches for this tuning job.
learning_rate_multiplier: Learning rate multiplier for tuning.
Returns:
A `TuningJob` object.
"""
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
training_dataset_uri=train_dataset,
validation_dataset_uri=validation_dataset,
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
epoch_count=epochs,
learning_rate_multiplier=learning_rate_multiplier,
),
)

if isinstance(source_model, generative_models.GenerativeModel):
source_model = source_model._prediction_resource_name.rpartition('/')[-1]

return SupervisedTuningJob._create( # pylint: disable=protected-access
base_model=source_model,
tuning_spec=supervised_tuning_spec,
tuned_model_display_name=tuned_model_display_name,
)


class SupervisedTuningJob(_tuning.TuningJob):
pass
Loading

0 comments on commit 036d2d0

Please sign in to comment.