Skip to content

Commit

Permalink
refactor the code using pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenhoangthuan99 committed Aug 6, 2024
1 parent 8a2b8d2 commit 3eb76cc
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 4 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test-branch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ jobs:
- name: Run tests
working-directory: ./tests
run: |
python3 test_case.py --model_dir ${{ github.event.inputs.model_id || 'jan-hq/Jan-Llama3-0708' }} --data_dir ${{ github.event.inputs.dataset_id || 'jan-hq/instruction-speech-conversation-test' }} ${{ github.event.inputs.extra_args || '--mode audio --num_rows 5' }}
pytest unit_test.py -v -s --cov=. --cov-report=xml --model_dir ${{ github.event.inputs.model_id }} --data_dir ${{ github.event.inputs.dataset_id }} ${{ github.event.inputs.extra_args }}
2 changes: 1 addition & 1 deletion .github/workflows/test-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Run tests
working-directory: ./tests
run: |
python3 test_case.py --model_dir ${{ github.event.inputs.model_id }} --data_dir ${{ github.event.inputs.dataset_id }} ${{ github.event.inputs.extra_args }}
pytest unit_test.py -v -s --cov=. --cov-report=xml --model_dir ${{ github.event.inputs.model_id }} --data_dir ${{ github.event.inputs.dataset_id }} ${{ github.event.inputs.extra_args }}
- name: Install benchmark dependencies
if: ${{ github.event.inputs.run_benchmark == 'true' }}
Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

def pytest_addoption(parser):
parser.addoption("--model_dir", type=str, default="jan-hq/Jan-Llama3-0708", help="Hugging Face model link or local_dir")
parser.addoption("--max_length", type=int, default=1024, help="Maximum length of the output")
parser.addoption("--data_dir", type=str, required=True, help="Hugging Face model repository link or Data path")
parser.addoption("--cache_dir", type=str, default=".", help="Absolute path to save the model and dataset")
parser.addoption("--mode", type=str, default="audio", help="Mode of the model (audio or text)")
parser.addoption("--num_rows", type=int, default=5, help="Number of dataset rows to process")
parser.addoption("--output_file", type=str, default="output/", help="Output file path")

@pytest.fixture(scope="session")
def custom_args(request):
return {
"model_dir": request.config.getoption("--model_dir"),
"max_length": request.config.getoption("--max_length"),
"data_dir": request.config.getoption("--data_dir"),
"cache_dir": request.config.getoption("--cache_dir"),
"mode": request.config.getoption("--mode"),
"num_rows": request.config.getoption("--num_rows"),
"output_file": request.config.getoption("--output_file"),
}
6 changes: 4 additions & 2 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
datasets==2.20.0
torch==2.3.0
datasets==2.20.0
transformers
vllm
huggingface_hub==0.23.4
pandas==2.2.2
nltk
nltk
pytest
pytest-cov
102 changes: 102 additions & 0 deletions tests/unit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import pandas as pd
import numpy as np
import os
import time

@pytest.fixture(scope="module")
def model_setup(custom_args):
args = custom_args
model_name = args.model_dir.split("/")[-1]
save_dir_output = f'{args.output_file}/{model_name}-{args.mode}-Result.csv'
if not os.path.exists(args.output_file):
os.makedirs(args.output_file)

sampling_params = SamplingParams(temperature=0.0, max_tokens=args.max_length, skip_special_tokens=False)

model_save_dir = os.path.join(args.cache_dir, args.model_dir)
if not os.path.exists(model_save_dir):
snapshot_download(args.model_dir, local_dir=model_save_dir, max_workers=64)
else:
print(f"Found {model_save_dir}. Skipping download.")

tokenizer = AutoTokenizer.from_pretrained(model_save_dir)
llm = LLM(model_save_dir, tokenizer=model_save_dir, gpu_memory_utilization=0.3)

data_save_dir = os.path.join(args.cache_dir, args.data_dir)
dataset = load_dataset(args.data_dir, split='train')
num_rows = min(args.num_rows, len(dataset))

return args, tokenizer, llm, dataset, num_rows, sampling_params, save_dir_output

@pytest.fixture(scope="module")
def inference_results(model_setup):
args, tokenizer, llm, dataset, num_rows, sampling_params, _ = model_setup
results = []

def vllm_sound_inference(sample_id):
sound_messages = dataset[sample_id]['sound_convo'][0]
expected_output_messages = dataset[sample_id]['sound_convo'][1]
sound_input_str = tokenizer.apply_chat_template([sound_messages], tokenize=False, add_generation_prompt=True)
text_input_str = dataset[sample_id]['prompt']
expected_output_str = tokenizer.apply_chat_template([expected_output_messages], tokenize=False)

outputs = llm.generate(sound_input_str, sampling_params)
output_based_on_sound = outputs[0].outputs[0].text
output_token_ids = outputs[0].outputs[0].token_ids

return text_input_str, output_based_on_sound, expected_output_str, output_token_ids

def vllm_qna_inference(sample_id):
text_input_str = dataset[sample_id]['prompt']
expected_answer_str = dataset[sample_id]['answer']
question_str = tokenizer.apply_chat_template([text_input_str], tokenize=False, add_generation_prompt=True)
outputs = llm.generate(question_str, sampling_params)
output_based_on_question = outputs[0].outputs[0].text
output_token_ids = outputs[0].outputs[0].token_ids

return text_input_str, output_based_on_question, expected_answer_str, output_token_ids
if args.mode == "audio":
for i in range(num_rows):
results.append(vllm_sound_inference(i))
elif args.mode == "text":
for i in range(num_rows):
results.append(vllm_qna_inference(i))

df_results = pd.DataFrame(results, columns=['input', 'output', 'expected_output', 'output_token_ids'])
df_results.to_csv(save_dir_output, index=False, encoding='utf-8')
print(f"Successfully saved in {save_dir_output}")

return results

def test_model_output(inference_results):
for text_input_str, output_based_on_sound, expected_output_str, output_token_ids in inference_results:
assert len(output_based_on_sound) > 0, "Output should not be empty"
assert isinstance(output_based_on_sound, str), "Output should be a string"
assert all(token >= 0 for token in output_token_ids), "Output tokens should be valid"

def test_special_tokens(model_setup, inference_results):
_, tokenizer, _, _, _, _, _ = model_setup
special_tokens = [tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token]
for token in special_tokens:
if token:
encoded = tokenizer.encode(token)
assert encoded[0] != -100, f"Special token {token} should not be ignored"

def test_no_nan_outputs(inference_results):
for _, output, _, _ in inference_results:
assert not any(np.isnan(float(word)) for word in output.split() if word.replace('.', '').isdigit()), "Output should not contain NaN values"

def test_eos_token_generation(model_setup, inference_results):
_, tokenizer, _, _, _, _, _ = model_setup
eos_token_id = tokenizer.eos_token_id
for _, _, _, output_token_ids in inference_results:
assert eos_token_id in output_token_ids, "EOS token not found in the generated output"
assert output_token_ids[-1] == eos_token_id, "EOS token is not at the end of the sequence"
assert output_token_ids.count(eos_token_id) == 1, f"Expected 1 EOS token, but found {output_token_ids.count(eos_token_id)}"

# Additional tests can be added here

0 comments on commit 3eb76cc

Please sign in to comment.