- 
                Notifications
    
You must be signed in to change notification settings  - Fork 0
 
Paged attention #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
        
          
                torchao/kv_cache.py
              
                Outdated
          
        
      | cache: torch.Tensor, | ||
| block_tables: torch.Tensor, | ||
| context_lens: torch.Tensor, | ||
| ): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding document on these args? Also, would be helpful to note what are owned by the object and what are shared? I guess cache is shared among multiple PagedTensors? What are the shapes for these tensors?
        
          
                torchao/kv_cache.py
              
                Outdated
          
        
      | ): | ||
| self.block_tables = block_tables | ||
| self.cache = cache | ||
| self.context_lens = context_lens | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it is good or general enough to incorporate "context length" into the semantics of a PagedTensor. The context length sounds like an app-level concept, not a general tensor-level concept?
        
          
                torchao/kv_cache.py
              
                Outdated
          
        
      | key_cache = key_tensor.cache | ||
| value_cache = value_tensor.cache | ||
| num_kv_head = key_cache.size(1) | ||
| num_queries_per_kv = query.size(1) // num_kv_head | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add an assertion here to make sure query.size(1) % num_kv_head == 0?
        
          
                torchao/kv_cache.py
              
                Outdated
          
        
      | query, | ||
| key_cache, | ||
| value_cache, | ||
| head_mapping, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this head_mapping and move it into the implementation? Assume we always do the even mapping here.
| : 0; | ||
| int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; | ||
| 
               | 
          ||
| auto max_num_partitions = | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The name partition sounds too general. Suggest to specify it is for sequence, e.g., max_num_seq_partitions. Same comments for other related names.
| * @param out Output tensor [num_seqs, 1, num_heads, head_size]. | ||
| * @param query Query tensor [num_seqs, 1, num_heads, head_size]. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add runtime assertion in the code to make sure the query has seq length 1 here. BTW, can we extend the implementation to support query seq length > 1 which can benefit chunked prefill and multi-turn conversation cases?
        
          
                torchao/kv_cache.py
              
                Outdated
          
        
      | head_mapping = torch.repeat_interleave( | ||
| torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), num_queries_per_kv | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do this inside the paged_attention c++ kernel so that we don't need to pass this head_mapping arg to it? This simplifies the kernel interface.
| reshape_attn_mask_to_4d(attn_mask.value(), num_seqs, num_heads, q_len, | ||
| attn_mask.value().size(-1)); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this convert the attn_mask to 4D or just view it as 4D? Since we are working on raw pointers, perhaps we don't need to expand it as 4D view here?
| if (has_attn_mask) { | ||
| _scale_attn_mask_fusion_kernel<accum_t, accum_t>( | ||
| logits, | ||
| attn_mask_ptr + seq_id * mStrideB + head_id * mStrideH + | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we need to carefully handle the case where the size is 1 in some dim of the mask here.
Co-authored-by: Jiong Gong <jiong.gong@intel.com>
* feat: starting layout implementation fix: namespace of common modules chore: remove not needed test file fix: op name being registered chore: can compile the cuda kernel fix: segmentation fault chore: wip - paste test code just to check if everything passes feat: wip - adding layout. unpack not working fix: circular import feat: wip - can almost revert feat: can unpack. just needs cleanup chore: improve layout code chore: wip - mm needs work feat: wip - something seems wrong fix: e2e test feat: wip - add group param fix: unpack weights feat: marlin is implemented and correct chore: rebase chore: remove old import feat: use int4 instead of dequantizing chore: remove unused fn feat: add checks and validation feat: add new kernel and refactor code (#1) * feat: wip - adding new kernel * feat: wip - continue working on the unpack * feat: wip - working on unpacking * feat: remove old op * feat: more code changes * chore: remove old code * feat: more code * chore: more code changes * chore: more code changes * feat: add more documentation * fix: dataclass * feat: add more docs * feat: remove assert chore: block 8 bits chore: update comment feat: refactor dispatch chore: add validation on group size chore: wip - working on fixing unpack feat: add small readme with sources feat: add checks feat: tests pass & can execute llama2 * compile kind of working * fix: batching and layout outputs correct results * fix: torch.compile * wip * feat: wip * chore: cleanup * chore: review * chore: review v2 * update benchmarks + README --------- Co-authored-by: Jesse Cai <jcjessecai@gmail.com>
* Lint fixes; * Ruff auto-format
Revert "Lint fixes #1 torchao/dtypes (pytorch#827)" This reverts commit 144445a. Co-authored-by: Mark Saroufim <marksaroufim@gmail.com>
No description provided.