Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] DeepSeek V2 H100 x8 Triton failure #913

Closed
3 tasks done
zhyncs opened this issue Aug 4, 2024 · 8 comments · Fixed by #1060
Closed
3 tasks done

[Bug] DeepSeek V2 H100 x8 Triton failure #913

zhyncs opened this issue Aug 4, 2024 · 8 comments · Fixed by #1060
Assignees
Labels
bug Something isn't working

Comments

@zhyncs
Copy link
Member

zhyncs commented Aug 4, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

[gpu=0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, #running-req: 0, #queue-req: 0
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
python: /project/lib/Analysis/Allocation.cpp:43: std::pair<llvm::SmallVector<unsigned int>, llvm::SmallVector<unsigned int> > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed.
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

ref triton-lang/triton#4418

Reproduction

Using FlashInfer is ok, and I just want to test with Triton.

python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --disable-radix-cache --tp 8 --trust-remote-code --disable-flashinfer

Environment

Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H100 80GB HBM3
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.1, V12.1.105
CUDA Driver Version: 550.54.15
550.54.15
550.54.15
550.54.15
550.54.15
550.54.15
550.54.15
550.54.15
PyTorch: 2.3.1+cu121
sglang: 0.2.9.post1
flashinfer: 0.1.3+cu121torch2.3
requests: 2.32.3
tqdm: 4.66.5
numpy: 1.26.3
aiohttp: 3.10.0
fastapi: 0.112.0
hf_transfer: 0.1.8
huggingface_hub: 0.24.5
interegular: 0.3.3
packaging: 23.2
PIL: 10.2.0
psutil: 5.9.8
pydantic: 2.8.2
uvicorn: 0.30.5
uvloop: 0.19.0
zmq: 24.0.1
vllm: 0.5.3.post1
multipart: 0.0.9
openai: 1.38.0
anthropic: 0.32.0
NVIDIA Topology:
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    SYS     SYS     0-51,104-155    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    SYS     SYS     0-51,104-155    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    SYS     SYS     0-51,104-155    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    SYS     SYS     0-51,104-155    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     52-103,156-207  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     52-103,156-207  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    PXB     PXB     52-103,156-207  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      PXB     PXB     52-103,156-207  1               N/A
NIC0    SYS     SYS     SYS     SYS     SYS     SYS     PXB     PXB      X      PIX
NIC1    SYS     SYS     SYS     SYS     SYS     SYS     PXB     PXB     PIX      X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1


ulimit soft: 1048576
@zhyncs zhyncs added the bug Something isn't working label Aug 4, 2024
@zhyncs
Copy link
Member Author

zhyncs commented Aug 4, 2024

ref #905

hardware: 8 x H100 SXM
model: deepseek-ai/DeepSeek-V2
dataset: ShareGPT 1k

# main
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --disable-radix-cache --tp 8 --trust-remote-code 

# main
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --disable-radix-cache --tp 8 --trust-remote-code --disable-flashinfer

# mla
python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V2 --disable-radix-cache --tp 8 --trust-remote-code --enable-mla

# client
python3 -m sglang.bench_serving --backend sglang

main branch (with FlashInfer)

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     1000
Benchmark duration (s):                  446.53
Total input tokens:                      236142
Total generated tokens:                  215614
Total generated tokens (retokenized):    215049
Request throughput (req/s):              2.24
Input token throughput (tok/s):          528.84
Output token throughput (tok/s):         482.86
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   221235.71
Median E2E Latency (ms):                 224602.89
---------------Time to First Token----------------
Mean TTFT (ms):                          202381.00
Median TTFT (ms):                        207191.84
P99 TTFT (ms):                           407561.37
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          89.28
Median TPOT (ms):                        36.82
P99 TPOT (ms):                           1648.87
---------------Inter-token Latency----------------
Mean ITL (ms):                           1038.86
Median ITL (ms):                         32.85
P99 ITL (ms):                            197.76
==================================================

main branch (with Triton)

not work right now on H100s

mla branch (with Triton)

not work right now on H100s

@Ying1123
Copy link
Member

Ying1123 commented Aug 4, 2024

try triton nightly or later versions

@zhyncs
Copy link
Member Author

zhyncs commented Aug 5, 2024

try triton nightly or later versions

ok

@zhyncs zhyncs self-assigned this Aug 7, 2024
@zhyncs
Copy link
Member Author

zhyncs commented Aug 11, 2024

cc @ispobock

@zhyncs
Copy link
Member Author

zhyncs commented Aug 11, 2024

ref triton-lang/triton#4492

@Jokeren
Copy link

Jokeren commented Aug 11, 2024

Will merge the PR soon. No worries

@zhyncs
Copy link
Member Author

zhyncs commented Aug 12, 2024

Hi @Jokeren Thanks for the fix. How can I use the latest commit? It seems that nightly build is not latest.

@Jokeren
Copy link

Jokeren commented Aug 12, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants