Skip to content

Commit

Permalink
simplify to generic
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-stauffer committed Feb 13, 2025
1 parent b47e842 commit daca188
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 73 deletions.
71 changes: 15 additions & 56 deletions great_expectations/metrics/metric_results.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,37 @@
from typing import Any, Literal, Union
from typing import Any, Generic, TypeVar, Union

from great_expectations.compatibility.pydantic import BaseModel, validator
from great_expectations.compatibility.pydantic import BaseModel, GenericModel
from great_expectations.validator.metric_configuration import MetricConfigurationID

_MetricResultValue = TypeVar("_MetricResultValue")

class _MetricResult(BaseModel):
id: MetricConfigurationID
value: Any
success: bool


class _SuccessfulMetricResult(_MetricResult):
success: Literal[True] = True

class MetricResult(GenericModel, Generic[_MetricResultValue]):
id: MetricConfigurationID
value: _MetricResultValue

class InvalidMetricError(TypeError):
def __init__(self, expected_metric: str, actual_metric: str):
super().__init__(
f"Invalid metric: expected {expected_metric} but received {actual_metric}."
)

class MetricErrorResult(MetricResult[dict[str, Union[int, dict, str]]]): ...

class TableColumns(_SuccessfulMetricResult):
@validator("id")
def validate_id(cls, v):
if v[0] != "table.columns":
raise InvalidMetricError(expected_metric="table.columns", actual_metric=v[0])

value: list[str]
class TableColumnsResult(MetricResult[list[str]]): ...


class _ColumnType(BaseModel):
class Config:
extra = "allow" # some backends return extra values

name: str
type: str


class TableColumnTypes(_SuccessfulMetricResult):
@validator("id")
def validate_id(cls, v):
if v[0] != "table.column_types":
raise InvalidMetricError(expected_metric="table.column_types", actual_metric=v[0])

value: list[_ColumnType]


class UnexpectedCount(_SuccessfulMetricResult):
@validator("id")
def validate_id(cls, v):
metric_name = v[0]
if metric_name.split(".")[-1] != "unexpected_count":
raise InvalidMetricError(expected_metric="unexpected_count", actual_metric=metric_name)

value: int


class UnexpectedValues(_SuccessfulMetricResult):
@validator("id")
def validate_id(cls, v):
metric_name = v[0]
if metric_name.split(".")[-1] != "unexpected_values":
raise InvalidMetricError(expected_metric="unexpected_values", actual_metric=metric_name)
class TableColumnTypesResult(MetricResult[list[_ColumnType]]): ...

value: list[Any] # unknowable type, since this is a sample of user data

class UnexpectedCountResult(MetricResult[int]): ...

class TableRowCount(_SuccessfulMetricResult):
@validator("id")
def validate_id(cls, v):
if v[0] != "table.row_count":
raise InvalidMetricError(expected_metric="table.row_count", actual_metric=v[0])

value: int
class UnexpectedValuesResult(MetricResult[list[Any]]): ...


class ErrorMetricResult(_MetricResult):
value: dict[str, Union[int, dict, str]]
success: Literal[False] = False
class TableRowCountResult(MetricResult[int]): ...
90 changes: 73 additions & 17 deletions tests/metrics/test_metric_results.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,106 @@
from great_expectations.metrics.metric_results import (
TableColumns,
TableColumnTypes,
UnexpectedCount,
MetricErrorResult,
TableColumnsResult,
TableColumnTypesResult,
UnexpectedCountResult,
UnexpectedValuesResult,
_ColumnType,
)
from great_expectations.validator.metric_configuration import MetricConfigurationID


class TestMetricResultInstantiation:
def test_unexpected_count(self):
metric_id = ("column_values.null.unexpected_count", "73d1f59d321e58e8e8a0cfc2d22cca1f", ())
def test_unexpected_count_result(self):
metric_id = MetricConfigurationID(
metric_name="column_values.null.unexpected_count",
metric_domain_kwargs_id="73d1f59d321e58e8e8a0cfc2d22cca1f",
metric_value_kwargs_id=(),
)
metric_value = 0

metric_result = UnexpectedCount(
metric_result = UnexpectedCountResult(
id=metric_id,
value=metric_value,
)
assert not metric_result.error
assert metric_result.dict() == {"id": metric_id, "value": metric_value}

def test_unexpected_values_result(self):
metric_id = MetricConfigurationID(
metric_name="column_values.null.unexpected_values",
metric_domain_kwargs_id="a8ef4ee749d02d0e5f92719fc6ee8010",
metric_value_kwargs_id="include_nested=True",
)
metric_value = [
# these are values coming from the user, so can be anything
"foo",
3.14,
False,
]

metric_result = UnexpectedValuesResult(
id=metric_id,
value=metric_value,
)
assert metric_result.dict() == {"id": metric_id, "value": metric_value}

def test_table_columns_metric_result(self):
metric_id = ("table.columns", "a8ef4ee749d02d0e5f92719fc6ee8010", ())
metric_id = MetricConfigurationID(
metric_name="table.columns",
metric_domain_kwargs_id="a8ef4ee749d02d0e5f92719fc6ee8010",
metric_value_kwargs_id=(),
)
metric_value = [
"existing_column",
"another_existing_column",
]
metric_result = TableColumns(
metric_result = TableColumnsResult(
id=metric_id,
value=metric_value,
)
assert not metric_result.error
assert all(isinstance(val, str) for val in metric_result.value)
assert metric_result.dict() == {"id": metric_id, "value": metric_value}

def test_table_column_types_result(self):
metric_id = (
"table.column_types",
"a8ef4ee749d02d0e5f92719fc6ee8010",
"include_nested=True",
metric_id = MetricConfigurationID(
metric_name="table.column_types",
metric_domain_kwargs_id="a8ef4ee749d02d0e5f92719fc6ee8010",
metric_value_kwargs_id="include_nested=True",
)
metric_value = [
{"name": "existing_column", "type": "int64"},
{"name": "another_existing_column", "type": "object"},
]

metric_result = TableColumnTypes(
metric_result = TableColumnTypesResult(
id=metric_id,
value=metric_value,
)

assert not metric_result.error
assert all(isinstance(val, _ColumnType) for val in metric_result.value)
assert metric_result.dict() == {"id": metric_id, "value": metric_value}

def test_metric_error_result(self):
metric_id = MetricConfigurationID(
metric_name="column.mean",
metric_domain_kwargs_id="8a975130e802d66f85ab0cac8d10fbec",
metric_value_kwargs_id=(),
)
metric_value = {
"exception_info": {
"exception_traceback": "Traceback (most recent call last)...",
"exception_message": "Error: The column does not exist.",
"raised_exception": True,
},
"metric_configuration": {
"metric_name": "column.mean",
"metric_domain_kwargs": {"table": None, "column": "doesnt exist"},
"metric_d..._kwargs": {},
"metric_value_kwargs_id": [],
"id": ["column.mean", "8a975130e802d66f85ab0cac8d10fbec", []],
},
"num_failures": 3,
}
metric_result = MetricErrorResult(
id=metric_id,
value=metric_value,
)
assert metric_result.dict() == {"id": metric_id, "value": metric_value}

0 comments on commit daca188

Please sign in to comment.