@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2727 // PagedAttention V2.
2828 ops.def (
2929 " paged_attention_v2("
30- " Tensor! out, Tensor exp_sums, Tensor max_logits,"
31- " Tensor tmp_out, Tensor query, Tensor key_cache,"
30+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
31+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
3232 " Tensor value_cache, int num_kv_heads, float scale,"
3333 " Tensor block_tables, Tensor seq_lens, int block_size,"
3434 " int max_seq_len, Tensor? alibi_slopes,"
@@ -95,8 +95,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
9595
9696 // Copy the cache blocks from src to dst.
9797 cache_ops.def (
98- " copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
99- " block_mapping) -> ()" );
98+ " copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
99+ " Tensor block_mapping) -> ()" );
100100 cache_ops.impl (" copy_blocks" , torch::kCPU , ©_blocks);
101101
102102 // Reshape the key and value tensors and cache them.
0 commit comments