Skip to content

Commit

Permalink
feat: Add default skew threshold to be an optional input at _SkewDete…
Browse files Browse the repository at this point in the history
…ctionConfig and also mark the target_field and data_source of skew config to optional.

PiperOrigin-RevId: 496543878
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Dec 20, 2022
1 parent c23a8bd commit 7da4164
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 35 deletions.
63 changes: 38 additions & 25 deletions google/cloud/aiplatform/model_monitoring/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#

from typing import Optional, Dict
from typing import Optional, Dict, Union

from google.cloud.aiplatform_v1.types import (
io as gca_io_v1,
Expand All @@ -39,27 +39,30 @@
class _SkewDetectionConfig:
def __init__(
self,
data_source: str,
skew_thresholds: Dict[str, float],
target_field: str,
attribute_skew_thresholds: Dict[str, float],
data_source: Optional[str] = None,
skew_thresholds: Union[Dict[str, float], float, None] = None,
target_field: Optional[str] = None,
attribute_skew_thresholds: Optional[Dict[str, float]] = None,
data_format: Optional[str] = None,
):
"""Base class for training-serving skew detection.
Args:
data_source (str):
Required. Path to training dataset.
Optional. Path to training dataset.
skew_thresholds (Dict[str, float]):
skew_thresholds: Union[Dict[str, float], float, None]:
Optional. Key is the feature name and value is the
threshold. If a feature needs to be monitored
for skew, a value threshold must be configured
for that feature. The threshold here is against
feature distribution distance between the
training and prediction feature.
training and prediction feature. If a float is passed,
then all features will be monitored using the same
threshold. If None is passed, all feature will be monitored
using alert threshold 0.3 (Backend default).
target_field (str):
Required. The target field name the model is to
Optional. The target field name the model is to
predict. This field will be excluded when doing
Predict and (or) Explain for the training data.
Expand Down Expand Up @@ -93,12 +96,18 @@ def as_proto(self):
"""Returns _SkewDetectionConfig as a proto message."""
skew_thresholds_mapping = {}
attribution_score_skew_thresholds_mapping = {}
default_skew_threshold = None
if self.skew_thresholds is not None:
for key in self.skew_thresholds.keys():
skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.skew_thresholds[key]
if isinstance(self.skew_thresholds, float):
default_skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.skew_thresholds
)
skew_thresholds_mapping[key] = skew_threshold
else:
for key in self.skew_thresholds.keys():
skew_threshold = gca_model_monitoring.ThresholdConfig(
value=self.skew_thresholds[key]
)
skew_thresholds_mapping[key] = skew_threshold
if self.attribute_skew_thresholds is not None:
for key in self.attribute_skew_thresholds.keys():
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
Expand All @@ -110,6 +119,7 @@ def as_proto(self):
return gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig(
skew_thresholds=skew_thresholds_mapping,
attribution_score_skew_thresholds=attribution_score_skew_thresholds_mapping,
default_skew_threshold=default_skew_threshold,
)


Expand Down Expand Up @@ -266,30 +276,33 @@ class SkewDetectionConfig(_SkewDetectionConfig):

def __init__(
self,
data_source: str,
target_field: str,
skew_thresholds: Optional[Dict[str, float]] = None,
data_source: Optional[str] = None,
target_field: Optional[str] = None,
skew_thresholds: Union[Dict[str, float], float, None] = None,
attribute_skew_thresholds: Optional[Dict[str, float]] = None,
data_format: Optional[str] = None,
):
"""Initializer for SkewDetectionConfig.
Args:
data_source (str):
Required. Path to training dataset.
Optional. Path to training dataset.
target_field (str):
Required. The target field name the model is to
Optional. The target field name the model is to
predict. This field will be excluded when doing
Predict and (or) Explain for the training data.
skew_thresholds (Dict[str, float]):
skew_thresholds: Union[Dict[str, float], float, None]:
Optional. Key is the feature name and value is the
threshold. If a feature needs to be monitored
for skew, a value threshold must be configured
for that feature. The threshold here is against
feature distribution distance between the
training and prediction feature.
training and prediction feature. If a float is passed,
then all features will be monitored using the same
threshold. If None is passed, all feature will be monitored
using alert threshold 0.3 (Backend default).
attribute_skew_thresholds (Dict[str, float]):
Optional. Key is the feature name and value is the
Expand All @@ -315,11 +328,11 @@ def __init__(
ValueError for unsupported data formats.
"""
super().__init__(
data_source,
skew_thresholds,
target_field,
attribute_skew_thresholds,
data_format,
data_source=data_source,
skew_thresholds=skew_thresholds,
target_field=target_field,
attribute_skew_thresholds=attribute_skew_thresholds,
data_format=data_format,
)


Expand Down
78 changes: 68 additions & 10 deletions tests/unit/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,79 @@
model_monitoring as gca_model_monitoring,
)

_TEST_THRESHOLD = 0.1
_TEST_TARGET_FIELD = "target"
_TEST_BQ_DATASOURCE = "bq://test/data"
_TEST_GCS_DATASOURCE = "gs://test/data"
_TEST_OTHER_DATASOURCE = ""
_TEST_KEY = "key"
_TEST_DRIFT_TRESHOLD = {"key": 0.2}
_TEST_EMAIL1 = "test1"
_TEST_EMAIL2 = "test2"
_TEST_VALID_DATA_FORMATS = ["tf-record", "csv", "jsonl"]
_TEST_SAMPLING_RATE = 0.8
_TEST_MONITORING_INTERVAL = 1
_TEST_SKEW_THRESHOLDS = [None, 0.2, {"key": 0.1}]
_TEST_ATTRIBUTE_SKEW_THRESHOLDS = [None, {"key": 0.1}]


class TestModelMonitoringConfigs:
"""Tests for model monitoring configs."""

@pytest.mark.parametrize(
"data_source",
[_TEST_BQ_DATASOURCE, _TEST_GCS_DATASOURCE, _TEST_OTHER_DATASOURCE],
)
@pytest.mark.parametrize("data_format", _TEST_VALID_DATA_FORMATS)
def test_valid_configs(self, data_source, data_format):
@pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS)
def test_skew_config_proto_value(self, data_source, data_format, skew_thresholds):
"""Tests if skew config can be constrctued properly to gapic proto."""
attribute_skew_thresholds = {"key": 0.1}
skew_config = model_monitoring.SkewDetectionConfig(
data_source=data_source,
skew_thresholds=skew_thresholds,
target_field=_TEST_TARGET_FIELD,
attribute_skew_thresholds=attribute_skew_thresholds,
data_format=data_format,
)
# data_format and data source are not used at
# TrainingPredictionSkewDetectionConfig.
if isinstance(skew_thresholds, dict):
expected_gapic_proto = gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig(
skew_thresholds={
key: gca_model_monitoring.ThresholdConfig(value=val)
for key, val in skew_thresholds.items()
},
attribution_score_skew_thresholds={
key: gca_model_monitoring.ThresholdConfig(value=val)
for key, val in attribute_skew_thresholds.items()
},
)
else:
expected_gapic_proto = gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig(
default_skew_threshold=gca_model_monitoring.ThresholdConfig(
value=skew_thresholds
)
if skew_thresholds is not None
else None,
attribution_score_skew_thresholds={
key: gca_model_monitoring.ThresholdConfig(value=val)
for key, val in attribute_skew_thresholds.items()
},
)
assert skew_config.as_proto() == expected_gapic_proto

@pytest.mark.parametrize(
"data_source",
[_TEST_BQ_DATASOURCE, _TEST_GCS_DATASOURCE, _TEST_OTHER_DATASOURCE],
)
@pytest.mark.parametrize("data_format", _TEST_VALID_DATA_FORMATS)
@pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS)
@pytest.mark.parametrize(
"attribute_skew_thresholds", _TEST_ATTRIBUTE_SKEW_THRESHOLDS
)
def test_valid_configs(
self, data_source, data_format, skew_thresholds, attribute_skew_thresholds
):
"""Test config creation validity."""
random_sample_config = model_monitoring.RandomSampleConfig(
sample_rate=_TEST_SAMPLING_RATE
)
Expand All @@ -57,17 +110,16 @@ def test_valid_configs(self, data_source, data_format):
)

prediction_drift_config = model_monitoring.DriftDetectionConfig(
drift_thresholds={_TEST_KEY: _TEST_THRESHOLD}
drift_thresholds=_TEST_DRIFT_TRESHOLD
)

skew_config = model_monitoring.SkewDetectionConfig(
data_source=data_source,
skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
skew_thresholds=skew_thresholds,
target_field=_TEST_TARGET_FIELD,
attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
attribute_skew_thresholds=attribute_skew_thresholds,
data_format=data_format,
)

expected_training_dataset = (
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_DATASOURCE),
Expand Down Expand Up @@ -110,15 +162,21 @@ def test_valid_configs(self, data_source, data_format):

@pytest.mark.parametrize("data_source", [_TEST_GCS_DATASOURCE])
@pytest.mark.parametrize("data_format", ["other"])
def test_invalid_data_format(self, data_source, data_format):
@pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS)
@pytest.mark.parametrize(
"attribute_skew_thresholds", _TEST_ATTRIBUTE_SKEW_THRESHOLDS
)
def test_invalid_data_format(
self, data_source, data_format, skew_thresholds, attribute_skew_thresholds
):
if data_format == "other":
with pytest.raises(ValueError) as e:
model_monitoring.ObjectiveConfig(
skew_detection_config=model_monitoring.SkewDetectionConfig(
data_source=data_source,
skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
skew_thresholds=skew_thresholds,
target_field=_TEST_TARGET_FIELD,
attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
attribute_skew_thresholds=attribute_skew_thresholds,
data_format=data_format,
)
).as_proto()
Expand Down

0 comments on commit 7da4164

Please sign in to comment.