diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmark/base.py b/benchmark/base.py new file mode 100644 index 0000000..b2448df --- /dev/null +++ b/benchmark/base.py @@ -0,0 +1,89 @@ +import argparse +import logging +import sys +import torch +import time +from collections import defaultdict +from benchmark.utils import profile_latency, profile_usage + + +class LlamaBenchmarkBase: + def __init__(self, model_dir_path: str, device: str, *args, **kwargs) -> None: + self.model_dir_path, self.device = model_dir_path, device + self.results = [] + + def load_model(self): + return self + + @profile_usage + @profile_latency + def run_model(self, prompt: str, max_tokens: int, *args, **kwargs): + raise NotImplementedError + + def benchmark(self, prompt: str, max_tokens: int, repetitions: int, *args, **kwargs) -> None: + for i in range(repetitions): + logging.info( + f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]" + ) + (latency, memory_usage), results = self.run_model( + prompt=prompt, max_tokens=max_tokens, *args, **kwargs + ) + + print(latency, memory_usage) + + self.results.append((latency, memory_usage)) + + del self.model + if self.device == "cuda": + torch.cuda.synchronize() + +def benchmark_arg_parser(name: str, benchmark_class): + parser = argparse.ArgumentParser(description=f"{name} Benchmark.") + parser.add_argument( + "--prompt", + type=str, + help="The prompt for the model.", + ) + parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.") + parser.add_argument( + "--repetitions", + type=int, + help="The number of repetitions for the benchmark.", + ) + parser.add_argument( + "--device", + help="Device to use for the benchmark.", + ) + parser.add_argument( + "--log_file", + type=str, + help="Path to the log file for writing logs (in append mode).", + ) + parser.add_argument( + "--models_dir", + type=str, + help="Path to the models directory.", + ) + + args = parser.parse_args() + + logging.info( + f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} " + + f"repetitions={args.repetitions} device={args.device}" + ) + report = defaultdict(lambda: defaultdict(float)) + + for precision in ("fp32", "fp16", "int4"): + logging.info(f"Running VLLM benchmark on Llama on {precision} precision.") + + llama_vllm_bench = benchmark_class( + f"{args.models_dir}/llama-2-7b-hf" + if precision != "int4" + else f"{args.models_dir}/llama-2-7b-autoawq", + device=args.device, + precision=precision, + ).load_model() + + llama_vllm_bench.benchmark( + max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions + ) \ No newline at end of file diff --git a/bench_vllm/README.md b/benchmark/bench_vllm/README.md similarity index 100% rename from bench_vllm/README.md rename to benchmark/bench_vllm/README.md diff --git a/benchmark/bench_vllm/__init__.py b/benchmark/bench_vllm/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/benchmark/bench_vllm/__init__.py @@ -0,0 +1 @@ + diff --git a/benchmark/bench_vllm/bench.py b/benchmark/bench_vllm/bench.py new file mode 100644 index 0000000..88b9c8c --- /dev/null +++ b/benchmark/bench_vllm/bench.py @@ -0,0 +1,46 @@ +import sys +from vllm import LLM +from vllm.model_executor.parallel_utils import parallel_state + +import logging +from benchmark.base import LlamaBenchmarkBase, benchmark_arg_parser + +logging.getLogger("vllm").setLevel(logging.ERROR) +logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +class LlamavLLMBenchmark(LlamaBenchmarkBase): + def __init__(self, model_dir_path: str, device: str, precision: str) -> None: + assert device == "cuda", ValueError("Supported device is cuda only.") + assert precision in ["fp16", "fp32", "int4"], ValueError( + "supported precision are: fp16, fp32 and int4" + ) + + self.precision = precision + self.precision_map = {"fp16": "float16", "fp32": "float32"} + super().__init__(model_dir_path=model_dir_path, device=device) + + def load_model(self): + if self.precision != "int4": + self.model = LLM(model=self.model_path) + self.model.dtype = self.precision_map[self.precision] + else: + self.model = LLM(model=self.model_path, quantization="AWQ") + return self + + def run_model(self, prompt: str, max_tokens: int) -> float: + self.model.max_num_seqs = max_tokens + output = self.model.generate(prompts=[prompt]) + return output + + def benchmark(self, prompt: str, max_tokens: int, repetitions: int, *args, **kwargs) -> None: + super().benchmark(prompt, max_tokens, repetitions, *args, **kwargs) + + if self.device == "cuda": + parallel_state.destroy_model_parallel() + + +benchmark_arg_parser(name="vLLM", benchmark_class=LlamavLLMBenchmark) \ No newline at end of file diff --git a/bench_vllm/bench.sh b/benchmark/bench_vllm/bench.sh similarity index 99% rename from bench_vllm/bench.sh rename to benchmark/bench_vllm/bench.sh index eb2afad..0868e73 100755 --- a/bench_vllm/bench.sh +++ b/benchmark/bench_vllm/bench.sh @@ -168,6 +168,6 @@ REPETITIONS="${REPETITIONS:-10}" MAX_TOKENS="${MAX_TOKENS:-512}" DEVICE="${DEVICE:-'cuda'}" LOG_FILENAME="${LOG_FILENAME:-"$LOGS_FOLDER/benchmark_vllm_$(date +'%Y%m%d%H%M%S').log"}" -MODELS_DIR="${MODELS_DIR:-"./models"}" +MODELS_DIR="${MODELS_DIR:-"../models"}" run_benchmarks "$PROMPT" "$REPETITIONS" "$MAX_TOKENS" "$DEVICE" "$LOG_FILENAME" "$MODELS_DIR" diff --git a/bench_vllm/bench.py b/benchmark/bench_vllm/bench2.py similarity index 100% rename from bench_vllm/bench.py rename to benchmark/bench_vllm/bench2.py diff --git a/bench_vllm/setup.sh b/benchmark/bench_vllm/setup.sh similarity index 98% rename from bench_vllm/setup.sh rename to benchmark/bench_vllm/setup.sh index 9153211..00832f8 100755 --- a/bench_vllm/setup.sh +++ b/benchmark/bench_vllm/setup.sh @@ -8,7 +8,7 @@ set -euo pipefail -AWQ_WEIGHTS_FOLDER="${AWQ_WEIGHTS_FOLDER:-"./models/llama-2-7b-awq"}" +AWQ_WEIGHTS_FOLDER="${AWQ_WEIGHTS_FOLDER:-"../models/llama-2-7b-awq"}" check_python() { if command -v python &> /dev/null; then diff --git a/benchmark/utils.py b/benchmark/utils.py new file mode 100644 index 0000000..9ac89a8 --- /dev/null +++ b/benchmark/utils.py @@ -0,0 +1,40 @@ +import os +import subprocess +import time +import psutil +import functools +from contextlib import contextmanager +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection +from memory_profiler import profile as mem_profile +from line_profiler import LineProfiler + +def profile_usage(func): + @functools.wraps(func) + def wrapper_profile_usage(*args, **kwargs): + mem_before = psutil.virtual_memory().used + result = func(*args, **kwargs) + mem_after = psutil.virtual_memory().used + mem_usage = mem_after - mem_before + print(f"Memory usage: {mem_usage} bytes") + return result, mem_usage + + return wrapper_profile_usage + + +def profile_latency(func): + @functools.wraps(func) + def wrapper_profile_latency(*args, **kwargs): + profiler = LineProfiler() + profiler.add_function(func) + profiler.enable_by_count() + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + profiler.disable_by_count() + profiler.print_stats() + latency = end_time - start_time + print(f"Latency: {latency} seconds") + return result, latency + + return wrapper_profile_latency \ No newline at end of file