Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torchao quant (int4/int8/fp8) to llama models #1341

Merged
merged 8 commits into from
Sep 9, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Sep 6, 2024

Summary:
We want to hack before we work on a proper solution

proper solution will be rewrite llama model with tensor parallelism: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
(using DTensor underneath), trying to do it here: pytorch/ao#785

Test Plan:

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8

max_total_num_tokens=432196
Warmup ...
Prefill. latency: 0.03214 s, throughput:   3983.19 token/s
Decode.  latency: 0.01383 s, throughput:     72.31 token/s
Decode.  latency: 0.01354 s, throughput:     73.88 token/s
Decode.  latency: 0.01338 s, throughput:     74.75 token/s
Decode.  latency: 0.01330 s, throughput:     75.17 token/s
Decode.  median latency: 0.01346 s, median throughput:     74.31 token/s
Total. latency:  0.086 s, throughput:   1531.66 token/s
Benchmark ...
Prefill. latency: 0.02514 s, throughput:   5092.40 token/s
Decode.  latency: 0.01337 s, throughput:     74.80 token/s
Decode.  latency: 0.01338 s, throughput:     74.74 token/s
Decode.  latency: 0.01339 s, throughput:     74.68 token/s
Decode.  latency: 0.01321 s, throughput:     75.68 token/s
Decode.  latency: 0.01295 s, throughput:     77.23 token/s
Decode.  median latency: 0.01337 s, median throughput:     74.77 token/s
Total. latency:  0.132 s, throughput:   1032.13 token/s

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config int4wo

max_total_num_tokens=505188
Warmup ...
Prefill. latency: 0.10929 s, throughput:   1171.18 token/s
Decode.  latency: 0.00790 s, throughput:    126.57 token/s
Decode.  latency: 0.00738 s, throughput:    135.54 token/s
Decode.  latency: 0.00724 s, throughput:    138.16 token/s
Decode.  latency: 0.00726 s, throughput:    137.71 token/s
Decode.  median latency: 0.00732 s, median throughput:    136.62 token/s
Total. latency:  0.139 s, throughput:    949.17 token/s
Benchmark ...
Prefill. latency: 0.10405 s, throughput:   1230.13 token/s
Decode.  latency: 0.00769 s, throughput:    129.96 token/s
Decode.  latency: 0.00725 s, throughput:    137.85 token/s
Decode.  latency: 0.00724 s, throughput:    138.11 token/s
Decode.  latency: 0.00731 s, throughput:    136.72 token/s
Decode.  latency: 0.00744 s, throughput:    134.47 token/s
Decode.  median latency: 0.00730 s, median throughput:    136.97 token/s
Total. latency:  0.163 s, throughput:    834.99 token/s

with compile
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile

max_total_num_tokens=432196
Warmup ...
Prefill. latency: 0.03095 s, throughput:   4135.63 token/s
Decode.  latency: 0.01258 s, throughput:     79.49 token/s
Decode.  latency: 0.01250 s, throughput:     79.98 token/s
Decode.  latency: 0.01262 s, throughput:     79.22 token/s
Decode.  latency: 0.01246 s, throughput:     80.26 token/s
Decode.  median latency: 0.01254 s, median throughput:     79.73 token/s
Total. latency:  0.081 s, throughput:   1627.28 token/s
Benchmark ...
Prefill. latency: 0.02445 s, throughput:   5234.65 token/s
Decode.  latency: 0.01249 s, throughput:     80.09 token/s
Decode.  latency: 0.01245 s, throughput:     80.30 token/s
Decode.  latency: 0.01195 s, throughput:     83.65 token/s
Decode.  latency: 0.01170 s, throughput:     85.45 token/s
Decode.  latency: 0.01179 s, throughput:     84.85 token/s
Decode.  median latency: 0.01179 s, median throughput:     84.79 token/s
Total. latency:  0.120 s, throughput:   1131.80 token/s

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int4wo-128
max_total_num_tokens=506932
Warmup ...
Prefill. latency: 0.10922 s, throughput:   1171.92 token/s
Decode.  latency: 0.00739 s, throughput:    135.25 token/s
Decode.  latency: 0.00692 s, throughput:    144.58 token/s
Decode.  latency: 0.00673 s, throughput:    148.68 token/s
Decode.  latency: 0.00673 s, throughput:    148.68 token/s
Decode.  median latency: 0.00682 s, median throughput:    146.60 token/s
Total. latency:  0.137 s, throughput:    963.61 token/s
Benchmark ...
Prefill. latency: 0.10405 s, throughput:   1230.12 token/s
Decode.  latency: 0.00733 s, throughput:    136.34 token/s
Decode.  latency: 0.00689 s, throughput:    145.13 token/s
Decode.  latency: 0.00682 s, throughput:    146.59 token/s
Decode.  latency: 0.00677 s, throughput:    147.67 token/s
Decode.  latency: 0.00676 s, throughput:    147.85 token/s
Decode.  median latency: 0.00677 s, median throughput:    147.76 token/s
Total. latency:  0.159 s, throughput:    856.11 token/s

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int8wo

max_total_num_tokens=484068
Warmup ...
Prefill. latency: 0.05860 s, throughput:   2184.33 token/s
Decode.  latency: 0.00921 s, throughput:    108.52 token/s
Decode.  latency: 0.00883 s, throughput:    113.29 token/s
Decode.  latency: 0.00880 s, throughput:    113.69 token/s
Decode.  latency: 0.00868 s, throughput:    115.15 token/s
Decode.  median latency: 0.00881 s, median throughput:    113.49 token/s
Total. latency:  0.094 s, throughput:   1402.45 token/s
Benchmark ...
Prefill. latency: 0.05269 s, throughput:   2429.12 token/s
Decode.  latency: 0.00928 s, throughput:    107.80 token/s
Decode.  latency: 0.00892 s, throughput:    112.10 token/s
Decode.  latency: 0.00881 s, throughput:    113.53 token/s
Decode.  latency: 0.00870 s, throughput:    114.90 token/s
Decode.  latency: 0.00864 s, throughput:    115.68 token/s
Decode.  median latency: 0.00876 s, median throughput:    114.15 token/s
Total. latency:  0.123 s, throughput:   1102.80 token/s

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int8dq
max_total_num_tokens=432196
Warmup ...
Prefill. latency: 0.03086 s, throughput:   4147.16 token/s
Decode.  latency: 0.01249 s, throughput:     80.03 token/s
Decode.  latency: 0.01246 s, throughput:     80.29 token/s
Decode.  latency: 0.01275 s, throughput:     78.42 token/s
Decode.  latency: 0.01244 s, throughput:     80.39 token/s
Decode.  median latency: 0.01247 s, median throughput:     80.16 token/s
Total. latency:  0.081 s, throughput:   1629.54 token/s
Benchmark ...
Prefill. latency: 0.02428 s, throughput:   5270.99 token/s
Decode.  latency: 0.01250 s, throughput:     79.99 token/s
Decode.  latency: 0.01252 s, throughput:     79.87 token/s
Decode.  latency: 0.01260 s, throughput:     79.34 token/s
Decode.  latency: 0.01234 s, throughput:     81.02 token/s
Decode.  latency: 0.01236 s, throughput:     80.92 token/s
Decode.  median latency: 0.01241 s, median throughput:     80.56 token/s
Total. latency:  0.124 s, throughput:   1098.26 token/s

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config fp8wo
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
We want to hack before we work on a proper solution

proper solution will be rewrite llama model with tensor parallelism: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
(using DTensor underneath), trying to do it here: pytorch/ao#785

Test Plan:
change `ENABLE_TORCHAO` to True/False in `python/sglang/srt/models/llama.py` to test the baseline v.s. torchao int4 weight only quant performance

python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8

```

max_total_num_tokens=432196
Warmup ...
Prefill. latency: 0.03214 s, throughput:   3983.19 token/s
Decode.  latency: 0.01383 s, throughput:     72.31 token/s
Decode.  latency: 0.01354 s, throughput:     73.88 token/s
Decode.  latency: 0.01338 s, throughput:     74.75 token/s
Decode.  latency: 0.01330 s, throughput:     75.17 token/s
Decode.  median latency: 0.01346 s, median throughput:     74.31 token/s
Total. latency:  0.086 s, throughput:   1531.66 token/s
Benchmark ...
Prefill. latency: 0.02514 s, throughput:   5092.40 token/s
Decode.  latency: 0.01337 s, throughput:     74.80 token/s
Decode.  latency: 0.01338 s, throughput:     74.74 token/s
Decode.  latency: 0.01339 s, throughput:     74.68 token/s
Decode.  latency: 0.01321 s, throughput:     75.68 token/s
Decode.  latency: 0.01295 s, throughput:     77.23 token/s
Decode.  median latency: 0.01337 s, median throughput:     74.77 token/s
Total. latency:  0.132 s, throughput:   1032.13 token/s

max_total_num_tokens=505188
Warmup ...
Prefill. latency: 0.10929 s, throughput:   1171.18 token/s
Decode.  latency: 0.00790 s, throughput:    126.57 token/s
Decode.  latency: 0.00738 s, throughput:    135.54 token/s
Decode.  latency: 0.00724 s, throughput:    138.16 token/s
Decode.  latency: 0.00726 s, throughput:    137.71 token/s
Decode.  median latency: 0.00732 s, median throughput:    136.62 token/s
Total. latency:  0.139 s, throughput:    949.17 token/s
Benchmark ...
Prefill. latency: 0.10405 s, throughput:   1230.13 token/s
Decode.  latency: 0.00769 s, throughput:    129.96 token/s
Decode.  latency: 0.00725 s, throughput:    137.85 token/s
Decode.  latency: 0.00724 s, throughput:    138.11 token/s
Decode.  latency: 0.00731 s, throughput:    136.72 token/s
Decode.  latency: 0.00744 s, throughput:    134.47 token/s
Decode.  median latency: 0.00730 s, median throughput:    136.97 token/s
Total. latency:  0.163 s, throughput:    834.99 token/s

Warmup ...
Prefill. latency: 0.05868 s, throughput:   2181.51 token/s
Decode.  latency: 0.04475 s, throughput:     22.35 token/s
Decode.  latency: 0.04463 s, throughput:     22.41 token/s
Decode.  latency: 0.04467 s, throughput:     22.39 token/s
Decode.  latency: 0.04478 s, throughput:     22.33 token/s
Decode.  median latency: 0.04471 s, median throughput:     22.37 token/s
Total. latency:  0.238 s, throughput:    555.78 token/s
Benchmark ...
Prefill. latency: 0.05274 s, throughput:   2427.22 token/s
Decode.  latency: 0.04463 s, throughput:     22.41 token/s
Decode.  latency: 0.04456 s, throughput:     22.44 token/s
Decode.  latency: 0.04453 s, throughput:     22.45 token/s
Decode.  latency: 0.04469 s, throughput:     22.38 token/s
Decode.  latency: 0.04457 s, throughput:     22.44 token/s
Decode.  median latency: 0.04457 s, median throughput:     22.44 token/s
Total. latency:  0.409 s, throughput:    332.13 token/s
```

Reviewers:

Subscribers:

Tasks:

Tags:
@zhyncs
Copy link
Member

zhyncs commented Sep 6, 2024

Hi @msaroufim @jerryzh168 Nice work! It looks like the CI failures are due to a missing torchao dependency. Could you please add it to https://github.com/sgl-project/sglang/blob/main/python/pyproject.toml? Thanks.

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the changes are very minimal. I feel that we can even merge this and start iteration. Here are a few thoughts on making it more clean.

  1. Add a new argument --enable-torchao similar to --enable-torch-compile here
    "--enable-torch-compile",
    .
  2. Alternatively, instead of adding --enable-torchao, we can add a new quantization format torchao-int4 here
    "--quantization",
  3. You can pass the above arguments to this global variable
    global_server_args_dict.update(
    {
    "disable_flashinfer": server_args.disable_flashinfer,
    "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
    "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
    "enable_mla": server_args.enable_mla,
    }
    )
    and read it in llama.py

python/sglang/srt/models/llama.py Outdated Show resolved Hide resolved
@zhyncs zhyncs self-assigned this Sep 6, 2024
@jerryzh168
Copy link
Contributor Author

jerryzh168 commented Sep 7, 2024

2. Alternatively, instead of adding --enable-torchao, we can add a new quantization format torchao-int4 here

@merrymercy thanks for the suggestions, I can start with https://github.com/pytorch/ao/blob/1ce7da941b311ef6ab416d9c7a22be20e65b7495/torchao/_models/llama/generate.py#L462 I think ideally we'd want to be able to specify different types of quantization like quantization_config = TorchAoConfig("int4_weight_only", group_size=128) (see https://huggingface.co/docs/transformers/main/en/quantization/torchao) so people can try out different quantization methods

I could be using

quant_config: Optional[QuantizationConfig] = None,
actually, but this requires updates to vllm, unless we copy paste the quantization config code here

@merrymercy
Copy link
Contributor

merrymercy commented Sep 8, 2024

Feel free to copy over the code as we will gradually remove the dependency of vLLM (likely in one month or so)

@merrymercy merrymercy changed the title Add torchao quant to sgl llama model for testing Add torchao quant (int4/int8) to sgl llama model for testing Sep 9, 2024
@merrymercy merrymercy changed the title Add torchao quant (int4/int8) to sgl llama model for testing Add torchao quant (int4/int8/fp8) to llama models Sep 9, 2024
@merrymercy merrymercy merged commit a7c47e0 into sgl-project:main Sep 9, 2024
9 checks passed
@merrymercy
Copy link
Contributor

@jerryzh168 Thanks for the contribution. It is merged.

I pushed some updates:

  • Move utils to python/sglang/srt/layers/torchao_utils.py
  • Added unit test test/srt/test_torchao.py

@merrymercy
Copy link
Contributor

merrymercy commented Sep 9, 2024

I verified the performance on H100. Here are the results with a few comments:

  1. int4 decoding speed is pretty good. However, the prefill throughput seems much worse than the fp16 baseline. Do you have any ongoing efforts to improve it?
  2. int8 w/o torch.compile is very bad, but we cannot afford torch.compile for all batch sizes. Since fp8 is a much better alternative at the budget of 8-bit, we will likely only use fp8.
# fp16 baseline
# Decode.  median latency: 0.00728 s, median throughput:    137.41 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8

# fp8 baseline
# Decode.  median latency: 0.00590 s, median throughput:    169.51 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --quant fp8

# torchao int4
# Decode.  median latency: 0.00448 s, median throughput:    223.43 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128

# torchao int8 w/o compile
# Decode.  median latency: 0.02629 s, median throughput:     38.04 token/s
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config int8wo

I also got some errors with the current torchao release on Pypi when trying the following settings. Which version should I use?

  1. fp8 is not supported
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo

[rank0]:   File "/root/sglang/python/sglang/srt/layers/torchao_utils.py", line 31, in torchao_quantize_param_data
[rank0]:     from torchao.quantization import float8_weight_only
[rank0]: ImportError: cannot import name 'float8_weight_only' from 'torchao.quantization' (/usr/local/lib/python3.10/dist-packages/torchao/quantization/__init__.py)
  1. torch.compile + torchao is not supported
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile --torchao-config int4wo-128

[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torchao/dtypes/affine_quantized_tensor.py", line 155, in dequantize
[rank0]:     int_data, scale, zero_point = self.layout_tensor.get_plain()
[rank0]: torch._dynamo.exc.InternalTorchDynamoError: 'FakeTensor' object has no attribute 'get_plain'

@jerryzh168
Copy link
Contributor Author

Thanks @merrymercy for the fixes, tests and merging the PR!

int4 decoding speed is pretty good. However, the prefill throughput seems much worse than the fp16 baseline. Do you have any ongoing efforts to improve it?

not yet I think, but we can take a look to understand why

int8 w/o torch.compile is very bad, but we cannot afford torch.compile for all batch sizes. Since fp8 is a much better alternative at the budget of 8-bit, we will likely only use fp8.

yeah int8 requires torch.compile to speedup, using fp8 sounds good to us as well.

fp8 is not supported

yeah currently we have torchao 0.4 in pypi, we are doing a new 0.5 release this week so fp8 should be supported this week

torch.compile + torchao is not supported

I've seen this before, this is because of some issue in pytorch dynamo/inductor code I think, and it should be fixed in pytorch nightly, maybe you can try installing pytorch nightly first to verify that the issue can be fix first. also do you have a version requirement for pytorch in sglang?

@zhyncs
Copy link
Member

zhyncs commented Sep 9, 2024

also do you have a version requirement for pytorch in sglang?

@jerryzh168 We use torch==2.4.0

@jerryzh168
Copy link
Contributor Author

I see, for 2.4.0 and below, we'd need to call https://github.com/pytorch/ao/blob/e1039abac7f429a8d7f489d047d9b34d6ac6afe2/torchao/utils.py#L269 for the model to be compilable I think, I'll try to verify as well when I got a chance today

@jerryzh168
Copy link
Contributor Author

for 2.4.0 there seems to be some issue even with work around, trying to debug now

@jerryzh168
Copy link
Contributor Author

@merrymercy @zhyncs

for the second issue "torch.compile + torchao is not supported" I can only repro in 2.4.0 and not sure how to fix, also the fix could just be needed in pytorch itself I think and we can't really back port the fix to 2.4.0. I'm wondering if it's OK for you to use torch nightly for now for testing? pytorch 2.5.0 is going to be released in one month.

@merrymercy merrymercy mentioned this pull request Sep 13, 2024
29 tasks
jerryzh168 added a commit to jerryzh168/sglang that referenced this pull request Sep 14, 2024
Summary:
Similar to sgl-project#1341 we add torchao quantization to mixtral model

Test Plan:
Note: compile is not working yet, and I can't install torchnightly locally and make it work either.
I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later

python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8

Warmup ...
Prefill. latency: 0.05532 s, throughput:   2313.73 token/s
Decode.  latency: 0.00896 s, throughput:    111.65 token/s
Decode.  latency: 0.00833 s, throughput:    120.04 token/s
Decode.  latency: 0.00869 s, throughput:    115.06 token/s
Decode.  latency: 0.00842 s, throughput:    118.79 token/s
Decode.  median latency: 0.00855 s, median throughput:    116.89 token/s
Total. latency:  0.090 s, throughput:   1471.26 token/s
Benchmark ...
Prefill. latency: 0.04294 s, throughput:   2980.61 token/s
Decode.  latency: 0.00839 s, throughput:    119.12 token/s
Decode.  latency: 0.00828 s, throughput:    120.78 token/s
Decode.  latency: 0.00857 s, throughput:    116.64 token/s
Decode.  latency: 0.00853 s, throughput:    117.19 token/s
Decode.  latency: 0.00859 s, throughput:    116.39 token/s
Decode.  median latency: 0.00853 s, median throughput:    117.17 token/s
Total. latency:  0.111 s, throughput:   1226.84 token/s

python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128

Warmup ...
Prefill. latency: 0.06413 s, throughput:   1996.05 token/s
Decode.  latency: 0.00764 s, throughput:    130.84 token/s
Decode.  latency: 0.00748 s, throughput:    133.73 token/s
Decode.  latency: 0.00725 s, throughput:    137.84 token/s
Decode.  latency: 0.00721 s, throughput:    138.74 token/s
Decode.  median latency: 0.00737 s, median throughput:    135.76 token/s
Total. latency:  0.094 s, throughput:   1408.61 token/s
Benchmark ...
Prefill. latency: 0.05239 s, throughput:   2443.43 token/s
Decode.  latency: 0.00739 s, throughput:    135.25 token/s
Decode.  latency: 0.00720 s, throughput:    138.90 token/s
Decode.  latency: 0.00718 s, throughput:    139.21 token/s
Decode.  latency: 0.00722 s, throughput:    138.42 token/s
Decode.  latency: 0.00745 s, throughput:    134.30 token/s
Decode.  median latency: 0.00731 s, median throughput:    136.82 token/s
Total. latency:  0.111 s, throughput:   1223.51 token/s

A100, no compile
python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo
max_total_num_tokens=199454
Warmup ...
Prefill. latency: 0.06958 s, throughput:   1839.60 token/s
Decode.  latency: 0.02343 s, throughput:     42.68 token/s
Decode.  latency: 0.02342 s, throughput:     42.70 token/s
Decode.  latency: 0.02368 s, throughput:     42.23 token/s
Decode.  latency: 0.02337 s, throughput:     42.80 token/s
Decode.  median latency: 0.02342 s, median throughput:     42.69 token/s
Total. latency:  0.163 s, throughput:    807.48 token/s
Benchmark ...
Prefill. latency: 0.05767 s, throughput:   2219.36 token/s
Decode.  latency: 0.02293 s, throughput:     43.61 token/s
Decode.  latency: 0.02026 s, throughput:     49.36 token/s
Decode.  latency: 0.02029 s, throughput:     49.29 token/s
Decode.  latency: 0.02024 s, throughput:     49.41 token/s
Decode.  latency: 0.02026 s, throughput:     49.36 token/s
Decode.  median latency: 0.02025 s, median throughput:     49.39 token/s
Total. latency:  0.222 s, throughput:    611.87 token/s

Reviewers:

Subscribers:

Tasks:

Tags:
jerryzh168 added a commit to jerryzh168/sglang that referenced this pull request Sep 14, 2024
Summary:
Similar to sgl-project#1341 we add torchao quantization to mixtral model

Test Plan:
Note: compile is not working yet, and I can't install torchnightly locally and make it work either.
I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later

python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8

Warmup ...
Prefill. latency: 0.05532 s, throughput:   2313.73 token/s
Decode.  latency: 0.00896 s, throughput:    111.65 token/s
Decode.  latency: 0.00833 s, throughput:    120.04 token/s
Decode.  latency: 0.00869 s, throughput:    115.06 token/s
Decode.  latency: 0.00842 s, throughput:    118.79 token/s
Decode.  median latency: 0.00855 s, median throughput:    116.89 token/s
Total. latency:  0.090 s, throughput:   1471.26 token/s
Benchmark ...
Prefill. latency: 0.04294 s, throughput:   2980.61 token/s
Decode.  latency: 0.00839 s, throughput:    119.12 token/s
Decode.  latency: 0.00828 s, throughput:    120.78 token/s
Decode.  latency: 0.00857 s, throughput:    116.64 token/s
Decode.  latency: 0.00853 s, throughput:    117.19 token/s
Decode.  latency: 0.00859 s, throughput:    116.39 token/s
Decode.  median latency: 0.00853 s, median throughput:    117.17 token/s
Total. latency:  0.111 s, throughput:   1226.84 token/s

python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128

Warmup ...
Prefill. latency: 0.06413 s, throughput:   1996.05 token/s
Decode.  latency: 0.00764 s, throughput:    130.84 token/s
Decode.  latency: 0.00748 s, throughput:    133.73 token/s
Decode.  latency: 0.00725 s, throughput:    137.84 token/s
Decode.  latency: 0.00721 s, throughput:    138.74 token/s
Decode.  median latency: 0.00737 s, median throughput:    135.76 token/s
Total. latency:  0.094 s, throughput:   1408.61 token/s
Benchmark ...
Prefill. latency: 0.05239 s, throughput:   2443.43 token/s
Decode.  latency: 0.00739 s, throughput:    135.25 token/s
Decode.  latency: 0.00720 s, throughput:    138.90 token/s
Decode.  latency: 0.00718 s, throughput:    139.21 token/s
Decode.  latency: 0.00722 s, throughput:    138.42 token/s
Decode.  latency: 0.00745 s, throughput:    134.30 token/s
Decode.  median latency: 0.00731 s, median throughput:    136.82 token/s
Total. latency:  0.111 s, throughput:   1223.51 token/s

A100, no compile
python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo
max_total_num_tokens=199454
Warmup ...
Prefill. latency: 0.06958 s, throughput:   1839.60 token/s
Decode.  latency: 0.02343 s, throughput:     42.68 token/s
Decode.  latency: 0.02342 s, throughput:     42.70 token/s
Decode.  latency: 0.02368 s, throughput:     42.23 token/s
Decode.  latency: 0.02337 s, throughput:     42.80 token/s
Decode.  median latency: 0.02342 s, median throughput:     42.69 token/s
Total. latency:  0.163 s, throughput:    807.48 token/s
Benchmark ...
Prefill. latency: 0.05767 s, throughput:   2219.36 token/s
Decode.  latency: 0.02293 s, throughput:     43.61 token/s
Decode.  latency: 0.02026 s, throughput:     49.36 token/s
Decode.  latency: 0.02029 s, throughput:     49.29 token/s
Decode.  latency: 0.02024 s, throughput:     49.41 token/s
Decode.  latency: 0.02026 s, throughput:     49.36 token/s
Decode.  median latency: 0.02025 s, median throughput:     49.39 token/s
Total. latency:  0.222 s, throughput:    611.87 token/s

Reviewers:

Subscribers:

Tasks:

Tags:
@zhyncs
Copy link
Member

zhyncs commented Sep 19, 2024

I'm wondering if it's OK for you to use torch nightly for now for testing?

@jerryzh168 I think we can give it a try, I will verify asap.

@zhyncs zhyncs added the quant LLM Quantization label Sep 19, 2024
@LIHUA919
Copy link

interesting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quant LLM Quantization
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants