Skip to content

Commit

Permalink
Add KVCache for the google's version
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Sep 19, 2023
2 parents 8043c01 + af307f4 commit c3e97f5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 46 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ Therefore, it decodes by running the target model in parallel on the outputs of

The speculative sampling is proposed by Google and Deepmind independently. So I implement two slightly different versions of speculative sampling: [Google's](https://arxiv.org/abs/2211.17192) and [Deepmind's](https://arxiv.org/abs/2302.01318).

## Update Logs

2023.09.19 Add KVCache Optimization to the Google's version.

2023.08.16 First release, implement the paper's algorithm.

## Usage
In the sample, I use [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model.
Expand Down
181 changes: 135 additions & 46 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ def norm_logits(logits : torch.Tensor, temperature : float, top_k : float, top_p
"""
Args:
logits (torch.Tensor): shape (batch, seqlen, vocab)
logits (torch.Tensor): shape (1, vocab)
temperature (float): temperature
top_k (float): top_k
top_p (float): top_p
Returns:
torch.Tensor: next token with shape as (batch, 1)
"""
assert logits.dim() == 2
logits = logits / temperature
logits = top_k_top_p_filter(logits, top_k=top_k, top_p=top_p)
probs = F.softmax(logits, dim=1)
Expand All @@ -80,12 +81,75 @@ def sample(probs : torch.Tensor, num_samples: int = 1):
def _debug_show_kvcache(past_key_values):
if past_key_values is None:
return
print("show k,v cache shapes")
for elem in past_key_values:
k, v = elem
print(f"k shape {k.shape}, v shape {v.shape}")
print(f"kv cache: k shape {k.shape}, v shape {v.shape}")
break

def trim_kv_cache(past_key_values : Tuple[Tuple[torch.Tensor, torch.Tensor]], q : torch.Tensor, end_pos : int):
"""
trim the KV cache to the end_pos
Args:
past_key_values (Tuple): KV Cache
end_pos (int): the position of the valid prefix
Returns:
Tuple: the trimmed KV Cache
"""
past_key_values_trimmed = []
for kv in past_key_values:
k, v = kv
# NOTE() the indexing is specific for bloom. This won't work for other models
# For example llama k, v should be (batch, num_head, seq_len, hidden_dim)
k = k[:, :, :end_pos]
v = v[:, :end_pos, :]
kv_trimmed = (k, v)
past_key_values_trimmed.append(kv_trimmed)

q = q[:, :end_pos, :]
return past_key_values_trimmed, q


def forward_with_kvcache(model, input_ids, past_key_values, cached_q, temperature, top_k, top_p, use_debug = False):
if past_key_values is None:
assert cached_q is None
# the first forward returns the prompt's logits
outputs = model(input_ids)
cached_q = outputs.logits
for i in range(cached_q.shape[-2]):
cached_q[:, i, :] = norm_logits(cached_q[:, i, :], temperature, top_k, top_p)
last_q = cached_q[:, -1, :]
else:
# return the last token's logits
cached_len = 0
for kv in past_key_values:
k, v = kv
cached_len = k.shape[2]

last_input_id = input_ids[:, cached_len:]
if last_input_id.dim() == 1:
last_input_id = torch.unsqueeze(last_input_id, 0)

if use_debug:
print(f"last_input_id shape {last_input_id.shape}")
_debug_show_kvcache(past_key_values)

outputs = model(last_input_id, past_key_values=past_key_values, use_cache=True)

not_cached_q = outputs.logits
if not_cached_q.dim() == 2:
not_cached_q = torch.unsqueeze(not_cached_q, 0)

for i in range(not_cached_q.shape[-2]):
not_cached_q[:, i, :] = norm_logits(not_cached_q[:, i, :], temperature, top_k, top_p)

if cached_q is not None:
cached_q = torch.cat([cached_q, not_cached_q], dim=1)
last_q = not_cached_q[:, -1, :]

return last_q, outputs.past_key_values, cached_q

@torch.no_grad()
def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int,
temperature : float = 1, top_k : int = 0, top_p : float = 0):
Expand All @@ -95,9 +159,11 @@ def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int,
past_key_values = None
with tqdm(total=N, desc="autoregressive sampling") as pbar:
while n < T:
outputs = model(x)
logits = outputs.logits[::, -1, :]
idx_next = sample(norm_logits(logits, temperature, top_k, top_p))
# outputs = model(x)
last_q, past_key_values, _ = forward_with_kvcache(model, x, past_key_values, None, temperature, top_k, top_p)
# logits = outputs.logits[::, -1, :]
# past_key_values = outputs.past_key_values
idx_next = sample(last_q)
x = torch.cat((x, idx_next), dim=1)
n += 1
pbar.update(1)
Expand All @@ -114,18 +180,17 @@ def max_fn(x):
x_max_sum = torch.sum(x_max, dim=1, keepdim=True)
return x_max / x_max_sum

# def norm_logits(p : torch.Tensor):
# """
# normalize logits using softmax to probabilities along the last dimension.
# """
# return F.softmax(p, dim=-1)



def _approx_model_serial_forward(prefix : torch.Tensor, gamma : int, approx_model : torch.nn.Module,
temperature : float, top_k : float, top_p : float,
past_key_values : Tuple[Tuple[torch.Tensor, torch.Tensor]] = None,) -> Tuple[torch.Tensor, torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor]]]:
""" forward approx model gamma times
def generate_with_kvcache(prefix : torch.Tensor,
gamma : int,
approx_model : torch.nn.Module,
temperature : float,
top_k : float,
top_p : float,
past_key_values : Tuple[Tuple[torch.Tensor, torch.Tensor]] = None,
cached_q = None,
use_debug = False) -> Tuple[torch.Tensor, torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor]]]:
""" forward the model gamma times
Args:
prefix (torch.Tensor): the prefix
Expand All @@ -134,24 +199,20 @@ def _approx_model_serial_forward(prefix : torch.Tensor, gamma : int, approx_mode
temperature (float): temp for sampling
top_k (float): top_k for sampling
top_p (float): top_p for sampling
past_key_values : valid kv cache
cached_q: valid probability distribution of vocab on the all of the prefix tokens
Returns:
Tuple[torch.Tensor, torch.Tensor]: prefix+generated tokens, probability distribution of approx model's output
Tuple[torch.Tensor, torch.Tensor, Tuple]: prefix+generated tokens, past key value cache, probability distribution of vocab on the all of the tokens
"""
x = prefix

for _ in range(gamma):
output = approx_model(x)
q = output.logits
next_tok = sample(norm_logits(q[:, -1, :],
temperature, top_k, top_p))
q, past_key_values, cached_q = forward_with_kvcache(approx_model, x, past_key_values, cached_q, temperature, top_k, top_p, use_debug)
next_tok = sample(q)
x = torch.cat((x, next_tok), dim=1)

# normalize the logits
for i in range(q.shape[1]):
q[:,i,:] = norm_logits(q[:,i,:],
temperature, top_k, top_p)
return x, q, None
return x, past_key_values, cached_q

@torch.no_grad()
def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module,
Expand All @@ -161,7 +222,7 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul
DeepMind version Speculative Sampling.
Accelerating Large Language Model Decoding with Speculative Sampling
https://arxiv.org/abs/2302.01318
No KV Cache Optimization
Args:
x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now.
Expand All @@ -184,9 +245,19 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul
with tqdm(total=T, desc="speculative sampling") as pbar:
while prefix.shape[1] < T:
# q = M_q[prefix + x_0, x_1, .., x_(gamma-2)]
x = prefix
prefix_len = prefix.shape[1]
x, q, _ = _approx_model_serial_forward(prefix, gamma, approx_model, temperature, top_k, top_p)
for _ in range(gamma):
# p.logits shape (batch, seq, vocab)
q = approx_model(x).logits
next_tok = sample(norm_logits(q[:, -1, :],
temperature, top_k, top_p))
x = torch.cat((x, next_tok), dim=1)

# normalize the logits
for i in range(q.shape[1]):
q[:,i,:] = norm_logits(q[:,i,:],
temperature, top_k, top_p)
# p = M_p[prefix + x_0, x_0, .., x_(gamma-1)]
p = target_model(x).logits
for i in range(p.shape[1]):
Expand Down Expand Up @@ -221,14 +292,16 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul

return prefix


@torch.no_grad()
def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module,
max_len : int , gamma : int = 4,
temperature : float = 1, top_k : int = 0, top_p : float = 0, verbose : bool = False) -> torch.Tensor:
"""
Google version Speculative Sampling.
https://arxiv.org/pdf/2211.17192.pdf
Adapted with KV Cache Optimization.
Args:
x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now.
approx_model (torch.nn.Module): approx model, the small one
Expand All @@ -254,30 +327,38 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,

target_model_past_key_values = None
approx_model_past_key_values = None
#TODO() we can reduce the volume of q
cached_q = None
cached_p = None
with tqdm(total=T, desc="speculative sampling") as pbar:
while prefix.shape[1] < T:
# q = M_q[prefix + x_0, x_1, .., x_(gamma-2)]
prefix_len = prefix.shape[1]

x, q, _ = _approx_model_serial_forward(prefix, gamma, approx_model, temperature, top_k, top_p, None)
x, approx_model_past_key_values, cached_q = generate_with_kvcache(
prefix, gamma, approx_model,
temperature, top_k, top_p, approx_model_past_key_values, cached_q)

# q (batch_size, prefix_len+gamma, vocab)
# assert q.shape[-2] == gamma, f"q.shape {q.shape} dose not match gamma {gamma}"

# p = M_p[prefix + x_0, x_0, .., x_(gamma-1)]
# print(type(target_model_past_key_values))
target_model_outputs = target_model(x)
p = target_model_outputs.logits
_, target_model_past_key_values, cached_p = forward_with_kvcache(
target_model,
x,
target_model_past_key_values, cached_p,
temperature, top_k, top_p,
use_debug=False)

for i in range(p.shape[1]):
p[:,i,:] = norm_logits(p[:,i,:],
temperature, top_k, top_p)

# n the end position of the valid prefix
# x = x_[:prefix_len-1] + x_0, ... x_(gamma-1)
n = prefix_len + gamma - 1
for i in range(gamma):
r = torch.rand(1, device = p.device)
r = torch.rand(1, device = cached_q.device)
j = x[:, prefix_len + i]

if r > (p[:, prefix_len + i - 1, j]) / (q[:, prefix_len + i - 1, j]):
# print(f"cached_q {cached_q.shape}, p {cached_p.shape} prefix_len {prefix_len}")
if r > (cached_p[:, prefix_len + i - 1, j]) / (cached_q[:, prefix_len + i - 1, j]):
# reject
n = prefix_len + i - 1
break
Expand All @@ -288,19 +369,27 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module,
# print(f"n : {n}, i : {i}, prefix_len + gamma - 1: {prefix_len + gamma - 1}")
assert n >= prefix_len - 1, f"n {n}, prefix_len {prefix_len}"
prefix = x[:, :n + 1]
approx_model_past_key_values, cached_q = trim_kv_cache(approx_model_past_key_values, cached_q, n+1)
assert cached_q.shape[-2] <= n + 1, f"cached_q.shape {cached_q.shape}, n {n}"

if n < prefix_len + gamma - 1:
# reject someone, sample from the pos n
t = sample(max_fn(p[:, n, :] - q[:, n, :]))
t = sample(max_fn(cached_p[:, n, :] - cached_q[:, n, :]))
if verbose:
print(f"target resamples {n}: \033[34m{DECODER.decode(t)}\033[0m")

# target_model_past_key_values = None
# cached_p = None
target_model_past_key_values, cached_p = trim_kv_cache(target_model_past_key_values, cached_p, n+1)
else:
# all approx model decoding accepted
assert n == p.shape[1] - 1
t = sample(p[:, -1, :])
assert n == cached_p.shape[1] - 1
t = sample(cached_p[:, -1, :])
if verbose:
print(f"target samples {n}: \033[35m{DECODER.decode(t)}\033[0m")

target_model_past_key_values, cached_p = trim_kv_cache(target_model_past_key_values, cached_p, n+2)


prefix = torch.cat((prefix, t), dim=1)

if not verbose:
Expand Down

0 comments on commit c3e97f5

Please sign in to comment.