From b924a9ebfcdb29f034f5f991802258098fe6a243 Mon Sep 17 00:00:00 2001
From: Douglas Hanley <thesecretaryofwar@gmail.com>
Date: Wed, 21 Feb 2024 16:34:15 -0600
Subject: [PATCH 1/3] first shot at parallel generation

---
 llama_cpp/_internals.py |  11 ++++
 llama_cpp/llama.py      | 114 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 125 insertions(+)

diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py
index c60fdff7b..3c7cc58ae 100644
--- a/llama_cpp/_internals.py
+++ b/llama_cpp/_internals.py
@@ -530,6 +530,17 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
             self.batch.logits[i] = logits_all
         self.batch.logits[n_tokens - 1] = True
 
+    def set_batch_parallel(self, batch: Sequence[int], position: int, logits_all: bool):
+        assert self.batch is not None
+        n_tokens = len(batch)
+        self.batch.n_tokens = n_tokens
+        for i in range(n_tokens):
+            self.batch.token[i] = batch[i]
+            self.batch.pos[i] = position
+            self.batch.seq_id[i][0] = i
+            self.batch.n_seq_id[i] = 1
+            self.batch.logits[i] = logits_all
+
     def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
         assert self.batch is not None
         n_tokens = len(batch)
diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index 30cab0af9..c811fd9c8 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -523,6 +523,27 @@ def eval(self, tokens: Sequence[int]):
             # Update n_tokens
             self.n_tokens += n_tokens
 
+    def eval_parallel(self, tokens: Sequence[int]):
+        """Evaluate a list of tokens in different sequences but at the same position.
+
+        Args:
+            position: The position to evaluate the tokens at.
+            tokens: The list of tokens to evaluate.
+        """
+        assert self._ctx.ctx is not None
+        assert self._batch.batch is not None
+        self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
+
+        n_past = self.n_tokens
+        n_tokens = len(tokens)
+        self._batch.set_batch_parallel(batch=tokens, position=n_past, logits_all=True)
+        self._ctx.decode(self._batch)
+        # Save logits
+        size = n_tokens * self._n_vocab
+        self._scores.reshape(-1)[:size] = self._ctx.get_logits()[:size]
+        # Update n_tokens
+        self.n_tokens += 1
+
     def sample(
         self,
         top_k: int = 40,
@@ -714,6 +735,91 @@ def generate(
                     ]
                 )
 
+    @staticmethod
+    def longest_common_prefix(vecs):
+        if (max_len := min([len(s) for s in vecs], default=0)) == 0:
+            return []
+        for i in range(max_len):
+            if len(set([s[i] for s in vecs])) > 1:
+                return vecs[0][:i]
+        else:
+            return vecs[0][:max_len]
+
+    def generate_parallel(self, prompts, max_tokens=None, **kwargs):
+        # tokenize the prompts
+        n_parallel = len(prompts)
+        tokens_list = [self.tokenize(p.encode("utf-8")) for p in prompts]
+        ntoks_list = [len(toks) for toks in tokens_list]
+
+        # set default max_tokens
+        max_prompt = max(ntoks_list)
+        if max_tokens is None:
+            max_tokens = self._n_ctx // n_parallel - max_prompt
+        max_length = max_prompt + max_tokens
+
+        # check for overflows
+        if max_tokens <= 0:
+            raise ValueError(f"Maximum number of tokens exceeded")
+
+        # Run longest prefix in serial to populate kv cache. In the simplest case we look for the
+        # longest common prefix, but in general we could look for prefixes that are shared by certain
+        # subsets of the prompts.
+
+        # find the longest common prefix
+        prefix_tokens = self.longest_common_prefix(tokens_list)
+        prefix_len = len(prefix_tokens)
+
+        # reset batch and run prefix eval
+        self.reset()
+        self.eval(prefix_tokens)
+
+        # copy the kv_cache to other streams
+        for i in range(n_parallel):
+            llama_cpp.llama_kv_cache_seq_cp(self.ctx, 0, i, 0, prefix_len - 1)
+
+        # remember the batch index of the last token for each parallel sequence
+        i_batch = [prefix_len - 1 for _ in range(n_parallel)]
+
+        # since the prompts may be of different lengths, just yield the common prefix
+        for i in range(prefix_len):
+            result = [tokens_list[j][i] for j in range(n_parallel)]
+            yield result
+
+        # run the decoding loop
+        for k in range(prefix_len, max_length):
+            # sample the next token for each parallel sequence / stream
+            new_ids = []
+            for i in range(n_parallel):
+                # if the stream has already finished
+                if i_batch[i] < 0:
+                    continue
+
+                # see if we're still in the prompt
+                if k < ntoks_list[i]:
+                    new_id = tokens_list[i][k]
+                else:
+                    # get last logits and sample a new token
+                    new_id = self.sample(idx=i_batch[i], **kwargs)
+
+                    # is it an end of stream? -> mark the stream as finished
+                    if new_id == self._token_eos:
+                        i_batch[i] = -1
+                        continue
+
+                # increment counters
+                i_batch[i] = len(new_ids)
+                new_ids.append(new_id)
+
+            # check for done or run next eval
+            if len(new_ids) == 0:
+                break
+            else:
+                self.eval_parallel(new_ids)
+
+            # yield new tokens
+            result = [new_ids[j] if j >= 0 else None for j in i_batch]
+            yield result
+
     def create_embedding(
         self, input: Union[str, List[str]], model: Optional[str] = None
     ) -> CreateEmbeddingResponse:
@@ -1460,6 +1566,14 @@ def create_completion(
         completion: Completion = next(completion_or_chunks)  # type: ignore
         return completion
 
+    def create_completion_parallel(self, prompts, **kwargs):
+        streams = ["" for _ in prompts]
+        for toks in self.generate_parallel(prompts, **kwargs):
+            for i, tok in enumerate(toks):
+                if tok is not None:
+                    streams[i] += self.detokenize([tok]).decode("utf-8")
+        return streams
+
     def __call__(
         self,
         prompt: str,

From c82bcf3abc04093ce62fc64eb5bd85c9114f4a52 Mon Sep 17 00:00:00 2001
From: Douglas Hanley <thesecretaryofwar@gmail.com>
Date: Wed, 21 Feb 2024 20:05:10 -0600
Subject: [PATCH 2/3] better stream options

---
 llama_cpp/llama.py | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index c811fd9c8..33a8cbf62 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -1566,13 +1566,19 @@ def create_completion(
         completion: Completion = next(completion_or_chunks)  # type: ignore
         return completion
 
-    def create_completion_parallel(self, prompts, **kwargs):
-        streams = ["" for _ in prompts]
+    def _create_completion_parallel(self, prompts, stream=False, **kwargs):
         for toks in self.generate_parallel(prompts, **kwargs):
-            for i, tok in enumerate(toks):
-                if tok is not None:
-                    streams[i] += self.detokenize([tok]).decode("utf-8")
-        return streams
+            yield [
+                self.detokenize([tok]).decode("utf-8") if tok is not None else ""
+                for tok in toks
+            ]
+
+    def create_completion_parallel(self, prompts, stream=False, **kwargs):
+        genpar = self._create_completion_parallel(prompts, **kwargs)
+        if stream:
+            return genpar
+        else:
+            return ["".join(toks) for toks in zip(*genpar)]
 
     def __call__(
         self,

From 4787ec37086749401916777d5b728a08f13bd932 Mon Sep 17 00:00:00 2001
From: Douglas Hanley <thesecretaryofwar@gmail.com>
Date: Wed, 21 Feb 2024 22:56:37 -0600
Subject: [PATCH 3/3] cleanup and type hints

---
 llama_cpp/llama.py | 47 +++++++++++++++++++++++++++++-----------------
 1 file changed, 30 insertions(+), 17 deletions(-)

diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index 33a8cbf62..bfda45ef8 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -523,7 +523,7 @@ def eval(self, tokens: Sequence[int]):
             # Update n_tokens
             self.n_tokens += n_tokens
 
-    def eval_parallel(self, tokens: Sequence[int]):
+    def eval_parallel(self, tokens: List[int]):
         """Evaluate a list of tokens in different sequences but at the same position.
 
         Args:
@@ -745,14 +745,18 @@ def longest_common_prefix(vecs):
         else:
             return vecs[0][:max_len]
 
-    def generate_parallel(self, prompts, max_tokens=None, **kwargs):
-        # tokenize the prompts
-        n_parallel = len(prompts)
-        tokens_list = [self.tokenize(p.encode("utf-8")) for p in prompts]
-        ntoks_list = [len(toks) for toks in tokens_list]
+    def generate_parallel(
+        self,
+        tokens: List[List[int]],
+        max_tokens: Optional[int] = None,
+        **kwargs
+    ) -> Iterator[List[int]]:
+        # get prompt and token counts
+        n_parallel = len(tokens)
+        n_tokens = [len(toks) for toks in tokens]
 
         # set default max_tokens
-        max_prompt = max(ntoks_list)
+        max_prompt = max(n_tokens)
         if max_tokens is None:
             max_tokens = self._n_ctx // n_parallel - max_prompt
         max_length = max_prompt + max_tokens
@@ -766,7 +770,7 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs):
         # subsets of the prompts.
 
         # find the longest common prefix
-        prefix_tokens = self.longest_common_prefix(tokens_list)
+        prefix_tokens = self.longest_common_prefix(tokens)
         prefix_len = len(prefix_tokens)
 
         # reset batch and run prefix eval
@@ -782,7 +786,7 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs):
 
         # since the prompts may be of different lengths, just yield the common prefix
         for i in range(prefix_len):
-            result = [tokens_list[j][i] for j in range(n_parallel)]
+            result = [tokens[j][i] for j in range(n_parallel)]
             yield result
 
         # run the decoding loop
@@ -795,8 +799,8 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs):
                     continue
 
                 # see if we're still in the prompt
-                if k < ntoks_list[i]:
-                    new_id = tokens_list[i][k]
+                if k < n_tokens[i]:
+                    new_id = tokens[i][k]
                 else:
                     # get last logits and sample a new token
                     new_id = self.sample(idx=i_batch[i], **kwargs)
@@ -817,8 +821,7 @@ def generate_parallel(self, prompts, max_tokens=None, **kwargs):
                 self.eval_parallel(new_ids)
 
             # yield new tokens
-            result = [new_ids[j] if j >= 0 else None for j in i_batch]
-            yield result
+            yield [new_ids[j] if j >= 0 else None for j in i_batch]
 
     def create_embedding(
         self, input: Union[str, List[str]], model: Optional[str] = None
@@ -1566,15 +1569,25 @@ def create_completion(
         completion: Completion = next(completion_or_chunks)  # type: ignore
         return completion
 
-    def _create_completion_parallel(self, prompts, stream=False, **kwargs):
-        for toks in self.generate_parallel(prompts, **kwargs):
+    def _create_completion_parallel(
+        self,
+        prompts: List[str],
+        **kwargs
+    ) -> Iterator[List[str]]:
+        tokens: List[List[int]] = [self.tokenize(p.encode("utf-8")) for p in prompts]
+        for toks in self.generate_parallel(tokens, **kwargs):
             yield [
                 self.detokenize([tok]).decode("utf-8") if tok is not None else ""
                 for tok in toks
             ]
 
-    def create_completion_parallel(self, prompts, stream=False, **kwargs):
-        genpar = self._create_completion_parallel(prompts, **kwargs)
+    def create_completion_parallel(
+        self,
+        prompts: List[str],
+        stream: bool = False,
+        **kwargs
+    ) -> List[str]:
+        genpar: Iterator[List[str]] = self._create_completion_parallel(prompts, **kwargs)
         if stream:
             return genpar
         else: