Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

benchmark decoding attention kernel with cudnn #2467

Merged
merged 3 commits into from
Dec 17, 2024

Conversation

bjmsong
Copy link
Contributor

@bjmsong bjmsong commented Dec 12, 2024

Motivation

follow this pr, add cudnn

Modifications

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@merrymercy
Copy link
Contributor

Can you share any results?

@bjmsong
Copy link
Contributor Author

bjmsong commented Dec 13, 2024

image

<style> </style>
head_num,batch_size,kv_len SGLang[triton] FlashInfer cuDNN
32,1,64 265.22 270.34 430.08
32,1,128 265.73 271.36 429.57
32,1,256 275.46 285.70 440.32
32,1,512 297.98 299.01 449.54
32,1,1024 316.42 289.79 496.64
32,1,2048 346.11 293.89 522.24
32,1,4096 435.20 326.66 676.86
32,4,64 258.05 270.34 429.06
32,4,128 271.87 270.34 432.13
32,4,256 287.74 288.77 439.30
32,4,512 359.42 295.42 452.61
32,4,1024 428.03 330.75 547.84
32,4,2048 608.26 387.07 943.10
32,4,4096 1005.57 505.34 1775.62
32,16,64 288.77 270.34 442.37
32,16,128 331.78 279.04 447.49
32,16,256 420.86 314.37 565.25
32,16,512 610.30 374.78 949.25
32,16,1024 1003.52 492.54 1730.56
32,16,2048 1823.23 726.02 3349.50
32,16,4096 3307.52 1199.10 6626.30
32,64,64 579.07 311.30 563.20
32,64,128 730.11 367.62 1056.77
32,64,256 1064.96 493.57 1803.78
32,64,512 1806.34 734.21 3406.85
32,64,1024 3356.67 1218.56 6606.85
32,64,2048 6454.27 2183.17 13052.93
32,64,4096 12635.14 3829.76 26053.63
64,1,64 253.95 264.19 423.94
64,1,128 253.95 262.66 417.79
64,1,256 284.67 281.60 429.06
64,1,512 302.08 282.62 453.63
64,1,1024 343.55 288.77 475.14
64,1,2048 404.48 317.44 568.32
64,1,4096 592.90 366.59 979.97
64,4,64 263.68 264.19 430.08
64,4,128 285.70 265.22 429.06
64,4,256 321.54 276.48 459.78
64,4,512 477.18 303.10 555.01
64,4,1024 602.11 351.23 931.84
64,4,2048 992.26 455.68 1730.56
64,4,4096 1806.34 666.11 3369.47
64,16,64 380.93 277.50 464.90
64,16,128 459.78 305.15 592.90
64,16,256 624.64 351.23 980.99
64,16,512 989.18 460.80 1762.82
64,16,1024 1756.16 670.72 3357.70
64,16,2048 3312.64 1099.78 6560.77
64,16,4096 6398.98 1992.70 12994.56
64,64,64 960.51 369.66 973.82
64,64,128 1247.23 489.47 1938.43
64,64,256 1910.78 679.94 3522.05
64,64,512 3371.01 1101.31 6738.43
64,64,1024 6470.66 2007.04 13148.16
64,64,2048 12626.94 3844.10 25998.34
64,64,4096 24993.79 6835.71 47107.07

@zhyncs
Copy link
Member

zhyncs commented Dec 13, 2024

The hopper optimization will be released soon on FlashInfer's latest main branch. cc @yzh119

@zhyncs
Copy link
Member

zhyncs commented Dec 13, 2024

@bjmsong Nice work!!!

@zhyncs zhyncs self-assigned this Dec 13, 2024
@zhyncs zhyncs requested review from yzh119 and zhyncs December 13, 2024 08:40
@zhyncs
Copy link
Member

zhyncs commented Dec 13, 2024

@bjmsong Is this data execution time, meaning the lower, the better?

@bjmsong
Copy link
Contributor Author

bjmsong commented Dec 13, 2024

@bjmsong Is this data execution time, meaning the lower, the better?

yeah, it's in microsecond, the lower, the better.

@zhyncs
Copy link
Member

zhyncs commented Dec 13, 2024

I'm just a bit confused about why cuDNN is even worse than Triton.

@bjmsong
Copy link
Contributor Author

bjmsong commented Dec 15, 2024

I will optimize the performance later.

@bjmsong
Copy link
Contributor Author

bjmsong commented Dec 16, 2024

  • update:only measure the kernel execution time, excluding the data processing part.

image

<style> </style>
head_num,batch_size,kv_len cuDNN SGLang[triton] FlashInfer
32,1,64 37.88 84.84 28.75
32,1,128 35.45 89.70 28.25
32,1,256 39.78 79.64 33.11
32,1,512 45.10 80.74 33.07
32,1,1024 55.44 82.06 34.19
32,1,2048 71.03 84.60 43.19
32,1,4096 103.06 134.33 77.41
32,4,64 34.28 84.74 28.31
32,4,128 38.45 88.59 29.74
32,4,256 52.45 90.46 33.21
32,4,512 70.65 83.49 43.79
32,4,1024 104.42 105.05 76.36
32,4,2048 166.11 172.79 139.27
32,4,4096 287.11 311.79 256.56
32,16,64 56.99 127.50 29.26
32,16,128 78.53 131.28 39.81
32,16,256 116.72 140.26 73.93
32,16,512 178.32 173.57 131.03
32,16,1024 305.11 294.56 237.77
32,16,2048 556.01 533.53 483.58
32,16,4096 1061.43 1020.02 954.83
32,64,64 165.24 445.24 75.97
32,64,128 229.19 452.60 119.45
32,64,256 366.27 478.52 223.89
32,64,512 618.63 567.01 461.61
32,64,1024 1132.89 1023.42 934.39
32,64,2048 2143.48 1957.95 1914.50
32,64,4096 3379.35 3144.74 2898.98
64,1,64 35.13 85.39 29.26
64,1,128 36.26 78.99 28.35
64,1,256 40.90 80.19 34.23
64,1,512 53.57 79.05 33.35
64,1,1024 68.45 80.47 44.00
64,1,2048 101.84 99.62 76.40
64,1,4096 160.52 167.03 130.98
64,4,64 43.80 81.62 28.39
64,4,128 52.91 82.30 29.14
64,4,256 69.67 86.65 35.52
64,4,512 102.96 105.07 69.38
64,4,1024 170.00 169.51 125.83
64,4,2048 294.75 300.52 234.91
64,4,4096 533.40 565.29 460.78
64,16,64 86.15 232.47 45.80
64,16,128 124.32 237.48 73.89
64,16,256 193.64 251.59 127.38
64,16,512 320.78 305.25 236.49
64,16,1024 575.09 531.26 484.83
64,16,2048 1077.27 994.54 960.10
64,16,4096 2073.99 1915.31 1941.34
64,64,64 281.46 867.22 133.79
64,64,128 413.01 877.68 224.82
64,64,256 679.76 919.92 453.97
64,64,512 1188.34 1082.00 930.13
64,64,1024 2208.71 1989.55 1891.40
64,64,2048 3383.27 3234.63 2957.94
64,64,4096 6510.72 6140.89 5833.15

@ispobock
Copy link
Collaborator

ispobock commented Dec 16, 2024

  1. Did you consider warmup for these kernels before benchmark? Since the Triton kernel is JIT compiled, compilation overhead may be included.
  2. It seems you only tested MHA here, while there are many popular models use GQA. Could you also include GQA in that benchmark?

@bjmsong
Copy link
Contributor Author

bjmsong commented Dec 17, 2024

Update:Add warmup & GQA

Here're some results:
image

<style> </style>
(batch_size,kv_len) cuDNN Triton FlashInfer
(1,64) 31.67 78.47 27.47
(1,128) 33.89 76.19 27.52
(1,256) 39.62 76.14 31.54
(1,512) 44.29 78.13 31.89
(1,1024) 54.87 82.72 34.28
(1,2048) 70.26 84.93 42.92
(1,4096) 102.61 133.10 76.31
(4,64) 32.80 79.01 32.76
(4,128) 37.38 80.16 28.25
(4,256) 53.32 77.67 32.96
(4,512) 70.30 81.66 43.75
(4,1024) 103.56 104.41 79.21
(4,2048) 164.40 172.06 144.71
(4,4096) 285.75 314.03 251.61
(16,64) 55.95 125.68 29.11
(16,128) 78.14 131.68 39.18
(16,256) 113.02 140.26 76.06
(16,512) 177.33 170.94 131.64
(16,1024) 305.33 294.01 243.98
(16,2048) 557.51 534.84 471.28
(16,4096) 1055.98 1017.17 976.81
(64,64) 165.15 444.49 75.20
(64,128) 231.59 451.85 122.77
(64,256) 362.76 478.56 225.86
(64,512) 620.90 564.61 462.78
(64,1024) 1126.41 1025.39 944.30
(64,2048) 2146.21 1953.23 1918.42
(64,4096) 3307.69 3091.87 2898.67

image

<style> </style>
cuDNN Triton FlashInfer
31.86 84.42 31.54
34.55 79.58 31.16
36.42 82.85 37.53
40.71 81.03 36.46
45.70 87.36 36.70
51.72 80.57 39.14
61.44 82.22 43.16
34.12 87.26 32.73
34.91 81.41 34.45
40.57 84.48 38.14
45.47 78.70 38.05
56.83 82.04 47.40
73.38 83.88 56.81
107.25 117.13 89.00
37.31 84.59 32.47
39.06 81.56 32.97
55.57 80.72 33.84
73.59 87.54 49.32
108.06 108.73 82.31
171.88 179.98 136.74
296.34 320.65 250.79
70.56 136.65 54.47
91.39 141.97 62.13
128.15 150.03 89.96
195.93 216.71 140.90
322.75 344.60 247.24
581.26 594.42 472.30
1103.89 1083.44 884.71

@merrymercy merrymercy merged commit e210266 into sgl-project:main Dec 17, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants