Skip to content

Use numpy for internal buffers #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 30, 2023
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added first version of the changelog
- Use numpy for internal buffers to reduce memory usage and improve performance.

### Fixed

Expand Down
90 changes: 58 additions & 32 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from . import llama_cpp
from .llama_types import *

import numpy as np
import numpy.typing as npt


class LlamaCache:
"""Cache for a llama.cpp model."""
Expand Down Expand Up @@ -73,11 +76,15 @@ def __init__(
self,
eval_tokens: Deque[int],
eval_logits: Deque[List[float]],
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state_size: int,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
self.input_ids = input_ids
self.scores = scores
self.llama_state = llama_state
self.llama_state_size = llama_state_size

Expand Down Expand Up @@ -207,27 +214,27 @@ def __init__(

self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
data = (llama_cpp.llama_token_data * self._n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=llama_cpp.c_float(0.0),
p=llama_cpp.c_float(0.0),
)
for i in range(self._n_vocab)
]
)
size = llama_cpp.c_size_t(self._n_vocab)
sorted = False
sorted = llama_cpp.c_bool(False)
self._candidates_data = np.array(
[],
dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
),
)
self._candidates_data.resize(3, self._n_vocab)
candidates = llama_cpp.llama_token_data_array(
data=data,
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
size=size,
sorted=sorted,
)
self._candidates = candidates
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()

self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)

def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.

Expand Down Expand Up @@ -295,6 +302,8 @@ def reset(self):
"""Reset the model state."""
self.eval_tokens.clear()
self.eval_logits.clear()
self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)

def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
Expand All @@ -306,7 +315,7 @@ def eval(self, tokens: Sequence[int]):
n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
n_past = min(n_ctx - len(batch), len(self._input_ids))
n_tokens = len(batch)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
Expand All @@ -319,13 +328,19 @@ def eval(self, tokens: Sequence[int]):
raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens
self.eval_tokens.extend(batch)
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
)
# Save logits
rows = n_tokens if self.params.logits_all else 1
n_vocab = self._n_vocab
cols = n_vocab
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
self.eval_logits.extend(logits)
self._scores: npt.NDArray[np.single] = np.concatenate(
(self._scores, np.array(logits, dtype=np.single)), axis=0
)

def _sample(
self,
Expand All @@ -346,6 +361,7 @@ def _sample(
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
assert self._scores.shape[0] > 0
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
Expand All @@ -354,18 +370,23 @@ def _sample(
if last_n_tokens_size.value < 0
else last_n_tokens_size
)
logits = self.eval_logits[-1]
logits: npt.NDArray[np.single] = self._scores[-1, :]

if logits_processor is not None:
logits = logits_processor(list(self.eval_tokens), logits)
self.eval_logits[-1] = logits
logits = np.array(
logits_processor(self._input_ids.tolist(), logits.tolist()),
dtype=np.single,
)
self._scores[-1, :] = logits
self.eval_logits[-1] = logits.tolist()

nl_logit = logits[self._token_nl]
candidates = self._candidates
for i, logit in enumerate(logits):
candidates.data[i].id = llama_cpp.llama_token(i)
candidates.data[i].logit = llama_cpp.c_float(logit)
candidates.data[i].p = llama_cpp.c_float(0.0)
candidates_data = self._candidates_data
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
candidates_data["logit"] = logits
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
candidates.sorted = llama_cpp.c_bool(False)
candidates.size = llama_cpp.c_size_t(n_vocab)
llama_cpp.llama_sample_repetition_penalty(
Expand Down Expand Up @@ -483,8 +504,8 @@ def sample(
"""
assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
0, self.last_n_tokens_size - len(self._input_ids)
) + self._input_ids[-self.last_n_tokens_size :].tolist()
return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data
Expand Down Expand Up @@ -542,9 +563,9 @@ def generate(
"""
assert self.ctx is not None

if reset and len(self.eval_tokens) > 0:
if reset and len(self._input_ids) > 0:
longest_prefix = 0
for a, b in zip(self.eval_tokens, tokens[:-1]):
for a, b in zip(self._input_ids, tokens[:-1]):
if a == b:
longest_prefix += 1
else:
Expand All @@ -554,6 +575,8 @@ def generate(
print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False
tokens = tokens[longest_prefix:]
self._input_ids = self._input_ids[:longest_prefix]
self._scores = self._scores[:longest_prefix, :]
for _ in range(len(self.eval_tokens) - longest_prefix):
self.eval_tokens.pop()
try:
Expand All @@ -580,7 +603,7 @@ def generate(
logits_processor=logits_processor,
)
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
return
tokens_or_none = yield token
Expand Down Expand Up @@ -715,10 +738,10 @@ def _create_completion(
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.eval_tokens, prompt_tokens
cache_item.input_ids.tolist(), prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self.eval_tokens, prompt_tokens
self._input_ids.tolist(), prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
Expand Down Expand Up @@ -807,7 +830,7 @@ def _create_completion(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self.eval_logits[token_offset - 1]
logits = self._scores[token_offset - 1, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
Expand Down Expand Up @@ -856,7 +879,7 @@ def _create_completion(
break

if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
text = self.detokenize(completion_tokens)
finish_reason = "stop"
Expand Down Expand Up @@ -886,7 +909,7 @@ def _create_completion(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self.eval_logits[token_offset]
logits = self._scores[token_offset, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
Expand Down Expand Up @@ -988,8 +1011,7 @@ def _create_completion(
for token in all_tokens
]
all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
][token_offset:]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
Expand Down Expand Up @@ -1373,6 +1395,8 @@ def save_state(self) -> LlamaState:
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
scores=self._scores.copy(),
input_ids=self._input_ids.copy(),
llama_state=llama_state_compact,
llama_state_size=n_bytes,
)
Expand All @@ -1381,6 +1405,8 @@ def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
self._scores = state.scores.copy()
self._input_ids = state.input_ids.copy()
state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
license="MIT",
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
packages=["llama_cpp", "llama_cpp.server"],
install_requires=[
"typing-extensions>=4.5.0",
],
install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"],
extras_require={
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
},
Expand Down