From 8c17cf41280960f087538b78073853b0486f6c6a Mon Sep 17 00:00:00 2001 From: Eyob Date: Mon, 27 May 2024 12:04:28 +0300 Subject: [PATCH] Refactor validation logic and add global regeneration max parameter --- src/sherpa_ai/agents/base.py | 39 +++++++++---------- src/sherpa_ai/output_parsers/base.py | 3 ++ .../test_regression_validator_flow.py | 12 +++--- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/sherpa_ai/agents/base.py b/src/sherpa_ai/agents/base.py index 87c4e00e..1f87369e 100644 --- a/src/sherpa_ai/agents/base.py +++ b/src/sherpa_ai/agents/base.py @@ -106,22 +106,21 @@ def run(self): self.shared_memory.add(EventType.result, self.name, result) return result - - # The validation_iterator function is responsible for iterating through each instantiated validation in the 'instantiated_validations' list. + # The validation_iterator function is responsible for iterating through each instantiated validation in the 'self.validations' list. # It performs the necessary validation steps for each validation, updating the belief system and synthesizing output if needed. # It keeps track of the global regeneration count, whether all validations have passed, and if any validation has been escaped. # The function returns the updated global regeneration count, the status of all validations, whether any validation has been escaped, and the synthesized output. def validation_iterator( self, - instantiated_validations, + validations, global_regen_count, all_pass, validation_is_scaped, result, ): - - for i in range(len(instantiated_validations)): - validation = instantiated_validations[i] + + for i in range(len(validations)): + validation = validations[i] logger.info(f"validation_running: {validation.__class__.__name__}") logger.info(f"validation_count: {validation.count}") # this checks if the validator has already exceeded the validation steps limit. @@ -142,16 +141,16 @@ def validation_iterator( break # if all validations passed then set all_pass to True - elif i == len(instantiated_validations) - 1: + elif i == len(validations) - 1: result = validation_result.result all_pass = True else: result = validation_result.result # if validation is the last one and surpassed the validation steps limit then finish the loop with all_pass and mention there is a scaped validation. - elif i == len(instantiated_validations) - 1: + elif i == len(validations) - 1: validation_is_scaped = True all_pass = True - + else: # if the validation has already reached the validation steps limit then continue to the next validation. validation_is_scaped = True @@ -173,21 +172,21 @@ def validate_output(self): """ result = "" # create array of instance of validation so that we can keep track of how many times regeneration happened. - instantiated_validations = [] - for validation in self.validations: - try: - instantiated_validations.append(validation()) - except Exception as e: - instantiated_validations.append(validation) - all_pass = False validation_is_scaped = False iteration_count = 0 result = self.synthesize_output() global_regen_count = 0 + + # reset the state of all the validation before starting the validation process. + for validation in self.validations: + validation.reset_state() + + validations = self.validations + # this loop will run until max regeneration reached or all validations have failed while self.global_regen_max > global_regen_count and not all_pass: - logger.info(f"validations_size: {len(instantiated_validations)}") + logger.info(f"validations_size: {len(validations)}") iteration_count += 1 logger.info(f"main_iteration: {iteration_count}") logger.info(f"regen_count: {global_regen_count}") @@ -197,7 +196,7 @@ def validate_output(self): all_pass=all_pass, global_regen_count=global_regen_count, validation_is_scaped=validation_is_scaped, - instantiated_validations=instantiated_validations, + validations=validations, result=result, ) ) @@ -205,7 +204,7 @@ def validate_output(self): if validation_is_scaped or self.global_regen_max >= global_regen_count: failed_validations = [] - for validation in instantiated_validations: + for validation in validations: validation_result = validation.process_output( text=result, belief=self.belief, llm=self.llm ) @@ -228,7 +227,7 @@ def validate_output(self): if inst_val.count == self.validation_steps else "" ) - for inst_val in instantiated_validations + for inst_val in validations ) self.belief.update_internal(EventType.result, self.name, result) diff --git a/src/sherpa_ai/output_parsers/base.py b/src/sherpa_ai/output_parsers/base.py index 110a5978..46dbead3 100644 --- a/src/sherpa_ai/output_parsers/base.py +++ b/src/sherpa_ai/output_parsers/base.py @@ -54,6 +54,9 @@ class BaseOutputProcessor(ABC): count: int = 0 + def reset_state(self): + self.count = 0 + @abstractmethod def process_output(self, text: str, **kwargs) -> ValidationResult: """ diff --git a/src/tests/integration_tests/test_regression_validator_flow.py b/src/tests/integration_tests/test_regression_validator_flow.py index 46560895..b1240987 100644 --- a/src/tests/integration_tests/test_regression_validator_flow.py +++ b/src/tests/integration_tests/test_regression_validator_flow.py @@ -51,10 +51,10 @@ def test_regression_validator_flow( agent_pool=None, ) - entity_validation = EntityValidation - number_validation = NumberValidation - number_validation_two = NumberValidation - citation_validation = CitationValidation + entity_validation = EntityValidation() + number_validation = NumberValidation() + number_validation_two = NumberValidation() + citation_validation = CitationValidation() with patch.object(SearchTool, "_run", return_value=data): task_agent = QAAgent( llm=llm, @@ -82,7 +82,7 @@ def test_regression_validator_flow( logger.info(results[0].content) final_result = results[0].content - if not entity_validation().get_failure_message() in final_result: + if not entity_validation.get_failure_message() in final_result: result_entities = [s.lower() for s in extract_entities(results[0].content)] expected_entities = [s.lower() for s in expected_entities] for entity in expected_entities: @@ -96,7 +96,7 @@ def test_regression_validator_flow( if not match_found: assert False, entity + " was not found in resource" - if not number_validation().get_failure_message() in final_result: + if not number_validation.get_failure_message() in final_result: for number in expected_number: if not number in combined_number_extractor(results[0].content): assert False, number + " was not found in resource"