Skip to content

Commit

Permalink
Merge branch 'main' into fix-slack-app
Browse files Browse the repository at this point in the history
  • Loading branch information
Siddhanttimeline authored Nov 21, 2024
2 parents 13bce1b + 0169aad commit f39ae3d
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional, Type, Union
from typing import TYPE_CHECKING, Set, Type, Union

from metadata.data_quality.validations.base_test_handler import BaseTestValidator
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
Expand Down Expand Up @@ -66,24 +66,20 @@ def __init__(
)
self.reset()

def set_runtime_params(
self, runtime_params_setter: Optional[RuntimeParameterSetter]
):
def set_runtime_params(self, runtime_params_setters: Set[RuntimeParameterSetter]):
"""Set the runtime parameters for the validator object
# TODO: We should support setting n runtime parameters
Args:
runtime_params_setter (Optional[RuntimeParameterSetter]): The runtime parameter setter
runtime_params_setters (Optional[RuntimeParameterSetter]): The runtime parameter setter
"""
if runtime_params_setter:
params = runtime_params_setter.get_parameters(self.test_case)
for setter in runtime_params_setters:
params = setter.get_parameters(self.test_case)
if not self.test_case.parameterValues:
# If there are no parameters, create a new list
self.test_case.parameterValues = []
self.test_case.parameterValues.append(
TestCaseParameterValue(
name="runtimeParams", value=params.model_dump_json()
name=type(params).__name__, value=params.model_dump_json()
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from abc import ABC, abstractmethod
from typing import Optional, Type
from typing import Optional, Set, Type

from metadata.data_quality.builders.i_validator_builder import IValidatorBuilder
from metadata.data_quality.validations.base_test_handler import BaseTestValidator
Expand Down Expand Up @@ -111,9 +111,9 @@ def run_test_case(self, test_case: TestCase) -> Optional[TestCaseResult]:
runtime_params_setter_fact: RuntimeParameterSetterFactory = (
self._get_runtime_params_setter_fact()
) # type: ignore
runtime_params_setter: Optional[
runtime_params_setters: Set[
RuntimeParameterSetter
] = runtime_params_setter_fact.get_runtime_param_setter(
] = runtime_params_setter_fact.get_runtime_param_setters(
test_case.testDefinition.fullyQualifiedName, # type: ignore
self.ometa_client,
self.service_connection_config,
Expand All @@ -127,7 +127,7 @@ def run_test_case(self, test_case: TestCase) -> Optional[TestCaseResult]:
).entityType.value

validator_builder = self._get_validator_builder(test_case, entity_type)
validator_builder.set_runtime_params(runtime_params_setter)
validator_builder.set_runtime_params(runtime_params_setters)
validator: BaseTestValidator = validator_builder.validator
try:
return validator.run_validation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, List, Optional, Type, TypeVar, Union

from pydantic import BaseModel

from metadata.data_quality.validations import utils
from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.generated.schema.tests.basic import (
TestCaseResult,
TestCaseStatus,
Expand All @@ -37,6 +36,7 @@

T = TypeVar("T", bound=Callable)
R = TypeVar("R")
S = TypeVar("S", bound=BaseModel)


class BaseTestValidator(ABC):
Expand All @@ -45,8 +45,6 @@ class BaseTestValidator(ABC):
This can be useful to resolve complex test parameters based on the parameters gibven by the user.
"""

runtime_parameter_setter: Optional[Type[RuntimeParameterSetter]] = None

def __init__(
self,
runner: Union[QueryRunner, List["DataFrame"]],
Expand Down Expand Up @@ -168,3 +166,10 @@ def get_max_bound(self, param_name: str) -> Optional[float]:
def get_predicted_value(self) -> Optional[str]:
"""Get predicted value"""
return None

def get_runtime_parameters(self, setter_class: Type[S]) -> S:
"""Get runtime parameters"""
for param in self.test_case.parameterValues or []:
if param.name == setter_class.__name__:
return setter_class.model_validate_json(param.value)
raise ValueError(f"Runtime parameter {setter_class.__name__} not found")
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,70 @@
This class is responsible for creating instances of the RuntimeParameterSetter
based on the test case.
"""

from typing import Optional
import sys
from typing import Dict, Set, Type

from metadata.data_quality.validations.runtime_param_setter.param_setter import (
RuntimeParameterSetter,
)
from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import (
TableDiffParamsSetter,
)
from metadata.data_quality.validations.table.sqlalchemy.tableDiff import (
TableDiffValidator,
)
from metadata.generated.schema.entity.data.table import Table
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.profiler.processor.sampler.sqlalchemy.sampler import SQASampler


def removesuffix(s: str, suffix: str) -> str:
"""A custom implementation of removesuffix for python versions < 3.9
Args:
s (str): The string to remove the suffix from
suffix (str): The suffix to remove
Returns:
str: The string with the suffix removed
"""
if sys.version_info >= (3, 9):
return s.removesuffix(suffix)
if s.endswith(suffix):
return s[: -len(suffix)]
return s


def validator_name(test_case_class: Type) -> str:
return removesuffix(
test_case_class.__name__[0].lower() + test_case_class.__name__[1:], "Validator"
)


class RuntimeParameterSetterFactory:
"""runtime parameter setter factory class"""

def __init__(self) -> None:
"""Set"""
self._setter_map = {
TableDiffParamsSetter: {"tableDiff"},
self._setter_map: Dict[str, Set[Type[RuntimeParameterSetter]]] = {
validator_name(TableDiffValidator): {TableDiffParamsSetter},
}

def get_runtime_param_setter(
def get_runtime_param_setters(
self,
name: str,
ometa: OpenMetadata,
service_connection_config,
table_entity,
sampler,
) -> Optional[RuntimeParameterSetter]:
table_entity: Table,
sampler: SQASampler,
) -> Set[RuntimeParameterSetter]:
"""Get the runtime parameter setter"""
for setter_cls, validator_names in self._setter_map.items():
if name in validator_names:
return setter_cls(
ometa,
service_connection_config,
table_entity,
sampler,
)
return None
return {
setter(
ometa,
service_connection_config,
table_entity,
sampler,
)
for setter in self._setter_map.get(name, set())
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin):
runtime_params: TableDiffRuntimeParameters

def run_validation(self) -> TestCaseResult:
self.runtime_params = self.get_runtime_params()
self.runtime_params = self.get_runtime_parameters(TableDiffRuntimeParameters)
try:
self._validate_dialects()
return self._run()
Expand Down Expand Up @@ -438,13 +438,6 @@ def calculate_nounce(self, max_nounce=2**32 - 1) -> int:
)
raise ValueError("Invalid profile sample type")

def get_runtime_params(self) -> TableDiffRuntimeParameters:
raw = self.get_test_case_param_value(
self.test_case.parameterValues, "runtimeParams", str
)
runtime_params = TableDiffRuntimeParameters.model_validate_json(raw)
return runtime_params

def get_row_diff_test_case_result(
self,
threshold: int,
Expand Down

0 comments on commit f39ae3d

Please sign in to comment.