Skip to content

Commit

Permalink
feat: support model monitoring for batch prediction in Vertex SDK (#1570
Browse files Browse the repository at this point in the history
)

* feat: support model monitoring for batch prediction in Vertex SDK

* fixed broken tests

* fixing syntax error

* addressed comments

* updated test variable name
  • Loading branch information
rosiezou authored Aug 16, 2022
1 parent 3d3e0aa commit bbec998
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 58 deletions.
64 changes: 63 additions & 1 deletion google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ def create(
sync: bool = True,
create_request_timeout: Optional[float] = None,
batch_size: Optional[int] = None,
model_monitoring_objective_config: Optional[
"aiplatform.model_monitoring.ObjectiveConfig"
] = None,
model_monitoring_alert_config: Optional[
"aiplatform.model_monitoring.AlertConfig"
] = None,
analysis_instance_schema_uri: Optional[str] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Expand Down Expand Up @@ -551,6 +558,23 @@ def create(
but too high value will result in a whole batch not fitting in a machine's memory,
and the whole operation will fail.
The default value is 64.
model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig):
Optional. The objective config for model monitoring. Passing this parameter enables
monitoring on the model associated with this batch prediction job.
model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig):
Optional. Configures how model monitoring alerts are sent to the user. Right now
only email alert is supported.
analysis_instance_schema_uri (str):
Optional. Only applicable if model_monitoring_objective_config is also passed.
This parameter specifies the YAML schema file uri describing the format of a single
instance that you want Tensorflow Data Validation (TFDV) to
analyze. If this field is empty, all the feature data types are
inferred from predict_instance_schema_uri, meaning that TFDV
will use the data in the exact format as prediction request/response.
If there are any data type differences between predict instance
and TFDV instance, this field can be used to override the schema.
For models trained with Vertex AI, this field must be set as all the
fields in predict instance formatted as string.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand Down Expand Up @@ -601,7 +625,18 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

# TODO: remove temporary import statements once model monitoring for batch prediction is GA
if model_monitoring_objective_config:
from google.cloud.aiplatform.compat.types import (
io_v1beta1 as gca_io_compat,
batch_prediction_job_v1beta1 as gca_bp_job_compat,
model_monitoring_v1beta1 as gca_model_monitoring_compat,
)
else:
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)
gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()

# Required Fields
Expand Down Expand Up @@ -688,6 +723,28 @@ def create(
)
)

# Model Monitoring
if model_monitoring_objective_config:
if model_monitoring_objective_config.drift_detection_config:
_LOGGER.info(
"Drift detection config is currently not supported for monitoring models associated with batch prediction jobs."
)
if model_monitoring_objective_config.explanation_config:
_LOGGER.info(
"XAI config is currently not supported for monitoring models associated with batch prediction jobs."
)
gapic_batch_prediction_job.model_monitoring_config = (
gca_model_monitoring_compat.ModelMonitoringConfig(
objective_configs=[
model_monitoring_objective_config.as_proto(config_for_bp=True)
],
alert_config=model_monitoring_alert_config.as_proto(
config_for_bp=True
),
analysis_instance_schema_uri=analysis_instance_schema_uri,
)
)

empty_batch_prediction_job = cls._empty_constructor(
project=project,
location=location,
Expand All @@ -702,6 +759,11 @@ def create(
sync=sync,
create_request_timeout=create_request_timeout,
)
# TODO: b/242108750
from google.cloud.aiplatform.compat.types import (
io as gca_io_compat,
batch_prediction_job as gca_bp_job_compat,
)

@classmethod
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
Expand Down
24 changes: 21 additions & 3 deletions google/cloud/aiplatform/model_monitoring/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

from typing import Optional, List
from google.cloud.aiplatform_v1.types import (
model_monitoring as gca_model_monitoring,
model_monitoring as gca_model_monitoring_v1,
)

# TODO: remove imports from v1beta1 once model monitoring for batch prediction is GA
from google.cloud.aiplatform_v1beta1.types import (
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1


class EmailAlertConfig:
def __init__(
Expand All @@ -40,8 +47,19 @@ def __init__(
self.enable_logging = enable_logging
self.user_emails = user_emails

def as_proto(self):
"""Returns EmailAlertConfig as a proto message."""
# TODO: remove config_for_bp parameter when model monitoring for batch prediction is GA
def as_proto(self, config_for_bp: bool = False):
"""Returns EmailAlertConfig as a proto message.
Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
"""
if config_for_bp:
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_model_monitoring = gca_model_monitoring_v1
user_email_alert_config = (
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=self.user_emails
Expand Down
96 changes: 63 additions & 33 deletions google/cloud/aiplatform/model_monitoring/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@
from typing import Optional, Dict

from google.cloud.aiplatform_v1.types import (
io as gca_io,
ThresholdConfig as gca_threshold_config,
model_monitoring as gca_model_monitoring,
io as gca_io_v1,
model_monitoring as gca_model_monitoring_v1,
)

# TODO: b/242108750
from google.cloud.aiplatform_v1beta1.types import (
io as gca_io_v1beta1,
model_monitoring as gca_model_monitoring_v1beta1,
)

gca_model_monitoring = gca_model_monitoring_v1
gca_io = gca_io_v1

TF_RECORD = "tf-record"
CSV = "csv"
JSONL = "jsonl"
Expand Down Expand Up @@ -80,19 +88,20 @@ def __init__(
self.attribute_skew_thresholds = attribute_skew_thresholds
self.data_format = data_format
self.target_field = target_field
self.training_dataset = None

def as_proto(self):
"""Returns _SkewDetectionConfig as a proto message."""
skew_thresholds_mapping = {}
attribution_score_skew_thresholds_mapping = {}
if self.skew_thresholds is not None:
for key in self.skew_thresholds.keys():
skew_threshold = gca_threshold_config(value=self.skew_thresholds[key])
skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.skew_thresholds[key]
)
skew_thresholds_mapping[key] = skew_threshold
if self.attribute_skew_thresholds is not None:
for key in self.attribute_skew_thresholds.keys():
attribution_score_skew_threshold = gca_threshold_config(
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.attribute_skew_thresholds[key]
)
attribution_score_skew_thresholds_mapping[
Expand Down Expand Up @@ -134,12 +143,16 @@ def as_proto(self):
attribution_score_drift_thresholds_mapping = {}
if self.drift_thresholds is not None:
for key in self.drift_thresholds.keys():
drift_threshold = gca_threshold_config(value=self.drift_thresholds[key])
drift_threshold = gca_model_monitoring.ThresholdConfig(
value=self.drift_thresholds[key]
)
drift_thresholds_mapping[key] = drift_threshold
if self.attribute_drift_thresholds is not None:
for key in self.attribute_drift_thresholds.keys():
attribution_score_drift_threshold = gca_threshold_config(
value=self.attribute_drift_thresholds[key]
attribution_score_drift_threshold = (
gca_model_monitoring.ThresholdConfig(
value=self.attribute_drift_thresholds[key]
)
)
attribution_score_drift_thresholds_mapping[
key
Expand Down Expand Up @@ -186,11 +199,49 @@ def __init__(
self.drift_detection_config = drift_detection_config
self.explanation_config = explanation_config

def as_proto(self):
"""Returns _ObjectiveConfig as a proto message."""
# TODO: b/242108750
def as_proto(self, config_for_bp: bool = False):
"""Returns _SkewDetectionConfig as a proto message.
Args:
config_for_bp (bool):
Optional. Set this parameter to True if the config object
is used for model monitoring on a batch prediction job.
"""
if config_for_bp:
gca_io = gca_io_v1beta1
gca_model_monitoring = gca_model_monitoring_v1beta1
else:
gca_io = gca_io_v1
gca_model_monitoring = gca_model_monitoring_v1
training_dataset = None
if self.skew_detection_config is not None:
training_dataset = self.skew_detection_config.training_dataset
training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
target_field=self.skew_detection_config.target_field
)
)
if self.skew_detection_config.data_source.startswith("bq:/"):
training_dataset.bigquery_source = gca_io.BigQuerySource(
input_uri=self.skew_detection_config.data_source
)
elif self.skew_detection_config.data_source.startswith("gs:/"):
training_dataset.gcs_source = gca_io.GcsSource(
uris=[self.skew_detection_config.data_source]
)
if (
self.skew_detection_config.data_format is not None
and self.skew_detection_config.data_format
not in [TF_RECORD, CSV, JSONL]
):
raise ValueError(
"Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s"
% (TF_RECORD, CSV, JSONL)
)
training_dataset.data_format = self.skew_detection_config.data_format
else:
training_dataset.dataset = self.skew_detection_config.data_source

return gca_model_monitoring.ModelMonitoringObjectiveConfig(
training_dataset=training_dataset,
training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
Expand Down Expand Up @@ -271,27 +322,6 @@ def __init__(
data_format,
)

training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
target_field=target_field
)
)
if data_source.startswith("bq:/"):
training_dataset.bigquery_source = gca_io.BigQuerySource(
input_uri=data_source
)
elif data_source.startswith("gs:/"):
training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source])
if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]:
raise ValueError(
"Unsupported value. `data_format` must be one of %s, %s, or %s"
% (TF_RECORD, CSV, JSONL)
)
training_dataset.data_format = data_format
else:
training_dataset.dataset = data_source
self.training_dataset = training_dataset


class DriftDetectionConfig(_DriftDetectionConfig):
"""A class that configures prediction drift detection for models deployed to an endpoint.
Expand Down
36 changes: 27 additions & 9 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@
from google.api_core import exceptions as core_exceptions
from tests.system.aiplatform import e2e_base

from google.cloud.aiplatform_v1.types import (
io as gca_io,
model_monitoring as gca_model_monitoring,
)

# constants used for testing
USER_EMAIL = ""
MODEL_NAME = "churn"
MODEL_NAME2 = "churn2"
MODEL_DISPLAYNAME_KEY = "churn"
MODEL_DISPLAYNAME_KEY2 = "churn2"
IMAGE = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-5:latest"
ENDPOINT = "us-central1-aiplatform.googleapis.com"
CHURN_MODEL_PATH = "gs://mco-mm/churn"
Expand Down Expand Up @@ -139,7 +144,7 @@ def temp_endpoint(self, shared_state):
)

model = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)
Expand All @@ -157,19 +162,19 @@ def temp_endpoint_with_two_models(self, shared_state):
)

model1 = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)

model2 = aiplatform.Model.upload(
display_name=self._make_display_name(key=MODEL_NAME),
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY2),
artifact_uri=CHURN_MODEL_PATH,
serving_container_image_uri=IMAGE,
)
shared_state["resources"] = [model1, model2]
endpoint = aiplatform.Endpoint.create(
display_name=self._make_display_name(key=MODEL_NAME)
display_name=self._make_display_name(key=MODEL_DISPLAYNAME_KEY)
)
endpoint.deploy(
model=model1, machine_type="n1-standard-2", traffic_percentage=100
Expand Down Expand Up @@ -224,7 +229,14 @@ def test_mdm_one_model_one_valid_config(self, shared_state):
gca_obj_config = gapic_job.model_deployment_monitoring_objective_configs[
0
].objective_config
assert gca_obj_config.training_dataset == skew_config.training_dataset

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)
assert gca_obj_config.training_dataset == expected_training_dataset
assert (
gca_obj_config.training_prediction_skew_detection_config
== skew_config.as_proto()
Expand Down Expand Up @@ -297,12 +309,18 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
)
assert gapic_job.model_monitoring_alert_config.enable_logging

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=DATASET_BQ_URI),
target_field=TARGET,
)
)

for config in gapic_job.model_deployment_monitoring_objective_configs:
gca_obj_config = config.objective_config
deployed_model_id = config.deployed_model_id
assert (
gca_obj_config.training_dataset
== all_configs[deployed_model_id].skew_detection_config.training_dataset
gca_obj_config.as_proto().training_dataset == expected_training_dataset
)
assert (
gca_obj_config.training_prediction_skew_detection_config
Expand Down
Loading

0 comments on commit bbec998

Please sign in to comment.