Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/deepseek-ai/DeepSeek-R1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "deepseek-ai/DeepSeek-R1"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/deepseek-ai/DeepSeek-V3.1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "deepseek-ai/DeepSeek-V3.1"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/deepseek-ai/DeepSeek-V3.2-Exp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "deepseek-ai/DeepSeek-V3.2-Exp"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/google/gemma-3-27b-it.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "google/gemma-3-27b-it"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tasks:
- name: "gsm8k"
device: b200
tp: 4
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/openai/gpt-oss-120b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "openai/gpt-oss-120b"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/openai/gpt-oss-20b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "openai/gpt-oss-20b"
tasks:
- name: "gsm8k"
device: b200
tp: 1
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/qwen/Qwen3-30B-A3B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "Qwen/Qwen3-30B-A3B"
tasks:
- name: "gsm8k"
device: b200
tp: 8
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
14 changes: 14 additions & 0 deletions vllm-eval-harness/configs/qwen/Qwen3-8B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "Qwen/Qwen3-8B"
tasks:
- name: "gsm8k"
device: b200
tp: 1
# Adopted from vLLM lm-eval-harness, set the value to 0 if there is no baseline
metrics:
- name: "exact_match,strict-match"
value: 0
- name: "exact_match,flexible-extract"
value: 0
limit: 1000
num_fewshot: 5
trust_remote_code: True
186 changes: 186 additions & 0 deletions vllm-eval-harness/run_vllm_eval_harness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import json
import os
import glob
import lm_eval
import yaml
from logging import warning, info
from argparse import Action, ArgumentParser, Namespace
import torch
from typing import Dict, Any, List, Optional


# See lm-eval docs for the list of acceptable values
LM_EVAL_MODEL_SOURCE = os.environ.get("LM_EVAL_MODEL_SOURCE", "vllm")


class ValidateDir(Action):
def __call__(
self,
parser: ArgumentParser,
namespace: Namespace,
values: Any,
option_string: Optional[str] = None,
) -> None:
if os.path.isdir(values):
setattr(namespace, self.dest, values)
return

parser.error(f"{values} is not a valid directory")


def parse_args() -> Any:
parser = ArgumentParser("Run vLLM lm-eval harness")

parser.add_argument(
"--configs-dir",
type=str,
action=ValidateDir,
help="the directory contains vLLM lm-eval harness configs",
required=True,
)
parser.add_argument(
"--models",
type=str,
default="",
help="the comma-separated list of models to evaluate (optional)",
)
parser.add_argument(
"--tasks",
type=str,
default="",
help="the comma-separated list of tasks to evaluate (optional)",
)

return parser.parse_args()


def convert_to_pytorch_benchmark_format(
model_name: str, tp_size: int, results: Dict[str, Any]
) -> List[Any]:
records = []
configs = results.get("configs", {})

for task_name, metrics in results.get("results", {}).items():
for metric_name, metric_value in metrics.items():
if type(metric_value) is str:
continue

record = {
"benchmark": {
"name": "vLLM lm-eval harness",
"extra_info": {
"args": {
"tensor_parallel_size": tp_size,
},
"configs": configs.get(task_name, {}),
},
},
"model": {
"name": model_name,
},
"metric": {
"name": metric_name,
"benchmark_values": [metric_value],
},
}
records.append(record)

return records


def run(
model_name: str, tasks: List[str], tp_size: int, config: Dict[str, Any]
) -> Dict[str, Any]:
trust_remote_code = config.get("trust_remote_code", False)
max_model_len = config.get("max_model_len", 8192)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this likely will impact the result. ideally it's set in auto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like vLLM lm-eval does like that value and ends up with this error:

/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
INFO 10-10 01:22:45 [__init__.py:215] Automatically detected platform cuda.
Traceback (most recent call last):
  File "/vllm-workspace/pytorch-integration-testing/vllm-eval-harness/run_vllm_eval_harness.py", line 186, in <module>
    main()
  File "/vllm-workspace/pytorch-integration-testing/vllm-eval-harness/run_vllm_eval_harness.py", line 182, in main
    run_lm_eval(args.configs_dir, models, tasks)
  File "/vllm-workspace/pytorch-integration-testing/vllm-eval-harness/run_vllm_eval_harness.py", line 164, in run_lm_eval
    results = run(model_name, selected_tasks, tp_size, config)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-workspace/pytorch-integration-testing/vllm-eval-harness/run_vllm_eval_harness.py", line 105, in run
    return lm_eval.simple_evaluate(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/lm_eval/utils.py", line 456, in _wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/lm_eval/evaluator.py", line 245, in simple_evaluate
    lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/lm_eval/api/model.py", line 155, in create_from_arg_string
    return cls(**args, **args2)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/lm_eval/models/vllm_causallms.py", line 170, in __init__
    "max_model_len": int(self._max_length) if self._max_length else None,
                     ^^^^^^^^^^^^^^^^^^^^^
ValueError: invalid literal for int() with base 10: 'auto'

Let me take a closer look.


model_args = (
f"pretrained={model_name},"
f"tensor_parallel_size={tp_size},"
f"add_bos_token=true,"
f"trust_remote_code={trust_remote_code},"
f"max_model_len={max_model_len}"
)
info(f"Evaluating {model_name} with {model_args}")
return lm_eval.simple_evaluate(
model=LM_EVAL_MODEL_SOURCE,
model_args=model_args,
tasks=tasks,
num_fewshot=config["num_fewshot"],
limit=config["limit"],
batch_size="auto",
)


def run_lm_eval(configs_dir: str, models: List[str], tasks: List[str]) -> None:
device_name = torch.cuda.get_device_name().lower()
device_count = torch.cuda.device_count()

results_dir = os.path.join(configs_dir, "results")
os.makedirs(results_dir, exist_ok=True)

for file in glob.glob(f"{configs_dir}/**/*.yml", recursive=True):
with open(file) as f:
config = yaml.safe_load(f)
# Check the model name
model_name = config.get("model_name", "").lower()
if models and model_name not in models:
info(f"Skip {model_name} from {file}")
continue

tp_size = 0
selected_tasks = []

# Check the lm-eval tasks, the selected device, and tp
for t in config.get("tasks", []):
task_name = t["name"]
if not task_name:
warning(f"{model_name} from {file}: skip missing task")
continue

if tasks and task_name not in tasks:
info(f"{model_name} from {file}: {task_name} not selected")

selected_device = t["device"].lower()
if selected_device not in device_name:
continue

tp = t["tp"]
if device_count < tp:
warning(
f"{model_name} from {file}: device count {device_count} < tp {tp} in {task_name}"
)
continue

selected_tasks.append(task_name)
if not tp_size:
tp_size = tp
assert tp_size == tp

if not selected_tasks:
info(f"Skip {model_name} from {file}: no task")
continue

results = run(model_name, selected_tasks, tp_size, config)
results_pytorch_format = convert_to_pytorch_benchmark_format(
model_name, tp_size, results
)

results_file = os.path.splitext(os.path.basename(file))[0]
# Dump the results from lm-eval
with open(os.path.join(results_dir, f"{results_file}_lm_eval.json"), "w") as f:
json.dump(results, f, indent=2)
# Dump the results that can be uploaded to PyTorch OSS benchmark infra
with open(os.path.join(results_dir, f"{results_file}_pytorch.json"), "w") as f:
json.dump(results_pytorch_format, f, indent=2)


def main() -> None:
args = parse_args()
models = [m.strip().lower() for m in args.models.split(",") if m.strip()]
tasks = [m.strip().lower() for m in args.tasks.split(",") if m.strip()]
run_lm_eval(args.configs_dir, models, tasks)


if __name__ == "__main__":
main()