Skip to content

Commit

Permalink
feat: GenAI - Add cancel, delete, list methods in BatchPredictionJob
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633702194
  • Loading branch information
jaycee-li authored and copybara-github committed May 14, 2024
1 parent 4d091c6 commit 7ff8071
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 10 deletions.
120 changes: 116 additions & 4 deletions tests/unit/vertexai/test_batch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,36 @@ def complete_bq_uri_mock():


@pytest.fixture
def get_batch_prediction_job_mock():
def get_batch_prediction_job_with_bq_output_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_job_mock:
get_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
bigquery_output_table=_TEST_BQ_OUTPUT_PREFIX
),
)
yield get_job_mock


@pytest.fixture
def get_batch_prediction_job_with_gcs_output_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "get_batch_prediction_job"
) as get_job_mock:
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
),
)
yield get_job_mock


Expand All @@ -120,6 +145,39 @@ def create_batch_prediction_job_mock():
yield create_job_mock


@pytest.fixture
def cancel_batch_prediction_job_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "cancel_batch_prediction_job"
) as cancel_job_mock:
yield cancel_job_mock


@pytest.fixture
def delete_batch_prediction_job_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "delete_batch_prediction_job"
) as delete_job_mock:
yield delete_job_mock


@pytest.fixture
def list_batch_prediction_jobs_mock():
with mock.patch.object(
job_service_client.JobServiceClient, "list_batch_prediction_jobs"
) as list_jobs_mock:
list_jobs_mock.return_value = [
_TEST_GAPIC_BATCH_PREDICTION_JOB,
gca_batch_prediction_job_compat.BatchPredictionJob(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
model=_TEST_PALM_MODEL_RESOURCE_NAME,
state=_TEST_JOB_STATE_SUCCESS,
),
]
yield list_jobs_mock


@pytest.mark.usefixtures(
"google_auth_mock", "generate_display_name_mock", "complete_bq_uri_mock"
)
Expand All @@ -138,10 +196,12 @@ def setup_method(self):
def teardown_method(self):
aiplatform_initializer.global_pool.shutdown(wait=True)

def test_init_batch_prediction_job(self, get_batch_prediction_job_mock):
def test_init_batch_prediction_job(
self, get_batch_prediction_job_with_gcs_output_mock
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

get_batch_prediction_job_mock.assert_called_once_with(
get_batch_prediction_job_with_gcs_output_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
)

Expand All @@ -157,6 +217,7 @@ def test_init_batch_prediction_job_invalid_model(self):
):
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)

@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
def test_submit_batch_prediction_job_with_gcs_input(
self, create_batch_prediction_job_mock
):
Expand All @@ -167,6 +228,15 @@ def test_submit_batch_prediction_job_with_gcs_input(
)

assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
assert job.state == _TEST_JOB_STATE_RUNNING
assert not job.has_ended
assert not job.has_succeeded

job.refresh()
assert job.state == _TEST_JOB_STATE_SUCCESS
assert job.has_ended
assert job.has_succeeded
assert job.output_location == _TEST_GCS_OUTPUT_PREFIX

expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
display_name=_TEST_DISPLAY_NAME,
Expand All @@ -188,6 +258,7 @@ def test_submit_batch_prediction_job_with_gcs_input(
timeout=None,
)

@pytest.mark.usefixtures("get_batch_prediction_job_with_bq_output_mock")
def test_submit_batch_prediction_job_with_bq_input(
self, create_batch_prediction_job_mock
):
Expand All @@ -198,6 +269,15 @@ def test_submit_batch_prediction_job_with_bq_input(
)

assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
assert job.state == _TEST_JOB_STATE_RUNNING
assert not job.has_ended
assert not job.has_succeeded

job.refresh()
assert job.state == _TEST_JOB_STATE_SUCCESS
assert job.has_ended
assert job.has_succeeded
assert job.output_location == _TEST_BQ_OUTPUT_PREFIX

expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
display_name=_TEST_DISPLAY_NAME,
Expand Down Expand Up @@ -349,3 +429,35 @@ def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self):
source_model=_TEST_GEMINI_MODEL_NAME,
input_dataset=_TEST_GCS_INPUT_URI,
)

@pytest.mark.usefixtures("create_batch_prediction_job_mock")
def test_cancel_batch_prediction_job(self, cancel_batch_prediction_job_mock):
job = batch_prediction.BatchPredictionJob.submit(
source_model=_TEST_GEMINI_MODEL_NAME,
input_dataset=_TEST_GCS_INPUT_URI,
output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX,
)
job.cancel()

cancel_batch_prediction_job_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
)

@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
def test_delete_batch_prediction_job(self, delete_batch_prediction_job_mock):
job = batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
job.delete()

delete_batch_prediction_job_mock.assert_called_once_with(
name=_TEST_BATCH_PREDICTION_JOB_NAME,
)

def tes_list_batch_prediction_jobs(self, list_batch_prediction_jobs_mock):
jobs = batch_prediction.BatchPredictionJob.list()

assert len(jobs) == 1
assert jobs[0].gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB

list_batch_prediction_jobs_mock.assert_called_once_with(
request={"parent": _TEST_PARENT}
)
87 changes: 81 additions & 6 deletions vertexai/batch_prediction/_batch_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform_v1 import types as gca_types
from vertexai import generative_models

from google.rpc import status_pb2


_LOGGER = aiplatform_base.Logger(__name__)

Expand All @@ -37,7 +40,6 @@ class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus):
_resource_noun = "batchPredictionJobs"
_getter_method = "get_batch_prediction_job"
_list_method = "list_batch_prediction_jobs"
_cancel_method = "cancel_batch_prediction_job"
_delete_method = "delete_batch_prediction_job"
_job_type = "batch-predictions"
_parse_resource_name_method = "parse_batch_prediction_job_path"
Expand All @@ -63,13 +65,46 @@ def __init__(self, batch_prediction_job_name: str):
resource_name=batch_prediction_job_name
)
# TODO(b/338452508) Support tuned GenAI models.
if not re.search(_GEMINI_MODEL_PATTERN, self._gca_resource.model):
if not re.search(_GEMINI_MODEL_PATTERN, self.model_name):
raise ValueError(
f"BatchPredictionJob '{batch_prediction_job_name}' "
f"runs with the model '{self._gca_resource.model}', "
f"runs with the model '{self.model_name}', "
"which is not a GenAI model."
)

@property
def model_name(self) -> str:
"""Returns the model name used for this batch prediction job."""
return self._gca_resource.model

@property
def state(self) -> gca_types.JobState:
"""Returns the state of this batch prediction job."""
return self._gca_resource.state

@property
def has_ended(self) -> bool:
"""Returns true if this batch prediction job has ended."""
return self.state in jobs._JOB_COMPLETE_STATES

@property
def has_succeeded(self) -> bool:
"""Returns true if this batch prediction job has succeeded."""
return self.state == gca_types.JobState.JOB_STATE_SUCCEEDED

@property
def error(self) -> Optional[status_pb2.Status]:
"""Returns detailed error info for this Job resource."""
return self._gca_resource.error

@property
def output_location(self) -> str:
"""Returns the output location of this batch prediction job."""
return (
self._gca_resource.output_info.gcs_output_directory
or self._gca_resource.output_info.bigquery_output_table
)

@classmethod
def submit(
cls,
Expand Down Expand Up @@ -178,14 +213,54 @@ def submit(
_LOGGER.log_create_complete(
cls, job._gca_resource, "job", module_name="batch_prediction"
)
_LOGGER.info(
"View Batch Prediction Job:\n%s" % aiplatform_job._dashboard_uri()
)
_LOGGER.info("View Batch Prediction Job:\n%s" % job._dashboard_uri())

return job
finally:
logging.getLogger("google.cloud.aiplatform.jobs").disabled = False

def refresh(self) -> "BatchPredictionJob":
"""Refreshes the batch prediction job from the service."""
self._sync_gca_resource()
return self

def cancel(self):
"""Cancels this BatchPredictionJob.
Success of cancellation is not guaranteed. Use `job.refresh()` and
`job.state` to verify if cancellation was successful.
"""
_LOGGER.log_action_start_against_resource("Cancelling", "run", self)
self.api_client.cancel_batch_prediction_job(name=self.resource_name)

def delete(self):
"""Deletes this BatchPredictionJob resource.
WARNING: This deletion is permanent.
"""
self._delete()

@classmethod
def list(cls, filter=None) -> List["BatchPredictionJob"]:
"""Lists all BatchPredictionJob instances that run with GenAI models."""
return cls._list(
cls_filter=lambda gca_resource: re.search(
_GEMINI_MODEL_PATTERN, gca_resource.model
),
filter=filter,
)

def _dashboard_uri(self) -> Optional[str]:
"""Returns the Google Cloud console URL where job can be viewed."""
fields = self._parse_resource_name(self.resource_name)
location = fields.pop("location")
project = fields.pop("project")
job = list(fields.values())[0]
return (
"https://console.cloud.google.com/ai/platform/locations/"
f"{location}/{self._job_type}/{job}?project={project}"
)

@classmethod
def _reconcile_model_name(cls, model_name: str) -> str:
"""Reconciles model name to a publisher model resource name."""
Expand Down

0 comments on commit 7ff8071

Please sign in to comment.