diff --git a/tests/README.md b/tests/README.md index f4f5e02..b902279 100644 --- a/tests/README.md +++ b/tests/README.md @@ -8,8 +8,10 @@ 2. Run the test suite: ```bash python test_case.py --model_dir "jan-hq/Jan-Llama3-0708" \\ + --max_length 1024 \\ + --data_dir "jan-hq/instruction-speech-conversation-test" \\ --mode "audio" \\ - --num_rows 100 \\ + --num_rows 5 \\ ``` ## Test Configuration diff --git a/tests/test_case.py b/tests/test_case.py index da95f5b..6743849 100644 --- a/tests/test_case.py +++ b/tests/test_case.py @@ -9,10 +9,87 @@ from nltk.translate.bleu_score import sentence_bleu, corpus_bleu import argparse import os +import sys +from io import StringIO +import time +# Decorator Class +class CustomTestResult(unittest.TestResult): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.successes = [] + + def addSuccess(self, test): + super().addSuccess(test) + self.successes.append(test) + +class CustomTestRunner(unittest.TextTestRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stream = StringIO() + self.results = [] + + def run(self, test): + result = CustomTestResult() + start_time = time.time() + test(result) + time_taken = time.time() - start_time + self.results.append((result, time_taken)) + return result + + def print_results(self): + print("\n=== Test Results ===") + total_tests = 0 + total_successes = 0 + total_failures = 0 + total_errors = 0 + total_time = 0 + + for result, time_taken in self.results: + total_tests += result.testsRun + total_successes += len(result.successes) + total_failures += len(result.failures) + total_errors += len(result.errors) + total_time += time_taken + + print(f"Ran {total_tests} tests in {total_time:.3f} seconds") + print(f"Successes: {total_successes}") + print(f"Failures: {total_failures}") + print(f"Errors: {total_errors}") + + print("\nDetailed Results:") + for result, time_taken in self.results: + # todo: add time taken for each test + for test in result.successes: + print(f"PASS: {test._testMethodName}") + for test, _ in result.failures: + print(f"FAIL: {test._testMethodName}") + for test, _ in result.errors: + test_name = getattr(test, '_testMethodName', str(test)) + print(f"ERROR: {test_name}") + + if total_failures > 0 or total_errors > 0: + print("\nFailure and Error Details:") + for result, _ in self.results: + for test, traceback in result.failures: + print(f"\nFAILURE: {test._testMethodName}") + print(traceback) + for test, traceback in result.errors: + test_name = getattr(test, '_testMethodName', str(test)) + print(f"\nERROR: {test_name}") + print(traceback) + else: + print("\nAll tests passed successfully!") + +def test_name(name): + def decorator(func): + func.__name__ = name + return func + return decorator def parse_arguments(): parser = argparse.ArgumentParser(description="Run inference on a Sound-To-Text Model.") parser.add_argument("--model_dir", type=str, default="jan-hq/Jan-Llama3-0708", help="Hugging Face model link or local_dir") + parser.add_argument("--max_length", type=int, default=1024, help="Maximum length of the output") parser.add_argument("--data_dir", type=str, required=True, help="Hugging Face model repository link or Data path") parser.add_argument("--cache_dir", type=str, default=".", help="Absolute path to save the model and dataset") parser.add_argument("--mode", type=str, default="audio", help="Mode of the model (audio or text)") @@ -29,18 +106,20 @@ def setUpClass(cls): cls.save_dir_output = f'{args.output_file}/{model_name}-{args.mode}-Result.csv' if not os.path.exists(args.output_file): os.makedirs(args.output_file) - cls.sampling_params = SamplingParams(temperature=0.0, max_tokens=1024, skip_special_tokens=False) + cls.sampling_params = SamplingParams(temperature=0.0, max_tokens=args.max_length, skip_special_tokens=False) # Download model - if not os.path.exists(args.model_dir): - snapshot_download(args.model_dir, local_dir=args.cache_dir, max_workers=64) + model_save_dir = os.path.join(args.cache_dir, args.model_dir) + if not os.path.exists(model_save_dir): + snapshot_download(args.model_dir, local_dir=model_save_dir, max_workers=64) else: - print(f"Found {args.model_dir}. Skipping download.") + print(f"Found {model_save_dir}. Skipping download.") # Model loading using vllm - cls.tokenizer = AutoTokenizer.from_pretrained(args.cache_dir) - cls.llm = LLM(args.cache_dir, tokenizer=args.cache_dir) + cls.tokenizer = AutoTokenizer.from_pretrained(model_save_dir) + cls.llm = LLM(model_save_dir, tokenizer=model_save_dir) # Load dataset - cls.dataset = load_dataset(args.data_dir, cache_dir=args.cache_dir)['train'] + data_save_dir = os.path.join(args.cache_dir, args.data_dir) + cls.dataset = load_dataset(args.data_dir, split='train') cls.num_rows = min(args.num_rows, len(cls.dataset)) cls.inference_results = [] if args.mode == "audio": @@ -83,6 +162,7 @@ def vllm_qna_inference(self, sample_id): # return input_str, output_based_on_input, expected_output_str, output_token_ids + @test_name("Output validation (non-empty, correct type)") def test_model_output(self): for text_input_str, output_based_on_sound, expected_output_str, output_token_ids in self.inference_results: # Test 1: Check if output is not empty @@ -99,7 +179,7 @@ def test_model_output(self): # output_words = set(output_based_on_sound.lower().split()) # relevance_score = corpus_bleu(output_words, reference_words) # self.assertGreater(relevance_score, 0.3) - + @test_name("Test Special Tokens Handling") def test_special_tokens(self): # Test 5: Check if special tokens are handled correctly special_tokens = [self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token] @@ -114,12 +194,12 @@ def test_special_tokens(self): # results = [self.inference_results[0][1] for _ in range(3)] # self.assertEqual(results[0], results[1]) # self.assertEqual(results[1], results[2]) - + @test_name("Test for NaN outputs") def test_no_nan_outputs(self): # Test 7: Check for NaN outputs for _, output, _, _ in self.inference_results: self.assertFalse(any(np.isnan(float(word)) for word in output.split() if word.replace('.', '').isdigit())) - + @test_name("Test for EOS token generation") def test_eos_token_generation(self): # Test 8: Check if EOS token is generated for _, output_based_on_sound, _, output_token_ids in self.inference_results: @@ -137,4 +217,6 @@ def test_eos_token_generation(self): if __name__ == "__main__": - unittest.main(argv=['first-arg-is-ignored'], exit=False) \ No newline at end of file + runner = CustomTestRunner(stream=sys.stdout, verbosity=2) + unittest.main(argv=['first-arg-is-ignored'], exit=False, testRunner=runner) + runner.print_results() \ No newline at end of file