Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefill/decode from seq lens in BaseCausalLMModel #383

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 18 additions & 49 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def generate_batch_prefill(bs: int):
"tokens": {1: sl_dim},
"seq_lens": {},
"seq_block_ids": {1: block_dim},
"cs": cache_dynamic_shapes,
"cache_state": cache_dynamic_shapes,
}

print(f"Exporting prefill_bs{bs}")
Expand All @@ -186,40 +186,22 @@ def generate_batch_prefill(bs: int):
strict=args.strict,
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
def _(model, tokens, seq_lens, seq_block_ids, cache_state):
if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_tensors = torch.unbind(cs)
else:
cache_tensors = cs

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

cache_tensors = repack_cache(cs, cache_shard_dim)
cache_state = torch.unbind(cache_state)
if model.config.tensor_parallelism_size != 1:
cache_state = repack_cache(cache_state, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=attention_mask,
return model.prefill_from_seq_lens(
tokens=tokens,
seq_lens=seq_lens,
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
cache_state=cache_state,
)

if llama_config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

def generate_batch_decode(bs: int):
tokens = torch.ones(bs, 1, dtype=torch.int64)
seq_lens = torch.ones(bs, dtype=torch.int64)
Expand Down Expand Up @@ -274,34 +256,21 @@ def _(
seq_block_ids,
cache_state,
):
input_mask = model.input_mask(
seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride
)
attention_mask = model.decode_attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
start_positions = ops.replicate(start_positions, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

if (
model.config.tensor_parallelism_size == 1
and model.config.kv_cache_type == "direct"
):
cache_state = torch.unbind(cache_state)
if model.config.tensor_parallelism_size != 1:
cache_state = repack_cache(cache_state, cache_shard_dim)

logits = model.decode(
tokens,
attention_mask=attention_mask,
return model.decode_from_seq_lens(
tokens=tokens,
seq_lens=seq_lens,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)

if llama_config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
Expand Down
29 changes: 10 additions & 19 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TorchGenerator:

def __init__(
self,
model: PagedLlamaModelV1,
model: BaseCausalLMModel,
tokenizer: InferenceTokenizer,
page_cache_size: int = 128,
# Need to look at the model more for this.
Expand Down Expand Up @@ -162,17 +162,14 @@ def compute_prefill_logits(

def prefill(self):
model = self.parent.model
attention_mask = model.attention_mask(
model.input_mask(self.seq_lens, self.token_ids.shape[1])
)
seq_block_ids_tensor = self.pad_block_ids()
print(f":: Invoke prefill:")
trace_tensor("prefill.token_ids", self.token_ids)
trace_tensor("prefill.seq_lens", self.seq_lens)
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
trace_tensor("prefill.attention_mask", attention_mask)
logits = model.prefill(
self.logits = model.prefill_from_seq_lens(
self.token_ids,
attention_mask=attention_mask,
seq_lens=self.seq_lens,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
Expand All @@ -181,7 +178,7 @@ def prefill(self):
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, self.seq_lens)
model.extract_tokens_from_logits(self.logits, self.seq_lens)
).unsqueeze(1)
print(f":: Prefill results:\n{tokens.tolist()}")
self.add_result_token(tokens)
Expand All @@ -194,28 +191,22 @@ def decode(self):
self.allocate_seq_block_ids()
# TODO: Allocate more blocks on overflow.
seq_block_ids_tensor = self.pad_block_ids()
decode_attention_mask = model.decode_attention_mask(
model.input_mask(
self.seq_lens,
seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride,
)
)
trace_tensor("decode.token_ids", self.next_tokens)
trace_tensor("decode.seq_lens", self.seq_lens)
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)
logits = model.decode(
self.logits = model.decode_from_seq_lens(
self.next_tokens,
attention_mask=decode_attention_mask,
seq_lens=self.seq_lens,
start_positions=start_positions,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
trace_tensor("decode.logits", logits)
trace_tensor("decode.logits", self.logits)
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
tokens = torch.tensor(
model.extract_tokens_from_logits(logits, [1] * self.bs),
model.extract_tokens_from_logits(self.logits, [1] * self.bs),
device=self.parent.model.device,
).unsqueeze(1)
self.add_result_token(tokens)
Expand Down
4 changes: 3 additions & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from .base import BaseLayer, ThetaLayer
from .conv import Conv2DLayer
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
from .causal_llm import BaseCausalLMModel
from .causal_llm import (
BaseCausalLMModel,
)
from .linear import LinearLayer
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
Expand Down
108 changes: 101 additions & 7 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional
from typing import Optional, Union
from abc import ABC, abstractmethod

import torch

from ..types import Theta
from ..types import SplitPrimitiveTensor, ReplicatedTensor
from .. import ops
from .base import (
ThetaLayer,
BaseLayer,
)


class BaseCausalLMModel(ThetaLayer):
class BaseCausalLMModel(BaseLayer):
"""Base class for causal LM models.

This provides some utilities and common API surface related to masking
Expand All @@ -25,16 +27,14 @@ class BaseCausalLMModel(ThetaLayer):

def __init__(
self,
theta: Theta,
*,
context_length: int,
static_tables: bool = True,
static_context_mask: bool = False,
device: Optional[torch.device] = None,
activation_dtype: torch.dtype = torch.float32,
attention_dtype: torch.dtype = torch.float32,
):
super().__init__(theta)
super().__init__()
self.device = device
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
Expand Down Expand Up @@ -149,3 +149,97 @@ def extract_tokens_from_logits(
step_logits = logits[batch, seq_len - 1]
results.append(torch.argmax(step_logits))
return results

def prefill(
self,
# [bs, batch_seq_len]
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [1, 1, batch_seq_len, batch_seq_len]
attention_mask: Union[torch.Tensor, ReplicatedTensor],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
):
raise NotImplementedError()

def prefill_from_seq_lens(
self,
tokens: torch.Tensor,
*,
seq_lens: torch.Tensor,
seq_block_ids: torch.Tensor,
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
):
batch_seq_len = tokens.shape[1]
input_mask = self.input_mask(seq_lens, batch_seq_len)
attention_mask = self.attention_mask(input_mask)

if self.config.tensor_parallelism_size != 1:
shard_count = self.config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

logits = self.prefill(
tokens,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)

if self.config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits

def decode(
self,
# [bs, 1]
tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [bs, 1, 1, batch_seq_len]
attention_mask: Union[torch.Tensor, ReplicatedTensor],
# [bs] of starting positions
start_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, batch_seq_len // block_seq_stride]
seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
):
raise NotImplementedError()

def decode_from_seq_lens(
self,
tokens: torch.Tensor,
*,
seq_lens: torch.Tensor,
start_positions: torch.Tensor,
seq_block_ids: torch.Tensor,
cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
):
input_mask = self.input_mask(
seq_lens, seq_block_ids.shape[1] * self.cache.block_seq_stride
)
attention_mask = self.decode_attention_mask(input_mask)

if self.config.tensor_parallelism_size != 1:
shard_count = self.config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
attention_mask = ops.replicate(attention_mask, count=shard_count)
start_positions = ops.replicate(start_positions, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)

logits = self.decode(
tokens,
attention_mask=attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)

if self.config.tensor_parallelism_size != 1:
logits = ops.unshard(logits)

return logits
1 change: 0 additions & 1 deletion sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class PagedGrokModelV1(BaseCausalLMModel):
def __init__(self, theta: Theta, config: LlamaModelConfig):
hp = config.hp
super().__init__(
theta,
context_length=config.hp.context_length,
device=config.device,
activation_dtype=config.activation_dtype,
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class PagedLlamaModelV1(BaseCausalLMModel):
def __init__(self, theta: Theta, config: LlamaModelConfig):
hp = config.hp
super().__init__(
theta,
context_length=config.hp.context_length,
static_tables=config.static_tables,
device=config.device,
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/models/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class PagedMixtralModelV1(BaseCausalLMModel):
def __init__(self, theta: Theta, config: LlamaModelConfig):
hp = config.hp
super().__init__(
theta,
context_length=config.hp.context_length,
device=config.device,
activation_dtype=config.activation_dtype,
Expand Down
4 changes: 1 addition & 3 deletions sharktank/tests/models/llama/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def setUp(self):
self.embedding_batch_mask = self.attention_embedding.compute_batch_mask(
self.start_positions, batch_seq_len=1
)
self.model = causal_llm.BaseCausalLMModel(
self.attention_block_theta, context_length=self.max_seq_len
)
self.model = causal_llm.BaseCausalLMModel(context_length=self.max_seq_len)
self.prefill_attention_mask = self.model.attention_mask(
self.model.input_mask(self.start_positions, self.seq_len)
)
Expand Down