Skip to content

Conversation

@liangan1
Copy link
Owner

@liangan1 liangan1 commented Jun 5, 2024

No description provided.

Comment on lines 13 to 16
cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding document on these args? Also, would be helpful to note what are owned by the object and what are shared? I guess cache is shared among multiple PagedTensors? What are the shapes for these tensors?

):
self.block_tables = block_tables
self.cache = cache
self.context_lens = context_lens
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it is good or general enough to incorporate "context length" into the semantics of a PagedTensor. The context length sounds like an app-level concept, not a general tensor-level concept?

key_cache = key_tensor.cache
value_cache = value_tensor.cache
num_kv_head = key_cache.size(1)
num_queries_per_kv = query.size(1) // num_kv_head
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an assertion here to make sure query.size(1) % num_kv_head == 0?

query,
key_cache,
value_cache,
head_mapping,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this head_mapping and move it into the implementation? Assume we always do the even mapping here.

: 0;
int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0;

auto max_num_partitions =
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The name partition sounds too general. Suggest to specify it is for sequence, e.g., max_num_seq_partitions. Same comments for other related names.

Comment on lines 227 to 228
* @param out Output tensor [num_seqs, 1, num_heads, head_size].
* @param query Query tensor [num_seqs, 1, num_heads, head_size].
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add runtime assertion in the code to make sure the query has seq length 1 here. BTW, can we extend the implementation to support query seq length > 1 which can benefit chunked prefill and multi-turn conversation cases?

Comment on lines 59 to 61
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), num_queries_per_kv
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this inside the paged_attention c++ kernel so that we don't need to pass this head_mapping arg to it? This simplifies the kernel interface.

Comment on lines 268 to 269
reshape_attn_mask_to_4d(attn_mask.value(), num_seqs, num_heads, q_len,
attn_mask.value().size(-1));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this convert the attn_mask to 4D or just view it as 4D? Since we are working on raw pointers, perhaps we don't need to expand it as 4D view here?

if (has_attn_mask) {
_scale_attn_mask_fusion_kernel<accum_t, accum_t>(
logits,
attn_mask_ptr + seq_id * mStrideB + head_id * mStrideH +
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we need to carefully handle the case where the size is 1 in some dim of the mask here.

liangan1 pushed a commit that referenced this pull request May 22, 2025
* feat: starting layout implementation

fix: namespace of common modules

chore: remove not needed test file

fix: op name being registered

chore: can compile the cuda kernel

fix: segmentation fault

chore: wip - paste test code just to check if everything passes

feat: wip - adding layout. unpack not working

fix: circular import

feat: wip - can almost revert

feat: can unpack. just needs cleanup

chore: improve layout code

chore: wip - mm needs work

feat: wip - something seems wrong

fix: e2e test

feat: wip - add group param

fix: unpack weights

feat: marlin is implemented and correct

chore: rebase

chore: remove old import

feat: use int4 instead of dequantizing

chore: remove unused fn

feat: add checks and validation

feat: add new kernel and refactor code (#1)

* feat: wip - adding new kernel

* feat: wip - continue working on the unpack

* feat: wip - working on unpacking

* feat: remove old op

* feat: more code changes

* chore: remove old code

* feat: more code

* chore: more code changes

* chore: more code changes

* feat: add more documentation

* fix: dataclass

* feat: add more docs

* feat: remove assert

chore: block 8 bits

chore: update comment

feat: refactor dispatch

chore: add validation on group size

chore: wip - working on fixing unpack

feat: add small readme with sources

feat: add checks

feat: tests pass & can execute llama2

* compile kind of working

* fix: batching and layout outputs correct results

* fix: torch.compile

* wip

* feat: wip

* chore: cleanup

* chore: review

* chore: review v2

* update benchmarks + README

---------

Co-authored-by: Jesse Cai <jcjessecai@gmail.com>
liangan1 pushed a commit that referenced this pull request May 22, 2025
* Lint fixes;

* Ruff auto-format
liangan1 pushed a commit that referenced this pull request May 22, 2025
Revert "Lint fixes #1 torchao/dtypes (pytorch#827)"

This reverts commit 144445a.

Co-authored-by: Mark Saroufim <marksaroufim@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants