Skip to content

Commit a1c6dc2

Browse files
committed
Update llama.cpp API and supplementing the State/sessions API
Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent 914893d commit a1c6dc2

File tree

2 files changed

+88
-5
lines changed

2 files changed

+88
-5
lines changed

llama_cpp/_internals.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,16 +347,92 @@ def memory_seq_pos_max(self, seq_id: int) -> int:
347347
def memory_seq_pos_min(self, seq_id: int) -> int:
348348
return llama_cpp.llama_memory_seq_pos_min(self.get_memory(), seq_id)
349349

350+
# // State / sessions API
351+
350352
def get_state_size(self) -> int:
351353
return llama_cpp.llama_state_get_size(self.ctx)
352354

353-
# TODO: copy_state_data
355+
def get_state_data(self, dst:ctypes.Array[ctypes.c_uint8], size: int) -> int:
356+
return llama_cpp.llama_state_get_data(self.ctx, dst, size)
357+
358+
def set_state_data(self, src:ctypes.Array[ctypes.c_uint8], size: int) -> int:
359+
return llama_cpp.llama_state_set_data(self.ctx, src, size)
360+
361+
def load_state_file(
362+
self,
363+
path_session: bytes,
364+
tokens_out: ctypes.Array[llama_cpp.llama_token],
365+
n_token_capacity: ctypes.c_size_t,
366+
n_token_count_out: ctypes.pointer(ctypes.c_size_t)
367+
) -> bool:
368+
return llama_cpp.llama_state_load_file(self.ctx, path_session, tokens_out, n_token_capacity, n_token_count_out)
369+
370+
def save_state_file(
371+
self,
372+
path_session: bytes,
373+
tokens: ctypes.Array[llama_cpp.llama_token],
374+
n_token_count: ctypes.c_size_t
375+
) -> bool:
376+
return llama_cpp.llama_state_save_file(self.ctx, path_session, tokens, n_token_count)
377+
378+
def get_state_seq_size(self, seq_id: int) -> int:
379+
return llama_cpp.llama_state_seq_get_size(self.ctx, seq_id)
380+
381+
def get_state_seq_data(self, dst: ctypes.Array[ctypes.c_uint8], size: int, seq_id: int) -> int:
382+
return llama_cpp.llama_state_seq_get_data(self.ctx, dst, size, seq_id)
383+
384+
def set_state_seq_data(self, src: ctypes.Array[ctypes.c_uint8], size: int, dest_seq_id: int) -> int:
385+
return llama_cpp.llama_state_seq_set_data(self.ctx, src, size, dest_seq_id)
386+
387+
def load_state_seq_file(
388+
self,
389+
filepath: bytes,
390+
dest_seq_id: int,
391+
tokens_out: ctypes.Array[llama_cpp.llama_token],
392+
n_token_capacity: ctypes.c_size_t,
393+
n_token_count_out: ctypes.pointer(ctypes.c_size_t)
394+
) -> int:
395+
return llama_cpp.llama_state_seq_load_file(self.ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out)
396+
397+
def save_state_seq_file(
398+
self,
399+
filepath: bytes,
400+
seq_id: int,
401+
tokens: ctypes.Array[llama_cpp.llama_token],
402+
n_token_count: ctypes.c_size_t
403+
) -> int:
404+
return llama_cpp.llama_state_seq_save_file(self.ctx, filepath, seq_id, tokens, n_token_count)
405+
406+
def get_state_seq_size_ext(self, seq_id: int, flags: llama_cpp.llama_state_seq_flags) -> int:
407+
return llama_cpp.llama_state_seq_get_size_ext(self.ctx, seq_id, flags)
408+
409+
def get_state_seq_data_ext(
410+
self,
411+
dst:ctypes.Array[ctypes.c_uint8],
412+
size: int,
413+
seq_id: int,
414+
flags: llama_cpp.llama_state_seq_flags
415+
) -> int:
416+
return llama_cpp.llama_state_seq_get_data_ext(self.ctx, dst, size, seq_id, flags)
354417

355-
# TODO: set_state_data
418+
def set_state_seq_data_ext(
419+
self,
420+
src:ctypes.Array[ctypes.c_uint8],
421+
size: int,
422+
dest_seq_id: int,
423+
flags: llama_cpp.llama_state_seq_flags
424+
) -> int:
425+
return llama_cpp.llama_state_seq_set_data_ext(self.ctx, src, size, dest_seq_id, flags)
356426

357-
# TODO: llama_load_session_file
427+
# // Decoding API
358428

359-
# TODO: llama_save_session_file
429+
def encode(self, batch: LlamaBatch):
430+
return_code = llama_cpp.llama_encode(
431+
self.ctx,
432+
batch.batch,
433+
)
434+
if return_code != 0:
435+
raise RuntimeError(f"llama_encode returned {return_code}")
360436

361437
def decode(self, batch: LlamaBatch):
362438
return_code = llama_cpp.llama_decode(

llama_cpp/llama_cpp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1403,6 +1403,7 @@ def llama_supports_rpc() -> bool:
14031403

14041404
# // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
14051405
# // In some cases the requested values via llama_context_params may differ from the actual values used by the context
1406+
# // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
14061407
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
14071408
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
14081409
def llama_n_ctx(ctx: llama_context_p, /) -> int:
@@ -1503,6 +1504,12 @@ def llama_model_n_embd(model: llama_model_p, /) -> int:
15031504
...
15041505

15051506

1507+
# LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
1508+
@ctypes_function("llama_model_n_embd_inp", [llama_model_p_ctypes], ctypes.c_int32)
1509+
def llama_model_n_embd_inp(model: llama_model_p, /) -> int:
1510+
...
1511+
1512+
15061513
# LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
15071514
@ctypes_function("llama_model_n_layer", [llama_model_p_ctypes], ctypes.c_int32)
15081515
def llama_model_n_layer(model: llama_model_p, /) -> int:
@@ -2440,7 +2447,7 @@ def llama_save_session_file(
24402447
@ctypes_function(
24412448
"llama_state_seq_get_size",
24422449
[llama_context_p_ctypes, llama_seq_id],
2443-
ctypes.c_size_t,
2450+
llama_seq_id,
24442451
)
24452452
def llama_state_seq_get_size(ctx: llama_context_p, seq_id: llama_seq_id, /) -> int:
24462453
"""Get the exact size needed to copy the state of a single sequence"""

0 commit comments

Comments
 (0)