Skip to content

Commit

Permalink
add support for batched input_pos to model (#1700)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 29, 2024
1 parent 1bfc24d commit fdf6a12
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
60 changes: 55 additions & 5 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,15 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")

if input_pos is not None: # use the kv cache
cos = self.cos.index_select(0, input_pos)
sin = self.sin.index_select(0, input_pos)
cos = batched_index_select(self.cos, 0, input_pos)
sin = batched_index_select(self.sin, 0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
mask = batched_index_select(self.mask_cache, 2, input_pos)
if mask.dim() > 4:
# the mask cache has a batch dim of 1 in addition to the one
# we get if input_pos has a batch dimension
mask = mask.squeeze(1)
else:
cos = self.cos[:T]
sin = self.sin[:T]
Expand Down Expand Up @@ -425,11 +429,57 @@ def build_rope_cache(
return torch.cos(idx_theta), torch.sin(idx_theta)


def batched_index_select(t, dim, idx):
"""index_select for batched index and unbatched t"""
if idx.dim() == 1:
return torch.index_select(t, dim, idx)

*batch_shape, idx_size = idx.shape
res = torch.index_select(t, dim, idx.reshape(-1)) # flat index
# split out single batch idx
res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :])
# move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors
dims = [dim] + list(range(res.dim()))
del dims[dim + 1]
res = res.permute(dims)
# unflatten batch dims
res = res.view(*batch_shape, *res.shape[1:])
return res


def batched_index_copy_(t, dim, idx, val):
"""index copy for batched t, idx, val"""
if idx.dim() == 1:
return t.index_copy_(dim, idx, val)

assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}"
assert dim != 0, f"cannot index batch dim"
batch_size, idx_size = idx.shape
assert batch_size == t.size(0)
assert batch_size == val.size(0)
t_indexed_dim = t.size(dim)

# if we can view the batch and indexed dimensions together, we could
# do index trickery. This is, sadly, not the case for kvcache so we
# fall back to for loop
for i in range(batch_size):
unbatched_dim = dim if dim < 0 else dim - 1
t[i].index_copy_(unbatched_dim, idx[i], val[i])
return t


def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
if cos.dim() > 1:
# batch dimensions must align
# sin/cos are (B, T, hs) so we unsqeeze -3 for nh
# we count from back because all of apply_rope does
cos = cos.unsqueeze(-3)
sin = sin.unsqueeze(-3)

roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)

Expand All @@ -452,8 +502,8 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) ->
self.v = self.v.to(v.dtype)
# update the cache
n = k.size(0)
k = self.k[:n, ...].index_copy_(2, input_pos, k)
v = self.v[:n, ...].index_copy_(2, input_pos, v)
k = batched_index_copy_(self.k[:n, ...], -2, input_pos, k)
v = batched_index_copy_(self.v[:n, ...], -2, input_pos, v)
return k, v

def reset_parameters(self) -> None:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import pytest
import warnings
from pathlib import Path
import litgpt
from litgpt.generate.base import next_token, batched_next_token
from litgpt.api import LLM, GPT
from litgpt.scripts.download import download_from_hub
from tests.conftest import RunIf

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -63,3 +65,44 @@ def test_batched_equivalence(tmp_path):
# Assert that single and batched next token generation are equivalent
assert all(t == tok_1 for t in toks_1), f"{tok_1} != {toks_1}"
assert all(t == tok_2 for t in toks_2), f"{tok_2} != {toks_2}"


@RunIf(min_cuda_gpus=1)
def test_simple_batch():
old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
config = litgpt.Config.from_name(
"Llama-3.1-8B", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256
)
with torch.device("cuda"):
m = litgpt.GPT(config).requires_grad_(False).eval()
x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]])
input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]])
x1 = torch.tensor([[1], [2]])
input_pos1 = torch.tensor([[4], [3]])

with torch.device("cuda"):
m.set_kv_cache(2)
outs0 = m(x0, input_pos0)
outs1 = m(x1, input_pos1)

with torch.device("cuda"):
m.set_kv_cache(1)

outs0_ref0 = m(x0[:1], input_pos0[0])
outs1_ref0 = m(x1[:1], input_pos1[0])

with torch.device("cuda"):
m.set_kv_cache(1)

outs0_ref1 = m(x0[1:], input_pos0[1])
outs1_ref1 = m(x1[1:], input_pos1[1])

outs0_ref = torch.cat([outs0_ref0, outs0_ref1])
outs1_ref = torch.cat([outs1_ref0, outs1_ref1])

print(outs0_ref - outs0)
print(outs0.shape)
torch.testing.assert_close(outs0, outs0_ref)
torch.testing.assert_close(outs1, outs1_ref)
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32

0 comments on commit fdf6a12

Please sign in to comment.