Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Throughput and Latency degradation with a single LoRA adapter on A100 40 GB #10062

Open
1 task done
kaushikmitr opened this issue Nov 6, 2024 · 10 comments
Open
1 task done
Labels
performance Performance-related issues

Comments

@kaushikmitr
Copy link

Proposal to improve performance

No response

Report of performance regression

No response

Misc discussion on performance


Setup Summary for vLLM Benchmarking with Llama-2 Model:

  • Hardware: A100 40 GB (a2-highgpu-2g) on Google Kubernetes Engine (GKE)

  • Model: meta-llama/Llama-2-7b-hf

  • GPU Count: 1

  • Experiments:

    • Experiment 1: Requests using the base model meta-llama/Llama-2-7b-hf.
    • Experiment 2: vLLM deployed with LoRA adapter vineetsharma/qlora-adapter-Llama-2-7b-hf-TweetSumm (size 160 MB).
    • Experiment 3: vLLM deployed with LoRA adapter xtuner/Llama-2-7b-qlora-moss-003-sft (size 640 MB).

    For all three experiments, we used the same input prompt (ShareGPT) and observed a similar output length.

Settings:

  • Eager Mode: Not enabled.
  • Max GPU Utilization: Default at 90%.

Benchmark Metrics:
We measured:

  • Latency per output token
  • Throughput (output tokens per second)

You can view detailed results in the benchmark document: Benchmark 1 server - Sheet7.pdf.


Observations and Questions:

  • Using LoRA adapters led to a notable degradation in throughput and latency compared to the base model. Specifically, we observed up to a 50% drop in maximum throughput with LoRA compared to the base model.
  • Is this performance degradation expected with LoRA adapters?
  • Are there parameters or tuning options that could improve LoRA performance?

Deployment Command:

command: ["python3", "-m", "vllm.entrypoints.openai.api_server"]
args:
  - "--model"
  - "meta-llama/Llama-2-7b-hf"
  - "--tensor-parallel-size"
  - "1"
  - "--port"
  - "8000"
  - "--disable-log-requests"
  - "--enable-lora"
  - "--max-loras"
  - "3"
  - "--max-cpu-loras"
  - "15"
  - "--max-lora-rank"
  - "64"
  - "--gpu-memory-utilization"
  - "0.9"
  - "--lora-modules"
  - xtuner/Llama-2-7b-qlora-moss-003-sft

Your current environment (if you think it is necessary)


Sample Query:

curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{
  "model": "tweet-summary",
  "prompt": "Write as if you were a critic: San Francisco",
  "max_tokens": 100,
  "temperature": 0
}'

Deployment YAML Configuration:

---
apiVersion: v1
kind: Service
metadata:
  name: vllm-llama2-7b-pool
spec:
  selector:
    app: vllm-llama2-7b-pool
  ports:
    - protocol: TCP
      port: 8000
      targetPort: 8000
  type: LoadBalancer

---
apiVersion: apps/v1
kind: Deployment
metadata:
  name: vllm-llama2-7b-pool
spec:
  replicas: 1
  selector:
    matchLabels:
      app: vllm-llama2-7b-pool
  template:
    metadata:
      labels:
        app: vllm-llama2-7b-pool
    spec:
      containers:
        - name: lora
          image: "vllm/vllm-openai:latest"
          imagePullPolicy: Always
          command: ["python3", "-m", "vllm.entrypoints.openai.api_server"]
          args:
            - "--model"
            - "meta-llama/Llama-2-7b-hf"
            - "--tensor-parallel-size"
            - "1"
            - "--port"
            - "8000"
            - "--disable-log-requests"
            - "--enable-lora"
            - "--max-loras"
            - "3"
            - "--max-cpu-loras"
            - "15"
            - "--max-lora-rank"
            - "64"
            - "--gpu-memory-utilization"
            - "0.9"
            - "--lora-modules"
            - "tweet-summary-0=/adapters/vineetsharma/qlora-adapter-Llama-2-7b-hf-TweetSumm_0"
          env:
            - name: PORT
              value: "8000"
            - name: HUGGING_FACE_HUB_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-token
                  key: token
          ports:
            - containerPort: 8000
              name: http
              protocol: TCP
          livenessProbe:
            failureThreshold: 240
            httpGet:
              path: /health
              port: http
              scheme: HTTP
            initialDelaySeconds: 5
            periodSeconds: 5
            successThreshold: 1
            timeoutSeconds: 1
          readinessProbe:
            failureThreshold: 600
            httpGet:
              path: /health
              port: http
              scheme: HTTP
            initialDelaySeconds: 5
            periodSeconds: 5
            successThreshold: 1
            timeoutSeconds: 1
          resources:
            limits:
              nvidia.com/gpu: 1
            requests:
              nvidia.com/gpu: 1
          volumeMounts:
            - mountPath: /data
              name: data
            - mountPath: /dev/shm
              name: shm
            - name: adapters
              mountPath: "/adapters"
      initContainers:
        - name: adapter-loader
          image: ghcr.io/tomatillo-and-multiverse/adapter-puller:demo
          command: ["python"]
          args:
            - ./pull_adapters.py
            - --adapter
            - xtuner/Llama-2-7b-qlora-moss-003-sft
            - --adapter
            - yard1/llama-2-7b-sql-lora-test
            - --adapter
            - vineetsharma/qlora-adapter-Llama-2-7b-hf-TweetSumm
            - --duplicate-count
            - "5"
          env:
            - name: HF_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-token
                  key: token
            - name: HF_HOME
              value: /adapters
          volumeMounts:
            - name: adapters
              mountPath: "/adapters"
      restartPolicy: Always
      schedulerName: default-scheduler
      terminationGracePeriodSeconds: 30
      volumes:
        - name: data
          emptyDir: {}
        - name: shm
          emptyDir:
            medium: Memory
        - name: adapters
          emptyDir: {}

This deployment configuration sets up the vLLM server with LoRA adapters on GKE, with health probes, GPU limits, and a volume configuration for adapter management.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@kaushikmitr kaushikmitr added the performance Performance-related issues label Nov 6, 2024
@kaushikmitr kaushikmitr changed the title [Performance]: Throughput and Latency degradation with and without single LoRA adapter [Performance]: Throughput and Latency degradation with a single LoRA adapter on A100 40 GB Nov 6, 2024
@jeejeelee
Copy link
Collaborator

This is a very detailed and excellent description.
Which vllm version are you using?

@ahg-g
Copy link

ahg-g commented Nov 6, 2024

The latest version; see the container image in the yaml: vllm/vllm-openai:latest; so that would be https://github.com/vllm-project/vllm/releases/tag/v0.6.3.post1 since those tests were run this week.

Do you recommend a specific version to test with?

@jeejeelee
Copy link
Collaborator

jeejeelee commented Nov 6, 2024

There are some similar issues, see: #9496 and #9452. The main reason is due to enable_eager=true, but I can't find this argument in your script.

BTW, if I remember correctly, some versions had cuda graph bug, so I suggest you try doing profiling and confirm whether eager mode is disabled correctly, you can refer to :https://docs.vllm.ai/en/latest/dev/profiling/profiling_index.html#openai-server.

PS:I will try to reproduce your results tomorrow in my local timezone and #5036 provibe some test results

@ahg-g
Copy link

ahg-g commented Nov 7, 2024

Thanks @jeejeelee it will be great if you can reproduce!

@jeejeelee
Copy link
Collaborator

jeejeelee commented Nov 7, 2024

We conducted testing on a local A800 (A800-SXM4-80GB).

 vllm serve meta-llama/Llama-2-7b-chat-hf --gpu-memory-utilization 0.90 --served-model-name base --enable-lora --max-loras 3 --max-cpu-loras 15 --max-lora-rank 64 --lora-modules moss=xtuner/Llama-2-7b-qlora-moss-003-sft
  • The benchmarking script is:
python benchmark_serving.py --model base --tokenizer meta-llama/Llama-2-7b-chat-hf  --dataset-name random --random-input-len 512 --random-output-len 128 --ignore-eos --num-prompts 24 --metric-percentiles 90 --request-rate 20(from 1to 24)
  • Although we set our total number of requests relatively low, still observed trends similar to yours, see:
    benchmark.pdf. Therefore, I believe I can answer your questions:
Q:
Is this performance degradation expected with LoRA adapters?
Are there parameters or tuning options that could improve LoRA performance?
A:
This behavior is expected. Unlike the base model, when we increase the request rate, LoRA become progressively compute-bound  leading to increased latency.,Currently no tuning  option is provided. you could try implementing Tensor Parallel.
  • We will try to increase the total number of requests for future tests, and will continue to investigate and work on resolving these issues.
  • BTW @simon-mo has added a fea-lora channel on Slack where we can discuss these topics.

@ahg-g
Copy link

ahg-g commented Nov 7, 2024

Thanks @jeejeelee, this is very insightful, did you try the smaller adapter?

@jeejeelee
Copy link
Collaborator

Thanks @jeejeelee, this is very insightful, did you try the smaller adapter?

Not yet, I will tomorrow

@Bingogogogogo
Copy link

same question, any solution yet?

@Jeffwan
Copy link
Contributor

Jeffwan commented Nov 21, 2024

I will also help reproduce this issue from my end this week. What I observe is around 20%-25% overhead which is expected. Seem we need to standarize the lora workloads and benchmark to better help users reproduce the results

@kaushikmitr
Copy link
Author

kaushikmitr commented Jan 22, 2025

throughput/latency vs kv cache utilization

I did some new benchmarks and noticed that max lora rank has a significant impact on performance, and its best to set it = the rank of lora (or the rank of the largest ranked lora if using multiple lora). This is consistent with what is documented here.

With rank = 16, the throughput hit is about 27% at 80% kv cache utilization.
With rank = 64, the throughput hit is about 50% at 80% kv cache utilization. (same as my initial benchmarking above).

Image

(tp-2 indicates tensor parallelism = 2 i.e 2 GPUs were used)

I also enabled the vLLM profiler to get a more granular understanding of where the performance hit is coming from.

Performance Analysis:

vLLM's profiler provides slice flamegraphs. Tweet summary (max rank 64) running online with 96 prompts revealed cudaMemcpyAsync as a major latency contributor 47% of the total 35 seconds.

Image

The base model's slice flowgraph showed cudaMemcpyAsync using 40% of the 27.96 seconds (96 prompts).

Image

The 8-second difference between base and LoRA models (same number of prompts) was largely due to cudaMemcpyAsync (60% of the delta can be explained by the extra time taken by cudaMemcpyAsync). While LoRA weights are <2% of the base model size, the significant difference in cudaMemcpyAsync with and without LoRA is unclear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

No branches or pull requests

5 participants