Skip to content

Commit

Permalink
Merge pull request #21 from janhq/CI-CD/bach
Browse files Browse the repository at this point in the history
Ci cd/bach
  • Loading branch information
hahuyhoang411 authored Jul 13, 2024
2 parents 02af579 + 85c4070 commit 546dba7
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 12 deletions.
4 changes: 3 additions & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
2. Run the test suite:
```bash
python test_case.py --model_dir "jan-hq/Jan-Llama3-0708" \\
--max_length 1024 \\
--data_dir "jan-hq/instruction-speech-conversation-test" \\
--mode "audio" \\
--num_rows 100 \\
--num_rows 5 \\
```
## Test Configuration

Expand Down
104 changes: 93 additions & 11 deletions tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,87 @@
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
import argparse
import os
import sys
from io import StringIO
import time
# Decorator Class
class CustomTestResult(unittest.TestResult):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.successes = []

def addSuccess(self, test):
super().addSuccess(test)
self.successes.append(test)

class CustomTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stream = StringIO()
self.results = []

def run(self, test):
result = CustomTestResult()
start_time = time.time()
test(result)
time_taken = time.time() - start_time
self.results.append((result, time_taken))
return result

def print_results(self):
print("\n=== Test Results ===")
total_tests = 0
total_successes = 0
total_failures = 0
total_errors = 0
total_time = 0

for result, time_taken in self.results:
total_tests += result.testsRun
total_successes += len(result.successes)
total_failures += len(result.failures)
total_errors += len(result.errors)
total_time += time_taken

print(f"Ran {total_tests} tests in {total_time:.3f} seconds")
print(f"Successes: {total_successes}")
print(f"Failures: {total_failures}")
print(f"Errors: {total_errors}")

print("\nDetailed Results:")
for result, time_taken in self.results:
# todo: add time taken for each test
for test in result.successes:
print(f"PASS: {test._testMethodName}")
for test, _ in result.failures:
print(f"FAIL: {test._testMethodName}")
for test, _ in result.errors:
test_name = getattr(test, '_testMethodName', str(test))
print(f"ERROR: {test_name}")

if total_failures > 0 or total_errors > 0:
print("\nFailure and Error Details:")
for result, _ in self.results:
for test, traceback in result.failures:
print(f"\nFAILURE: {test._testMethodName}")
print(traceback)
for test, traceback in result.errors:
test_name = getattr(test, '_testMethodName', str(test))
print(f"\nERROR: {test_name}")
print(traceback)
else:
print("\nAll tests passed successfully!")

def test_name(name):
def decorator(func):
func.__name__ = name
return func
return decorator

def parse_arguments():
parser = argparse.ArgumentParser(description="Run inference on a Sound-To-Text Model.")
parser.add_argument("--model_dir", type=str, default="jan-hq/Jan-Llama3-0708", help="Hugging Face model link or local_dir")
parser.add_argument("--max_length", type=int, default=1024, help="Maximum length of the output")
parser.add_argument("--data_dir", type=str, required=True, help="Hugging Face model repository link or Data path")
parser.add_argument("--cache_dir", type=str, default=".", help="Absolute path to save the model and dataset")
parser.add_argument("--mode", type=str, default="audio", help="Mode of the model (audio or text)")
Expand All @@ -29,18 +106,20 @@ def setUpClass(cls):
cls.save_dir_output = f'{args.output_file}/{model_name}-{args.mode}-Result.csv'
if not os.path.exists(args.output_file):
os.makedirs(args.output_file)
cls.sampling_params = SamplingParams(temperature=0.0, max_tokens=1024, skip_special_tokens=False)
cls.sampling_params = SamplingParams(temperature=0.0, max_tokens=args.max_length, skip_special_tokens=False)
# Download model
if not os.path.exists(args.model_dir):
snapshot_download(args.model_dir, local_dir=args.cache_dir, max_workers=64)
model_save_dir = os.path.join(args.cache_dir, args.model_dir)
if not os.path.exists(model_save_dir):
snapshot_download(args.model_dir, local_dir=model_save_dir, max_workers=64)
else:
print(f"Found {args.model_dir}. Skipping download.")
print(f"Found {model_save_dir}. Skipping download.")
# Model loading using vllm
cls.tokenizer = AutoTokenizer.from_pretrained(args.cache_dir)
cls.llm = LLM(args.cache_dir, tokenizer=args.cache_dir)
cls.tokenizer = AutoTokenizer.from_pretrained(model_save_dir)
cls.llm = LLM(model_save_dir, tokenizer=model_save_dir)

# Load dataset
cls.dataset = load_dataset(args.data_dir, cache_dir=args.cache_dir)['train']
data_save_dir = os.path.join(args.cache_dir, args.data_dir)
cls.dataset = load_dataset(args.data_dir, split='train')
cls.num_rows = min(args.num_rows, len(cls.dataset))
cls.inference_results = []
if args.mode == "audio":
Expand Down Expand Up @@ -83,6 +162,7 @@ def vllm_qna_inference(self, sample_id):


# return input_str, output_based_on_input, expected_output_str, output_token_ids
@test_name("Output validation (non-empty, correct type)")
def test_model_output(self):
for text_input_str, output_based_on_sound, expected_output_str, output_token_ids in self.inference_results:
# Test 1: Check if output is not empty
Expand All @@ -99,7 +179,7 @@ def test_model_output(self):
# output_words = set(output_based_on_sound.lower().split())
# relevance_score = corpus_bleu(output_words, reference_words)
# self.assertGreater(relevance_score, 0.3)

@test_name("Test Special Tokens Handling")
def test_special_tokens(self):
# Test 5: Check if special tokens are handled correctly
special_tokens = [self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]
Expand All @@ -114,12 +194,12 @@ def test_special_tokens(self):
# results = [self.inference_results[0][1] for _ in range(3)]
# self.assertEqual(results[0], results[1])
# self.assertEqual(results[1], results[2])

@test_name("Test for NaN outputs")
def test_no_nan_outputs(self):
# Test 7: Check for NaN outputs
for _, output, _, _ in self.inference_results:
self.assertFalse(any(np.isnan(float(word)) for word in output.split() if word.replace('.', '').isdigit()))

@test_name("Test for EOS token generation")
def test_eos_token_generation(self):
# Test 8: Check if EOS token is generated
for _, output_based_on_sound, _, output_token_ids in self.inference_results:
Expand All @@ -137,4 +217,6 @@ def test_eos_token_generation(self):


if __name__ == "__main__":
unittest.main(argv=['first-arg-is-ignored'], exit=False)
runner = CustomTestRunner(stream=sys.stdout, verbosity=2)
unittest.main(argv=['first-arg-is-ignored'], exit=False, testRunner=runner)
runner.print_results()

0 comments on commit 546dba7

Please sign in to comment.