@@ -347,16 +347,92 @@ def memory_seq_pos_max(self, seq_id: int) -> int:
347347 def memory_seq_pos_min (self , seq_id : int ) -> int :
348348 return llama_cpp .llama_memory_seq_pos_min (self .get_memory (), seq_id )
349349
350+ # // State / sessions API
351+
350352 def get_state_size (self ) -> int :
351353 return llama_cpp .llama_state_get_size (self .ctx )
352354
353- # TODO: copy_state_data
355+ def get_state_data (self , dst :ctypes .Array [ctypes .c_uint8 ], size : int ) -> int :
356+ return llama_cpp .llama_state_get_data (self .ctx , dst , size )
357+
358+ def set_state_data (self , src :ctypes .Array [ctypes .c_uint8 ], size : int ) -> int :
359+ return llama_cpp .llama_state_set_data (self .ctx , src , size )
360+
361+ def load_state_file (
362+ self ,
363+ path_session : bytes ,
364+ tokens_out : ctypes .Array [llama_cpp .llama_token ],
365+ n_token_capacity : ctypes .c_size_t ,
366+ n_token_count_out : ctypes .pointer (ctypes .c_size_t )
367+ ) -> bool :
368+ return llama_cpp .llama_state_load_file (self .ctx , path_session , tokens_out , n_token_capacity , n_token_count_out )
369+
370+ def save_state_file (
371+ self ,
372+ path_session : bytes ,
373+ tokens : ctypes .Array [llama_cpp .llama_token ],
374+ n_token_count : ctypes .c_size_t
375+ ) -> bool :
376+ return llama_cpp .llama_state_save_file (self .ctx , path_session , tokens , n_token_count )
377+
378+ def get_state_seq_size (self , seq_id : int ) -> int :
379+ return llama_cpp .llama_state_seq_get_size (self .ctx , seq_id )
380+
381+ def get_state_seq_data (self , dst : ctypes .Array [ctypes .c_uint8 ], size : int , seq_id : int ) -> int :
382+ return llama_cpp .llama_state_seq_get_data (self .ctx , dst , size , seq_id )
383+
384+ def set_state_seq_data (self , src : ctypes .Array [ctypes .c_uint8 ], size : int , dest_seq_id : int ) -> int :
385+ return llama_cpp .llama_state_seq_set_data (self .ctx , src , size , dest_seq_id )
386+
387+ def load_state_seq_file (
388+ self ,
389+ filepath : bytes ,
390+ dest_seq_id : int ,
391+ tokens_out : ctypes .Array [llama_cpp .llama_token ],
392+ n_token_capacity : ctypes .c_size_t ,
393+ n_token_count_out : ctypes .pointer (ctypes .c_size_t )
394+ ) -> int :
395+ return llama_cpp .llama_state_seq_load_file (self .ctx , filepath , dest_seq_id , tokens_out , n_token_capacity , n_token_count_out )
396+
397+ def save_state_seq_file (
398+ self ,
399+ filepath : bytes ,
400+ seq_id : int ,
401+ tokens : ctypes .Array [llama_cpp .llama_token ],
402+ n_token_count : ctypes .c_size_t
403+ ) -> int :
404+ return llama_cpp .llama_state_seq_save_file (self .ctx , filepath , seq_id , tokens , n_token_count )
405+
406+ def get_state_seq_size_ext (self , seq_id : int , flags : llama_cpp .llama_state_seq_flags ) -> int :
407+ return llama_cpp .llama_state_seq_get_size_ext (self .ctx , seq_id , flags )
408+
409+ def get_state_seq_data_ext (
410+ self ,
411+ dst :ctypes .Array [ctypes .c_uint8 ],
412+ size : int ,
413+ seq_id : int ,
414+ flags : llama_cpp .llama_state_seq_flags
415+ ) -> int :
416+ return llama_cpp .llama_state_seq_get_data_ext (self .ctx , dst , size , seq_id , flags )
354417
355- # TODO: set_state_data
418+ def set_state_seq_data_ext (
419+ self ,
420+ src :ctypes .Array [ctypes .c_uint8 ],
421+ size : int ,
422+ dest_seq_id : int ,
423+ flags : llama_cpp .llama_state_seq_flags
424+ ) -> int :
425+ return llama_cpp .llama_state_seq_set_data_ext (self .ctx , src , size , dest_seq_id , flags )
356426
357- # TODO: llama_load_session_file
427+ # // Decoding API
358428
359- # TODO: llama_save_session_file
429+ def encode (self , batch : LlamaBatch ):
430+ return_code = llama_cpp .llama_encode (
431+ self .ctx ,
432+ batch .batch ,
433+ )
434+ if return_code != 0 :
435+ raise RuntimeError (f"llama_encode returned { return_code } " )
360436
361437 def decode (self , batch : LlamaBatch ):
362438 return_code = llama_cpp .llama_decode (
0 commit comments