From b924a9ebfcdb29f034f5f991802258098fe6a243 Mon Sep 17 00:00:00 2001 From: Douglas Hanley <thesecretaryofwar@gmail.com> Date: Wed, 21 Feb 2024 16:34:15 -0600 Subject: [PATCH 1/3] first shot at parallel generation --- llama_cpp/_internals.py | 11 ++++ llama_cpp/llama.py | 114 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index c60fdff7b..3c7cc58ae 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -530,6 +530,17 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): self.batch.logits[i] = logits_all self.batch.logits[n_tokens - 1] = True + def set_batch_parallel(self, batch: Sequence[int], position: int, logits_all: bool): + assert self.batch is not None + n_tokens = len(batch) + self.batch.n_tokens = n_tokens + for i in range(n_tokens): + self.batch.token[i] = batch[i] + self.batch.pos[i] = position + self.batch.seq_id[i][0] = i + self.batch.n_seq_id[i] = 1 + self.batch.logits[i] = logits_all + def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): assert self.batch is not None n_tokens = len(batch) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 30cab0af9..c811fd9c8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -523,6 +523,27 @@ def eval(self, tokens: Sequence[int]): # Update n_tokens self.n_tokens += n_tokens + def eval_parallel(self, tokens: Sequence[int]): + """Evaluate a list of tokens in different sequences but at the same position. + + Args: + position: The position to evaluate the tokens at. + tokens: The list of tokens to evaluate. + """ + assert self._ctx.ctx is not None + assert self._batch.batch is not None + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + + n_past = self.n_tokens + n_tokens = len(tokens) + self._batch.set_batch_parallel(batch=tokens, position=n_past, logits_all=True) + self._ctx.decode(self._batch) + # Save logits + size = n_tokens * self._n_vocab + self._scores.reshape(-1)[:size] = self._ctx.get_logits()[:size] + # Update n_tokens + self.n_tokens += 1 + def sample( self, top_k: int = 40, @@ -714,6 +735,91 @@ def generate( ] ) + @staticmethod + def longest_common_prefix(vecs): + if (max_len := min([len(s) for s in vecs], default=0)) == 0: + return [] + for i in range(max_len): + if len(set([s[i] for s in vecs])) > 1: + return vecs[0][:i] + else: + return vecs[0][:max_len] + + def generate_parallel(self, prompts, max_tokens=None, **kwargs): + # tokenize the prompts + n_parallel = len(prompts) + tokens_list = [self.tokenize(p.encode("utf-8")) for p in prompts] + ntoks_list = [len(toks) for toks in tokens_list] + + # set default max_tokens + max_prompt = max(ntoks_list) + if max_tokens is None: + max_tokens = self._n_ctx // n_parallel - max_prompt + max_length = max_prompt + max_tokens + + # check for overflows + if max_tokens <= 0: + raise ValueError(f"Maximum number of tokens exceeded") + + # Run longest prefix in serial to populate kv cache. In the simplest case we look for the + # longest common prefix, but in general we could look for prefixes that are shared by certain + # subsets of the prompts. + + # find the longest common prefix + prefix_tokens = self.longest_common_prefix(tokens_list) + prefix_len = len(prefix_tokens) + + # reset batch and run prefix eval + self.reset() + self.eval(prefix_tokens) + + # copy the kv_cache to other streams + for i in range(n_parallel): + llama_cpp.llama_kv_cache_seq_cp(self.ctx, 0, i, 0, prefix_len - 1) + + # remember the batch index of the last token for each parallel sequence + i_batch = [prefix_len - 1 for _ in range(n_parallel)] + + # since the prompts may be of different lengths, just yield the common prefix + for i in range(prefix_len): + result = [tokens_list[j][i] for j in range(n_parallel)] + yield result + + # run the decoding loop + for k in range(prefix_len, max_length): + # sample the next token for each parallel sequence / stream + new_ids = [] + for i in range(n_parallel): + # if the stream has already finished + if i_batch[i] < 0: + continue + + # see if we're still in the prompt + if k < ntoks_list[i]: + new_id = tokens_list[i][k] + else: + # get last logits and sample a new token + new_id = self.sample(idx=i_batch[i], **kwargs) + + # is it an end of stream? -> mark the stream as finished + if new_id == self._token_eos: + i_batch[i] = -1 + continue + + # increment counters + i_batch[i] = len(new_ids) + new_ids.append(new_id) + + # check for done or run next eval + if len(new_ids) == 0: + break + else: + self.eval_parallel(new_ids) + + # yield new tokens + result = [new_ids[j] if j >= 0 else None for j in i_batch] + yield result + def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None ) -> CreateEmbeddingResponse: @@ -1460,6 +1566,14 @@ def create_completion( completion: Completion = next(completion_or_chunks) # type: ignore return completion + def create_completion_parallel(self, prompts, **kwargs): + streams = ["" for _ in prompts] + for toks in self.generate_parallel(prompts, **kwargs): + for i, tok in enumerate(toks): + if tok is not None: + streams[i] += self.detokenize([tok]).decode("utf-8") + return streams + def __call__( self, prompt: str, From c82bcf3abc04093ce62fc64eb5bd85c9114f4a52 Mon Sep 17 00:00:00 2001 From: Douglas Hanley <thesecretaryofwar@gmail.com> Date: Wed, 21 Feb 2024 20:05:10 -0600 Subject: [PATCH 2/3] better stream options --- llama_cpp/llama.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c811fd9c8..33a8cbf62 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1566,13 +1566,19 @@ def create_completion( completion: Completion = next(completion_or_chunks) # type: ignore return completion - def create_completion_parallel(self, prompts, **kwargs): - streams = ["" for _ in prompts] + def _create_completion_parallel(self, prompts, stream=False, **kwargs): for toks in self.generate_parallel(prompts, **kwargs): - for i, tok in enumerate(toks): - if tok is not None: - streams[i] += self.detokenize([tok]).decode("utf-8") - return streams + yield [ + self.detokenize([tok]).decode("utf-8") if tok is not None else "" + for tok in toks + ] + + def create_completion_parallel(self, prompts, stream=False, **kwargs): + genpar = self._create_completion_parallel(prompts, **kwargs) + if stream: + return genpar + else: + return ["".join(toks) for toks in zip(*genpar)] def __call__( self, From 4787ec37086749401916777d5b728a08f13bd932 Mon Sep 17 00:00:00 2001 From: Douglas Hanley <thesecretaryofwar@gmail.com> Date: Wed, 21 Feb 2024 22:56:37 -0600 Subject: [PATCH 3/3] cleanup and type hints --- llama_cpp/llama.py | 47 +++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 33a8cbf62..bfda45ef8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -523,7 +523,7 @@ def eval(self, tokens: Sequence[int]): # Update n_tokens self.n_tokens += n_tokens - def eval_parallel(self, tokens: Sequence[int]): + def eval_parallel(self, tokens: List[int]): """Evaluate a list of tokens in different sequences but at the same position. Args: @@ -745,14 +745,18 @@ def longest_common_prefix(vecs): else: return vecs[0][:max_len] - def generate_parallel(self, prompts, max_tokens=None, **kwargs): - # tokenize the prompts - n_parallel = len(prompts) - tokens_list = [self.tokenize(p.encode("utf-8")) for p in prompts] - ntoks_list = [len(toks) for toks in tokens_list] + def generate_parallel( + self, + tokens: List[List[int]], + max_tokens: Optional[int] = None, + **kwargs + ) -> Iterator[List[int]]: + # get prompt and token counts + n_parallel = len(tokens) + n_tokens = [len(toks) for toks in tokens] # set default max_tokens - max_prompt = max(ntoks_list) + max_prompt = max(n_tokens) if max_tokens is None: max_tokens = self._n_ctx // n_parallel - max_prompt max_length = max_prompt + max_tokens @@ -766,7 +770,7 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs): # subsets of the prompts. # find the longest common prefix - prefix_tokens = self.longest_common_prefix(tokens_list) + prefix_tokens = self.longest_common_prefix(tokens) prefix_len = len(prefix_tokens) # reset batch and run prefix eval @@ -782,7 +786,7 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs): # since the prompts may be of different lengths, just yield the common prefix for i in range(prefix_len): - result = [tokens_list[j][i] for j in range(n_parallel)] + result = [tokens[j][i] for j in range(n_parallel)] yield result # run the decoding loop @@ -795,8 +799,8 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs): continue # see if we're still in the prompt - if k < ntoks_list[i]: - new_id = tokens_list[i][k] + if k < n_tokens[i]: + new_id = tokens[i][k] else: # get last logits and sample a new token new_id = self.sample(idx=i_batch[i], **kwargs) @@ -817,8 +821,7 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs): self.eval_parallel(new_ids) # yield new tokens - result = [new_ids[j] if j >= 0 else None for j in i_batch] - yield result + yield [new_ids[j] if j >= 0 else None for j in i_batch] def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None @@ -1566,15 +1569,25 @@ def create_completion( completion: Completion = next(completion_or_chunks) # type: ignore return completion - def _create_completion_parallel(self, prompts, stream=False, **kwargs): - for toks in self.generate_parallel(prompts, **kwargs): + def _create_completion_parallel( + self, + prompts: List[str], + **kwargs + ) -> Iterator[List[str]]: + tokens: List[List[int]] = [self.tokenize(p.encode("utf-8")) for p in prompts] + for toks in self.generate_parallel(tokens, **kwargs): yield [ self.detokenize([tok]).decode("utf-8") if tok is not None else "" for tok in toks ] - def create_completion_parallel(self, prompts, stream=False, **kwargs): - genpar = self._create_completion_parallel(prompts, **kwargs) + def create_completion_parallel( + self, + prompts: List[str], + stream: bool = False, + **kwargs + ) -> List[str]: + genpar: Iterator[List[str]] = self._create_completion_parallel(prompts, **kwargs) if stream: return genpar else: