Skip to content

Commit

Permalink
Add benchmark using sglang server,
Browse files Browse the repository at this point in the history
Add sgl server benchmark to workflow file,
Restructure `app_tests/benchmark_tests`
  • Loading branch information
stbaione committed Nov 22, 2024
1 parent a6cb442 commit 3c21be0
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 28 deletions.
113 changes: 101 additions & 12 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
name: SGLang Llama Benchmarking Tests

on:
# TODO: Remove PR trigger after verification
pull_request:
workflow_dispatch:
schedule:
# Weekdays at 4:00 AM UTC = 9:00 PM PST.
- cron: "0 4 * * 1-5"
# Weekdays at 6:00 AM UTC = 11:00 PM PST.
# This is a pretty GPU intensive test, so want to avoid conflicting
# with other potentially scheduled tests.
- cron: "0 6 * * 1-5"

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
Expand All @@ -21,9 +25,9 @@ concurrency:
cancel-in-progress: true

jobs:
sglang_bench_serve:
benchmark_shortfin:
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
name: "SGLang Serving Benchmark Tests"
name: "SGLang Serving Benchmark With Shortfin"
strategy:
matrix:
version: [3.11]
Expand Down Expand Up @@ -77,13 +81,98 @@ jobs:
- name: Install SGLang
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Launch Shortfin Server
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
- name: Run Shortfin Benchmark Tests
run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py --log-cli-level=INFO --html=out/llm/shortfin/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
# TODO: Uncomment after verification
# - name: Deploy to GitHub Pages
# uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
# with:
# github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
# publish_dir: ./out/llm/sgl_benchmark/shortfin
# destination_dir: ./llm/sgl_benchmark/shortfin
# keep_files: true

benchmark_sglang:
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
name: "SGLang Serving Benchmark With SGLang"
needs: benchmark_shortfin
strategy:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
- name: Get Current Date
id: date
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"

- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
publish_dir: ./out/llm/sglang
destination_dir: ./llm/sglang
keep_files: true
python-version: ${{matrix.version}}

- name: Install SGLang
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

# Instruction for SGLang image sourced from here:
# https://sgl-project.github.io/start/install.html#method-3-using-docker
# We have to run in a docker container due to their vLLM dependency.
# From their pyproject.toml:
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
# srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
- name: Pull SGLang Image (Had issues with sglang:v0.3.5.post1-rocm620)
run: |
docker pull lmsysorg/sglang:v0.3.5.post1-rocm620
- name: Run SGLang Server
run: |
docker run -d --rm \
--device=/dev/kfd \
--device=/dev/dri \
--ipc=host \
--shm-size 16G \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-v $HOME/dockerx:/dockerx \
-v /data:/data \
-p 30000:30000 \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HF_TOKEN={{ secrets.HF_TOKEN }}" \
lmsysorg/sglang:v0.3.5.post1-rocm620 \
python3 -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8b \
--host 0.0.0.0 \
--port 30000 \
--tp 1 \
--dtype float16
- name: Run SGLang Benchmark Tests
run: |
pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --port 30000 --log-cli-level=INFO --html=out/llm/sglang/index.html
- name: Stop sglang-server
run: docker stop sglang-server || true # Stop container if it's running

- name: Cleanup SGLang Image
run: docker image rm lmsysorg/sglang:v0.3.5.post1-rocm620

# TODO: Uncomment after verifying
# - name: Deploy to GitHub Pages
# uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
# with:
# github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
# publish_dir: ./out/llm/sgl_benchmark/sglang
# destination_dir: ./llm/sgl_benchmark/sglang
# keep_files: true
5 changes: 5 additions & 0 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import pytest
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
)
from integration_tests.llm.utils import compile_model, export_paged_llm_v1


Expand Down Expand Up @@ -44,3 +46,17 @@ def pre_process_model(request, tmp_path_factory):
compile_model(mlir_path, vmfb_path, settings)

return tmp_dir


def pytest_addoption(parser):
parser.addoption(
"--port",
action="store",
default="30000",
help="Port that SGLang server is running on",
)


@pytest.fixture(scope="module")
def sglang_args(request):
return request.config.getoption("--port")
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
from pathlib import Path
import pytest
import time
from unittest.mock import patch

pytest.importorskip("sglang")
from sglang import bench_serving

from .utils import SGLangBenchmarkArgs, log_jsonl_result

from integration_tests.llm.utils import wait_for_server

logger = logging.getLogger(__name__)

TOKENIZER_DIR = Path("/data/llama3.1/8b/")


@pytest.mark.parametrize(
"request_rate",
[1, 2, 4, 8, 16, 32],
)
def test_sglang_benchmark(request_rate, sglang_args, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")
logger.info("Beginning SGLang benchmark test...")

port = sglang_args
base_url = f"http://localhost:{port}"

# Setting a high timeout gives enough time for downloading model artifacts
# and starting up server... Takes a little longer than shortfin.
wait_for_server(base_url, timeout=600)

benchmark_args = SGLangBenchmarkArgs(
backend="sglang",
num_prompt=10,
base_url=f"http://localhost:{port}",
tokenizer=TOKENIZER_DIR,
request_rate=request_rate,
)
output_file = (
tmp_dir
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl"
)
benchmark_args.output_file = output_file

logger.info("Running SGLang Benchmark with the following args:")
logger.info(benchmark_args)

try:
start = time.time()
with patch.object(bench_serving, "print", side_effect=logger.info):
bench_serving.run_benchmark(
benchmark_args.as_namespace(),
)
logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds")
logger.info("======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
logger.error(e)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import multiprocessing
import os
Expand All @@ -16,14 +15,17 @@
pytest.importorskip("sglang")
from sglang import bench_serving

from utils import SGLangBenchmarkArgs
from app_tests.benchmark_tests.llm.sglang_benchmarks.utils import (
SGLangBenchmarkArgs,
log_jsonl_result,
)

from integration_tests.llm.utils import (
find_available_port,
start_llm_server,
)

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)

device_settings = {
"device_flags": [
Expand All @@ -38,15 +40,6 @@
TOKENIZER_DIR = Path("/data/llama3.1/8b/")


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")


@pytest.mark.parametrize(
"request_rate",
[1, 2, 4, 8, 16, 32],
Expand All @@ -64,7 +57,7 @@ def log_jsonl_result(file_path):
],
indirect=True,
)
def test_sglang_benchmark_server(request_rate, pre_process_model):
def test_shortfin_benchmark(request_rate, pre_process_model):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"

Expand Down Expand Up @@ -116,7 +109,7 @@ def test_sglang_benchmark_server(request_rate, pre_process_model):
logger.info("======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
logger.info(e)
logger.error(e)

server_process.terminate()
server_process.wait()
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from argparse import Namespace
from dataclasses import dataclass
import json
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


@dataclass
class SGLangBenchmarkArgs:
Expand Down Expand Up @@ -54,3 +58,12 @@ def __repr__(self):
f"Tokenizer: {self.tokenizer}\n"
f"Request Rate: {self.request_rate}"
)


def log_jsonl_result(file_path):
with open(file_path, "r") as file:
json_string = file.readline().strip()

json_data = json.loads(json_string)
for key, val in json_data.items():
logger.info(f"{key.upper()}: {val}")
2 changes: 1 addition & 1 deletion app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from transformers import AutoTokenizer

logger = logging.getLogger("__name__")
logger = logging.getLogger(__name__)


class AccuracyValidationException(RuntimeError):
Expand Down

0 comments on commit 3c21be0

Please sign in to comment.