diff --git a/great_expectations/metrics/metric_results.py b/great_expectations/metrics/metric_results.py index 4748d151f353..40eaba75eaa7 100644 --- a/great_expectations/metrics/metric_results.py +++ b/great_expectations/metrics/metric_results.py @@ -1,78 +1,37 @@ -from typing import Any, Literal, Union +from typing import Any, Generic, TypeVar, Union -from great_expectations.compatibility.pydantic import BaseModel, validator +from great_expectations.compatibility.pydantic import BaseModel, GenericModel from great_expectations.validator.metric_configuration import MetricConfigurationID +_MetricResultValue = TypeVar("_MetricResultValue") -class _MetricResult(BaseModel): - id: MetricConfigurationID - value: Any - success: bool - - -class _SuccessfulMetricResult(_MetricResult): - success: Literal[True] = True +class MetricResult(GenericModel, Generic[_MetricResultValue]): + id: MetricConfigurationID + value: _MetricResultValue -class InvalidMetricError(TypeError): - def __init__(self, expected_metric: str, actual_metric: str): - super().__init__( - f"Invalid metric: expected {expected_metric} but received {actual_metric}." - ) +class MetricErrorResult(MetricResult[dict[str, Union[int, dict, str]]]): ... -class TableColumns(_SuccessfulMetricResult): - @validator("id") - def validate_id(cls, v): - if v[0] != "table.columns": - raise InvalidMetricError(expected_metric="table.columns", actual_metric=v[0]) - value: list[str] +class TableColumnsResult(MetricResult[list[str]]): ... class _ColumnType(BaseModel): + class Config: + extra = "allow" # some backends return extra values + name: str type: str -class TableColumnTypes(_SuccessfulMetricResult): - @validator("id") - def validate_id(cls, v): - if v[0] != "table.column_types": - raise InvalidMetricError(expected_metric="table.column_types", actual_metric=v[0]) - - value: list[_ColumnType] - - -class UnexpectedCount(_SuccessfulMetricResult): - @validator("id") - def validate_id(cls, v): - metric_name = v[0] - if metric_name.split(".")[-1] != "unexpected_count": - raise InvalidMetricError(expected_metric="unexpected_count", actual_metric=metric_name) - - value: int - - -class UnexpectedValues(_SuccessfulMetricResult): - @validator("id") - def validate_id(cls, v): - metric_name = v[0] - if metric_name.split(".")[-1] != "unexpected_values": - raise InvalidMetricError(expected_metric="unexpected_values", actual_metric=metric_name) +class TableColumnTypesResult(MetricResult[list[_ColumnType]]): ... - value: list[Any] # unknowable type, since this is a sample of user data +class UnexpectedCountResult(MetricResult[int]): ... -class TableRowCount(_SuccessfulMetricResult): - @validator("id") - def validate_id(cls, v): - if v[0] != "table.row_count": - raise InvalidMetricError(expected_metric="table.row_count", actual_metric=v[0]) - value: int +class UnexpectedValuesResult(MetricResult[list[Any]]): ... -class ErrorMetricResult(_MetricResult): - value: dict[str, Union[int, dict, str]] - success: Literal[False] = False +class TableRowCountResult(MetricResult[int]): ... diff --git a/tests/metrics/test_metric_results.py b/tests/metrics/test_metric_results.py index f5069d27ed27..bd90fbfee9ac 100644 --- a/tests/metrics/test_metric_results.py +++ b/tests/metrics/test_metric_results.py @@ -1,50 +1,106 @@ from great_expectations.metrics.metric_results import ( - TableColumns, - TableColumnTypes, - UnexpectedCount, + MetricErrorResult, + TableColumnsResult, + TableColumnTypesResult, + UnexpectedCountResult, + UnexpectedValuesResult, _ColumnType, ) +from great_expectations.validator.metric_configuration import MetricConfigurationID class TestMetricResultInstantiation: - def test_unexpected_count(self): - metric_id = ("column_values.null.unexpected_count", "73d1f59d321e58e8e8a0cfc2d22cca1f", ()) + def test_unexpected_count_result(self): + metric_id = MetricConfigurationID( + metric_name="column_values.null.unexpected_count", + metric_domain_kwargs_id="73d1f59d321e58e8e8a0cfc2d22cca1f", + metric_value_kwargs_id=(), + ) metric_value = 0 - metric_result = UnexpectedCount( + metric_result = UnexpectedCountResult( id=metric_id, value=metric_value, ) - assert not metric_result.error + assert metric_result.dict() == {"id": metric_id, "value": metric_value} + + def test_unexpected_values_result(self): + metric_id = MetricConfigurationID( + metric_name="column_values.null.unexpected_values", + metric_domain_kwargs_id="a8ef4ee749d02d0e5f92719fc6ee8010", + metric_value_kwargs_id="include_nested=True", + ) + metric_value = [ + # these are values coming from the user, so can be anything + "foo", + 3.14, + False, + ] + + metric_result = UnexpectedValuesResult( + id=metric_id, + value=metric_value, + ) + assert metric_result.dict() == {"id": metric_id, "value": metric_value} def test_table_columns_metric_result(self): - metric_id = ("table.columns", "a8ef4ee749d02d0e5f92719fc6ee8010", ()) + metric_id = MetricConfigurationID( + metric_name="table.columns", + metric_domain_kwargs_id="a8ef4ee749d02d0e5f92719fc6ee8010", + metric_value_kwargs_id=(), + ) metric_value = [ "existing_column", "another_existing_column", ] - metric_result = TableColumns( + metric_result = TableColumnsResult( id=metric_id, value=metric_value, ) - assert not metric_result.error - assert all(isinstance(val, str) for val in metric_result.value) + assert metric_result.dict() == {"id": metric_id, "value": metric_value} def test_table_column_types_result(self): - metric_id = ( - "table.column_types", - "a8ef4ee749d02d0e5f92719fc6ee8010", - "include_nested=True", + metric_id = MetricConfigurationID( + metric_name="table.column_types", + metric_domain_kwargs_id="a8ef4ee749d02d0e5f92719fc6ee8010", + metric_value_kwargs_id="include_nested=True", ) metric_value = [ {"name": "existing_column", "type": "int64"}, {"name": "another_existing_column", "type": "object"}, ] - metric_result = TableColumnTypes( + metric_result = TableColumnTypesResult( id=metric_id, value=metric_value, ) - assert not metric_result.error assert all(isinstance(val, _ColumnType) for val in metric_result.value) + assert metric_result.dict() == {"id": metric_id, "value": metric_value} + + def test_metric_error_result(self): + metric_id = MetricConfigurationID( + metric_name="column.mean", + metric_domain_kwargs_id="8a975130e802d66f85ab0cac8d10fbec", + metric_value_kwargs_id=(), + ) + metric_value = { + "exception_info": { + "exception_traceback": "Traceback (most recent call last)...", + "exception_message": "Error: The column does not exist.", + "raised_exception": True, + }, + "metric_configuration": { + "metric_name": "column.mean", + "metric_domain_kwargs": {"table": None, "column": "doesnt exist"}, + "metric_d..._kwargs": {}, + "metric_value_kwargs_id": [], + "id": ["column.mean", "8a975130e802d66f85ab0cac8d10fbec", []], + }, + "num_failures": 3, + } + metric_result = MetricErrorResult( + id=metric_id, + value=metric_value, + ) + assert metric_result.dict() == {"id": metric_id, "value": metric_value}