|
11 | 11 | AttentionLayer, AttentionType) |
12 | 12 | from vllm.attention.backends.utils import CommonAttentionState |
13 | 13 |
|
14 | | -# These are the 2 tunable parameters of the paged attention Pallas kernel. |
15 | | -NUM_QUERIES_PER_BLOCK = 32 |
16 | | -NUM_KV_PAGES_PER_BLOCK = 128 |
17 | | - |
18 | 14 |
|
19 | 15 | class PallasAttentionBackend(AttentionBackend): |
20 | 16 |
|
@@ -115,13 +111,6 @@ def __init__( |
115 | 111 | tpu_version = torch_xla.tpu.version() |
116 | 112 | if tpu_version < 4: |
117 | 113 | raise NotImplementedError("TPU version must be 4 or higher.") |
118 | | - # NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB |
119 | | - # TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK, |
120 | | - # NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes |
121 | | - if tpu_version == 4: |
122 | | - self.vmem_limit_bytes = 16 * 1024 * 1024 |
123 | | - else: |
124 | | - self.vmem_limit_bytes = 64 * 1024 * 1024 |
125 | 114 |
|
126 | 115 | def forward( |
127 | 116 | self, |
@@ -165,9 +154,12 @@ def forward( |
165 | 154 | attn_metadata.block_tables, |
166 | 155 | attn_metadata.query_start_loc, |
167 | 156 | attn_metadata.num_seqs, |
168 | | - num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, |
169 | | - num_queries_per_block=NUM_QUERIES_PER_BLOCK, |
170 | | - vmem_limit_bytes=self.vmem_limit_bytes, |
| 157 | + # By default, the system utilizes optimized block size and |
| 158 | + # vmem_limit_bytes parameters from the kernel repository. However, |
| 159 | + # these can be manually adjusted for debugging if necessary. |
| 160 | + num_kv_pages_per_block=None, |
| 161 | + num_queries_per_block=None, |
| 162 | + vmem_limit_bytes=None, |
171 | 163 | use_kernel=True, |
172 | 164 | sm_scale=self.scale, |
173 | 165 | sliding_window=self.sliding_window, |
|
0 commit comments