@@ -2567,6 +2567,85 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2567
2567
return nread;
2568
2568
}
2569
2569
2570
+ bool llama_load_session_file (struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2571
+ llama_file file (path_session, " rb" );
2572
+
2573
+ // sanity checks
2574
+ {
2575
+ const uint32_t magic = file.read_u32 ();
2576
+ const uint32_t version = file.read_u32 ();
2577
+
2578
+ if (!(magic == LLAMA_SESSION_MAGIC && version == LLAMA_SESSION_VERSION)) {
2579
+ fprintf (stderr, " %s : unknown (magic, version) for session file: %08x, %08x\n " , __func__, magic, version);
2580
+ return false ;
2581
+ }
2582
+
2583
+ llama_hparams session_hparams;
2584
+ file.read_raw (&session_hparams, sizeof (llama_hparams));
2585
+
2586
+ if (session_hparams != ctx->model .hparams ) {
2587
+ fprintf (stderr, " %s : model hparams didn't match from session file!\n " , __func__);
2588
+ return false ;
2589
+ }
2590
+ }
2591
+
2592
+ // load the prompt
2593
+ {
2594
+ const uint32_t n_token_count = file.read_u32 ();
2595
+
2596
+ if (n_token_count > n_token_capacity) {
2597
+ fprintf (stderr, " %s : token count in session file exceeded capacity! %u > %zu\n " , __func__, n_token_count, n_token_capacity);
2598
+ return false ;
2599
+ }
2600
+
2601
+ file.read_raw (tokens_out, sizeof (llama_token) * n_token_count);
2602
+ *n_token_count_out = n_token_count;
2603
+ }
2604
+
2605
+ // restore the context state
2606
+ {
2607
+ const size_t n_state_size_cur = file.size - file.tell ();
2608
+ const size_t n_state_size_exp = llama_get_state_size (ctx);
2609
+
2610
+ if (n_state_size_cur != n_state_size_exp) {
2611
+ fprintf (stderr, " %s : the state size in session file didn't match! expected %zu, got %zu\n " , __func__, n_state_size_exp, n_state_size_cur);
2612
+ return false ;
2613
+ }
2614
+
2615
+ std::vector<uint8_t > state_data (n_state_size_cur);
2616
+ file.read_raw (state_data.data (), n_state_size_cur);
2617
+
2618
+ llama_set_state_data (ctx, state_data.data ());
2619
+ }
2620
+
2621
+ return true ;
2622
+ }
2623
+
2624
+ bool llama_save_session_file (struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2625
+ llama_file file (path_session, " wb" );
2626
+
2627
+ file.write_u32 (LLAMA_SESSION_MAGIC);
2628
+ file.write_u32 (LLAMA_SESSION_VERSION);
2629
+
2630
+ file.write_raw (&ctx->model .hparams , sizeof (llama_hparams));
2631
+
2632
+ // save the prompt
2633
+ file.write_u32 ((uint32_t ) n_token_count);
2634
+ file.write_raw (tokens, sizeof (llama_token) * n_token_count);
2635
+
2636
+ // save the context state
2637
+ {
2638
+ const size_t n_state_size = llama_get_state_size (ctx);
2639
+
2640
+ std::vector<uint8_t > state_data (n_state_size);
2641
+ llama_copy_state_data (ctx, state_data.data ());
2642
+
2643
+ file.write_raw (state_data.data (), n_state_size);
2644
+ }
2645
+
2646
+ return true ;
2647
+ }
2648
+
2570
2649
int llama_eval (
2571
2650
struct llama_context * ctx,
2572
2651
const llama_token * tokens,
@@ -2694,57 +2773,3 @@ const char * llama_print_system_info(void) {
2694
2773
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map (struct llama_context * ctx) {
2695
2774
return ctx->model .tensors_by_name ;
2696
2775
}
2697
-
2698
- size_t llama_load_session_file (struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
2699
- // TODO leverage mmap
2700
- llama_file file (path_session, " rb" );
2701
- const uint32_t magic = file.read_u32 ();
2702
- const uint32_t version = file.read_u32 ();
2703
-
2704
- if (!(magic == ' ggsn' && version == 0 )) {
2705
- fprintf (stderr, " %s : unknown (magic, version) for session file: %08x, %08x\n " , __func__, magic, version);
2706
- return 0 ;
2707
- }
2708
-
2709
- llama_hparams session_hparams;
2710
- file.read_raw (&session_hparams, sizeof (llama_hparams));
2711
-
2712
- // REVIEW
2713
- if (session_hparams != ctx->model .hparams ) {
2714
- fprintf (stderr, " %s : model hparams didn't match from session file!\n " , __func__);
2715
- return 0 ;
2716
- }
2717
-
2718
- const uint32_t n_token_count = file.read_u32 ();
2719
- LLAMA_ASSERT (n_token_capacity >= n_token_count);
2720
- file.read_raw (tokens_out, sizeof (llama_token) * n_token_count);
2721
- *n_token_count_out = n_token_count;
2722
-
2723
- const size_t n_state_size = file.size - file.tell ();
2724
- const size_t n_orig_state_size = llama_get_state_size (ctx);
2725
- if (n_state_size != n_orig_state_size) {
2726
- fprintf (stderr, " %s : failed to validate state size\n " , __func__);
2727
- }
2728
- std::unique_ptr<uint8_t []> state_data (new uint8_t [n_state_size]);
2729
- file.read_raw (state_data.get (), n_state_size);
2730
- return llama_set_state_data (ctx, state_data.get ());
2731
- }
2732
-
2733
- size_t llama_save_session_file (struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
2734
- // TODO save temp & swap
2735
- llama_file file (path_session, " wb" );
2736
-
2737
- const size_t n_state_size = llama_get_state_size (ctx);
2738
- std::unique_ptr<uint8_t []> state_data (new uint8_t [n_state_size]);
2739
- llama_copy_state_data (ctx, state_data.get ());
2740
-
2741
- file.write_u32 (' ggsn' ); // magic
2742
- file.write_u32 (0 ); // version
2743
- file.write_raw (&ctx->model .hparams , sizeof (llama_hparams));
2744
-
2745
- file.write_u32 ((uint32_t ) n_token_count); // REVIEW
2746
- file.write_raw (tokens, sizeof (llama_token) * n_token_count);
2747
-
2748
- file.write_raw (state_data.get (), n_state_size);
2749
- return n_state_size; // REVIEW
2750
- }
0 commit comments