diff --git a/ingestion/src/metadata/data_quality/builders/i_validator_builder.py b/ingestion/src/metadata/data_quality/builders/i_validator_builder.py index 8cfabb96d5d4..d24df9034178 100644 --- a/ingestion/src/metadata/data_quality/builders/i_validator_builder.py +++ b/ingestion/src/metadata/data_quality/builders/i_validator_builder.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import TYPE_CHECKING, Optional, Type, Union +from typing import TYPE_CHECKING, Set, Type, Union from metadata.data_quality.validations.base_test_handler import BaseTestValidator from metadata.data_quality.validations.runtime_param_setter.param_setter import ( @@ -66,24 +66,20 @@ def __init__( ) self.reset() - def set_runtime_params( - self, runtime_params_setter: Optional[RuntimeParameterSetter] - ): + def set_runtime_params(self, runtime_params_setters: Set[RuntimeParameterSetter]): """Set the runtime parameters for the validator object - # TODO: We should support setting n runtime parameters - Args: - runtime_params_setter (Optional[RuntimeParameterSetter]): The runtime parameter setter + runtime_params_setters (Optional[RuntimeParameterSetter]): The runtime parameter setter """ - if runtime_params_setter: - params = runtime_params_setter.get_parameters(self.test_case) + for setter in runtime_params_setters: + params = setter.get_parameters(self.test_case) if not self.test_case.parameterValues: # If there are no parameters, create a new list self.test_case.parameterValues = [] self.test_case.parameterValues.append( TestCaseParameterValue( - name="runtimeParams", value=params.model_dump_json() + name=type(params).__name__, value=params.model_dump_json() ) ) diff --git a/ingestion/src/metadata/data_quality/interface/test_suite_interface.py b/ingestion/src/metadata/data_quality/interface/test_suite_interface.py index 79c6d8303aa5..bcf02a41af05 100644 --- a/ingestion/src/metadata/data_quality/interface/test_suite_interface.py +++ b/ingestion/src/metadata/data_quality/interface/test_suite_interface.py @@ -15,7 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Type +from typing import Optional, Set, Type from metadata.data_quality.builders.i_validator_builder import IValidatorBuilder from metadata.data_quality.validations.base_test_handler import BaseTestValidator @@ -97,9 +97,9 @@ def run_test_case(self, test_case: TestCase) -> Optional[TestCaseResult]: runtime_params_setter_fact: RuntimeParameterSetterFactory = ( self._get_runtime_params_setter_fact() ) # type: ignore - runtime_params_setter: Optional[ + runtime_params_setters: Set[ RuntimeParameterSetter - ] = runtime_params_setter_fact.get_runtime_param_setter( + ] = runtime_params_setter_fact.get_runtime_param_setters( test_case.testDefinition.fullyQualifiedName, # type: ignore self.ometa_client, self.service_connection_config, @@ -113,7 +113,7 @@ def run_test_case(self, test_case: TestCase) -> Optional[TestCaseResult]: ).entityType.value validator_builder = self._get_validator_builder(test_case, entity_type) - validator_builder.set_runtime_params(runtime_params_setter) + validator_builder.set_runtime_params(runtime_params_setters) validator: BaseTestValidator = validator_builder.validator try: return validator.run_validation() diff --git a/ingestion/src/metadata/data_quality/validations/base_test_handler.py b/ingestion/src/metadata/data_quality/validations/base_test_handler.py index 3e1363c93864..ddd49e1dc9cf 100644 --- a/ingestion/src/metadata/data_quality/validations/base_test_handler.py +++ b/ingestion/src/metadata/data_quality/validations/base_test_handler.py @@ -19,10 +19,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Callable, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel + from metadata.data_quality.validations import utils -from metadata.data_quality.validations.runtime_param_setter.param_setter import ( - RuntimeParameterSetter, -) from metadata.generated.schema.tests.basic import ( TestCaseResult, TestCaseStatus, @@ -37,6 +36,7 @@ T = TypeVar("T", bound=Callable) R = TypeVar("R") +S = TypeVar("S", bound=BaseModel) class BaseTestValidator(ABC): @@ -45,8 +45,6 @@ class BaseTestValidator(ABC): This can be useful to resolve complex test parameters based on the parameters gibven by the user. """ - runtime_parameter_setter: Optional[Type[RuntimeParameterSetter]] = None - def __init__( self, runner: Union[QueryRunner, List["DataFrame"]], @@ -168,3 +166,10 @@ def get_max_bound(self, param_name: str) -> Optional[float]: def get_predicted_value(self) -> Optional[str]: """Get predicted value""" return None + + def get_runtime_parameters(self, setter_class: Type[S]) -> S: + """Get runtime parameters""" + for param in self.test_case.parameterValues or []: + if param.name == setter_class.__name__: + return setter_class.model_validate_json(param.value) + raise ValueError(f"Runtime parameter {setter_class.__name__} not found") diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/param_setter_factory.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/param_setter_factory.py index 15653fee4427..f8f03935faed 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/param_setter_factory.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/param_setter_factory.py @@ -13,8 +13,8 @@ This class is responsible for creating instances of the RuntimeParameterSetter based on the test case. """ - -from typing import Optional +import sys +from typing import Dict, Set, Type from metadata.data_quality.validations.runtime_param_setter.param_setter import ( RuntimeParameterSetter, @@ -22,7 +22,35 @@ from metadata.data_quality.validations.runtime_param_setter.table_diff_params_setter import ( TableDiffParamsSetter, ) +from metadata.data_quality.validations.table.sqlalchemy.tableDiff import ( + TableDiffValidator, +) +from metadata.generated.schema.entity.data.table import Table from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.profiler.processor.sampler.sqlalchemy.sampler import SQASampler + + +def removesuffix(s: str, suffix: str) -> str: + """A custom implementation of removesuffix for python versions < 3.9 + + Args: + s (str): The string to remove the suffix from + suffix (str): The suffix to remove + + Returns: + str: The string with the suffix removed + """ + if sys.version_info >= (3, 9): + return s.removesuffix(suffix) + if s.endswith(suffix): + return s[: -len(suffix)] + return s + + +def validator_name(test_case_class: Type) -> str: + return removesuffix( + test_case_class.__name__[0].lower() + test_case_class.__name__[1:], "Validator" + ) class RuntimeParameterSetterFactory: @@ -30,25 +58,25 @@ class RuntimeParameterSetterFactory: def __init__(self) -> None: """Set""" - self._setter_map = { - TableDiffParamsSetter: {"tableDiff"}, + self._setter_map: Dict[str, Set[Type[RuntimeParameterSetter]]] = { + validator_name(TableDiffValidator): {TableDiffParamsSetter}, } - def get_runtime_param_setter( + def get_runtime_param_setters( self, name: str, ometa: OpenMetadata, service_connection_config, - table_entity, - sampler, - ) -> Optional[RuntimeParameterSetter]: + table_entity: Table, + sampler: SQASampler, + ) -> Set[RuntimeParameterSetter]: """Get the runtime parameter setter""" - for setter_cls, validator_names in self._setter_map.items(): - if name in validator_names: - return setter_cls( - ometa, - service_connection_config, - table_entity, - sampler, - ) - return None + return { + setter( + ometa, + service_connection_config, + table_entity, + sampler, + ) + for setter in self._setter_map.get(name, set()) + } diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index 02aea8f77417..b081323173fd 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -143,7 +143,7 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): runtime_params: TableDiffRuntimeParameters def run_validation(self) -> TestCaseResult: - self.runtime_params = self.get_runtime_params() + self.runtime_params = self.get_runtime_parameters(TableDiffRuntimeParameters) try: self._validate_dialects() return self._run() @@ -414,13 +414,6 @@ def calculate_nounce(self, max_nounce=2**32 - 1) -> int: ) raise ValueError("Invalid profile sample type") - def get_runtime_params(self) -> TableDiffRuntimeParameters: - raw = self.get_test_case_param_value( - self.test_case.parameterValues, "runtimeParams", str - ) - runtime_params = TableDiffRuntimeParameters.model_validate_json(raw) - return runtime_params - def get_row_diff_test_case_result( self, threshold: int,