diff --git a/README.md b/README.md index b4cf971..aeeb613 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Therefore, it decodes by running the target model in parallel on the outputs of The speculative sampling is proposed by Google and Deepmind independently. So I implement two slightly different versions of speculative sampling: [Google's](https://arxiv.org/abs/2211.17192) and [Deepmind's](https://arxiv.org/abs/2302.01318). +## Update Logs + +2023.09.19 Add KVCache Optimization to the Google's version. + +2023.08.16 First release, implement the paper's algorithm. ## Usage In the sample, I use [bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1/tree/main) as the target model, [bloom-560m](https://huggingface.co/bigscience/bloom-560m/tree/main) as the approximation model. diff --git a/sample.py b/sample.py index eaea18f..e368381 100644 --- a/sample.py +++ b/sample.py @@ -57,7 +57,7 @@ def norm_logits(logits : torch.Tensor, temperature : float, top_k : float, top_p """ Args: - logits (torch.Tensor): shape (batch, seqlen, vocab) + logits (torch.Tensor): shape (1, vocab) temperature (float): temperature top_k (float): top_k top_p (float): top_p @@ -65,6 +65,7 @@ def norm_logits(logits : torch.Tensor, temperature : float, top_k : float, top_p Returns: torch.Tensor: next token with shape as (batch, 1) """ + assert logits.dim() == 2 logits = logits / temperature logits = top_k_top_p_filter(logits, top_k=top_k, top_p=top_p) probs = F.softmax(logits, dim=1) @@ -80,12 +81,75 @@ def sample(probs : torch.Tensor, num_samples: int = 1): def _debug_show_kvcache(past_key_values): if past_key_values is None: return - print("show k,v cache shapes") for elem in past_key_values: k, v = elem - print(f"k shape {k.shape}, v shape {v.shape}") + print(f"kv cache: k shape {k.shape}, v shape {v.shape}") break +def trim_kv_cache(past_key_values : Tuple[Tuple[torch.Tensor, torch.Tensor]], q : torch.Tensor, end_pos : int): + """ + trim the KV cache to the end_pos + + Args: + past_key_values (Tuple): KV Cache + end_pos (int): the position of the valid prefix + + Returns: + Tuple: the trimmed KV Cache + """ + past_key_values_trimmed = [] + for kv in past_key_values: + k, v = kv + # NOTE() the indexing is specific for bloom. This won't work for other models + # For example llama k, v should be (batch, num_head, seq_len, hidden_dim) + k = k[:, :, :end_pos] + v = v[:, :end_pos, :] + kv_trimmed = (k, v) + past_key_values_trimmed.append(kv_trimmed) + + q = q[:, :end_pos, :] + return past_key_values_trimmed, q + + +def forward_with_kvcache(model, input_ids, past_key_values, cached_q, temperature, top_k, top_p, use_debug = False): + if past_key_values is None: + assert cached_q is None + # the first forward returns the prompt's logits + outputs = model(input_ids) + cached_q = outputs.logits + for i in range(cached_q.shape[-2]): + cached_q[:, i, :] = norm_logits(cached_q[:, i, :], temperature, top_k, top_p) + last_q = cached_q[:, -1, :] + else: + # return the last token's logits + cached_len = 0 + for kv in past_key_values: + k, v = kv + cached_len = k.shape[2] + + last_input_id = input_ids[:, cached_len:] + if last_input_id.dim() == 1: + last_input_id = torch.unsqueeze(last_input_id, 0) + + if use_debug: + print(f"last_input_id shape {last_input_id.shape}") + _debug_show_kvcache(past_key_values) + + outputs = model(last_input_id, past_key_values=past_key_values, use_cache=True) + + not_cached_q = outputs.logits + if not_cached_q.dim() == 2: + not_cached_q = torch.unsqueeze(not_cached_q, 0) + + for i in range(not_cached_q.shape[-2]): + not_cached_q[:, i, :] = norm_logits(not_cached_q[:, i, :], temperature, top_k, top_p) + + if cached_q is not None: + cached_q = torch.cat([cached_q, not_cached_q], dim=1) + last_q = not_cached_q[:, -1, :] + + return last_q, outputs.past_key_values, cached_q + @torch.no_grad() def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int, temperature : float = 1, top_k : int = 0, top_p : float = 0): @@ -95,9 +159,11 @@ def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int, past_key_values = None with tqdm(total=N, desc="autoregressive sampling") as pbar: while n < T: - outputs = model(x) - logits = outputs.logits[::, -1, :] - idx_next = sample(norm_logits(logits, temperature, top_k, top_p)) + # outputs = model(x) + last_q, past_key_values, _ = forward_with_kvcache(model, x, past_key_values, None, temperature, top_k, top_p) + # logits = outputs.logits[::, -1, :] + # past_key_values = outputs.past_key_values + idx_next = sample(last_q) x = torch.cat((x, idx_next), dim=1) n += 1 pbar.update(1) @@ -114,18 +180,17 @@ def max_fn(x): x_max_sum = torch.sum(x_max, dim=1, keepdim=True) return x_max / x_max_sum -# def norm_logits(p : torch.Tensor): -# """ -# normalize logits using softmax to probabilities along the last dimension. -# """ -# return F.softmax(p, dim=-1) - - -def _approx_model_serial_forward(prefix : torch.Tensor, gamma : int, approx_model : torch.nn.Module, - temperature : float, top_k : float, top_p : float, - past_key_values : Tuple[Tuple[torch.Tensor, torch.Tensor]] = None,) -> Tuple[torch.Tensor, torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor]]]: - """ forward approx model gamma times +def generate_with_kvcache(prefix : torch.Tensor, + gamma : int, + approx_model : torch.nn.Module, + temperature : float, + top_k : float, + top_p : float, + past_key_values : Tuple[Tuple[torch.Tensor, torch.Tensor]] = None, + cached_q = None, + use_debug = False) -> Tuple[torch.Tensor, torch.Tensor, Tuple[Tuple[torch.Tensor, torch.Tensor]]]: + """ forward the model gamma times Args: prefix (torch.Tensor): the prefix @@ -134,24 +199,20 @@ def _approx_model_serial_forward(prefix : torch.Tensor, gamma : int, approx_mode temperature (float): temp for sampling top_k (float): top_k for sampling top_p (float): top_p for sampling + past_key_values : valid kv cache + cached_q: valid probability distribution of vocab on the all of the prefix tokens Returns: - Tuple[torch.Tensor, torch.Tensor]: prefix+generated tokens, probability distribution of approx model's output + Tuple[torch.Tensor, torch.Tensor, Tuple]: prefix+generated tokens, past key value cache, probability distribution of vocab on the all of the tokens """ x = prefix - + for _ in range(gamma): - output = approx_model(x) - q = output.logits - next_tok = sample(norm_logits(q[:, -1, :], - temperature, top_k, top_p)) + q, past_key_values, cached_q = forward_with_kvcache(approx_model, x, past_key_values, cached_q, temperature, top_k, top_p, use_debug) + next_tok = sample(q) x = torch.cat((x, next_tok), dim=1) - # normalize the logits - for i in range(q.shape[1]): - q[:,i,:] = norm_logits(q[:,i,:], - temperature, top_k, top_p) - return x, q, None + return x, past_key_values, cached_q @torch.no_grad() def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, @@ -161,7 +222,7 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul DeepMind version Speculative Sampling. Accelerating Large Language Model Decoding with Speculative Sampling https://arxiv.org/abs/2302.01318 - + No KV Cache Optimization Args: x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now. @@ -184,9 +245,19 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul with tqdm(total=T, desc="speculative sampling") as pbar: while prefix.shape[1] < T: # q = M_q[prefix + x_0, x_1, .., x_(gamma-2)] + x = prefix prefix_len = prefix.shape[1] - x, q, _ = _approx_model_serial_forward(prefix, gamma, approx_model, temperature, top_k, top_p) + for _ in range(gamma): + # p.logits shape (batch, seq, vocab) + q = approx_model(x).logits + next_tok = sample(norm_logits(q[:, -1, :], + temperature, top_k, top_p)) + x = torch.cat((x, next_tok), dim=1) + # normalize the logits + for i in range(q.shape[1]): + q[:,i,:] = norm_logits(q[:,i,:], + temperature, top_k, top_p) # p = M_p[prefix + x_0, x_0, .., x_(gamma-1)] p = target_model(x).logits for i in range(p.shape[1]): @@ -221,14 +292,16 @@ def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Modul return prefix - @torch.no_grad() def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, max_len : int , gamma : int = 4, temperature : float = 1, top_k : int = 0, top_p : float = 0, verbose : bool = False) -> torch.Tensor: """ Google version Speculative Sampling. - + https://arxiv.org/pdf/2211.17192.pdf + + Adapted with KV Cache Optimization. + Args: x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now. approx_model (torch.nn.Module): approx model, the small one @@ -254,30 +327,38 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model_past_key_values = None approx_model_past_key_values = None + #TODO() we can reduce the volume of q + cached_q = None + cached_p = None with tqdm(total=T, desc="speculative sampling") as pbar: while prefix.shape[1] < T: # q = M_q[prefix + x_0, x_1, .., x_(gamma-2)] prefix_len = prefix.shape[1] - x, q, _ = _approx_model_serial_forward(prefix, gamma, approx_model, temperature, top_k, top_p, None) + x, approx_model_past_key_values, cached_q = generate_with_kvcache( + prefix, gamma, approx_model, + temperature, top_k, top_p, approx_model_past_key_values, cached_q) + + # q (batch_size, prefix_len+gamma, vocab) + # assert q.shape[-2] == gamma, f"q.shape {q.shape} dose not match gamma {gamma}" # p = M_p[prefix + x_0, x_0, .., x_(gamma-1)] - # print(type(target_model_past_key_values)) - target_model_outputs = target_model(x) - p = target_model_outputs.logits + _, target_model_past_key_values, cached_p = forward_with_kvcache( + target_model, + x, + target_model_past_key_values, cached_p, + temperature, top_k, top_p, + use_debug=False) - for i in range(p.shape[1]): - p[:,i,:] = norm_logits(p[:,i,:], - temperature, top_k, top_p) - # n the end position of the valid prefix # x = x_[:prefix_len-1] + x_0, ... x_(gamma-1) n = prefix_len + gamma - 1 for i in range(gamma): - r = torch.rand(1, device = p.device) + r = torch.rand(1, device = cached_q.device) j = x[:, prefix_len + i] - if r > (p[:, prefix_len + i - 1, j]) / (q[:, prefix_len + i - 1, j]): + # print(f"cached_q {cached_q.shape}, p {cached_p.shape} prefix_len {prefix_len}") + if r > (cached_p[:, prefix_len + i - 1, j]) / (cached_q[:, prefix_len + i - 1, j]): # reject n = prefix_len + i - 1 break @@ -288,19 +369,27 @@ def speculative_sampling(prefix : torch.Tensor, approx_model : torch.nn.Module, # print(f"n : {n}, i : {i}, prefix_len + gamma - 1: {prefix_len + gamma - 1}") assert n >= prefix_len - 1, f"n {n}, prefix_len {prefix_len}" prefix = x[:, :n + 1] + approx_model_past_key_values, cached_q = trim_kv_cache(approx_model_past_key_values, cached_q, n+1) + assert cached_q.shape[-2] <= n + 1, f"cached_q.shape {cached_q.shape}, n {n}" if n < prefix_len + gamma - 1: # reject someone, sample from the pos n - t = sample(max_fn(p[:, n, :] - q[:, n, :])) + t = sample(max_fn(cached_p[:, n, :] - cached_q[:, n, :])) if verbose: print(f"target resamples {n}: \033[34m{DECODER.decode(t)}\033[0m") + + # target_model_past_key_values = None + # cached_p = None + target_model_past_key_values, cached_p = trim_kv_cache(target_model_past_key_values, cached_p, n+1) else: # all approx model decoding accepted - assert n == p.shape[1] - 1 - t = sample(p[:, -1, :]) + assert n == cached_p.shape[1] - 1 + t = sample(cached_p[:, -1, :]) if verbose: print(f"target samples {n}: \033[35m{DECODER.decode(t)}\033[0m") - + target_model_past_key_values, cached_p = trim_kv_cache(target_model_past_key_values, cached_p, n+2) + + prefix = torch.cat((prefix, t), dim=1) if not verbose: