@@ -692,6 +692,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
692692 if (params.logdir .back () != DIRECTORY_SEPARATOR) {
693693 params.logdir += DIRECTORY_SEPARATOR;
694694 }
695+ } else if (arg == " -lcs" || arg == " --lookup-cache-static" ) {
696+ if (++i >= argc) {
697+ invalid_param = true ;
698+ break ;
699+ }
700+ params.lookup_cache_static = argv[i];
695701 } else if (arg == " --save-all-logits" || arg == " --kl-divergence-base" ) {
696702 if (++i >= argc) {
697703 invalid_param = true ;
@@ -1064,6 +1070,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10641070 printf (" draft model for speculative decoding\n " );
10651071 printf (" -ld LOGDIR, --logdir LOGDIR\n " );
10661072 printf (" path under which to save YAML logs (no logging if unset)\n " );
1073+ printf (" -lcs FNAME, --lookup-cache-static FNAME\n " );
1074+ printf (" path to static lookup cache to use for lookup decoding\n " );
10671075 printf (" --override-kv KEY=TYPE:VALUE\n " );
10681076 printf (" advanced option to override model metadata by key. may be specified multiple times.\n " );
10691077 printf (" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n " );
@@ -1805,3 +1813,228 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
18051813
18061814 printf (" \n === Done dumping\n " );
18071815}
1816+
1817+ void llama_ngram_cache_update (std::vector<llama_ngram_cache> & ncs, int ngram_min,
1818+ std::vector<llama_token> & inp, int nnew, bool print_progress) {
1819+ const int64_t t_start_ms = ggml_time_ms ();
1820+ const int ngram_max = ngram_min + ncs.size ()-1 ;
1821+ const int inp_size = inp.size ();
1822+
1823+ for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
1824+ llama_ngram_cache & nc = ncs[ngram_size - ngram_min];
1825+
1826+ const int i_start = std::max (inp_size - nnew, ngram_size);
1827+ for (int i = i_start; i < inp_size; ++i) {
1828+ const int ngram_start = i - ngram_size;
1829+ uint64_t ngram = inp[ngram_start];
1830+ for (int j = ngram_start+1 ; j < ngram_start + ngram_size; ++j) { // FIXME
1831+ const uint64_t ngram_part = inp[j];
1832+ ngram <<= 16 ;
1833+ ngram |= ngram_part;
1834+ }
1835+ const llama_token token = inp[i];
1836+
1837+ llama_ngram_cache::iterator part_it = nc.find (ngram);
1838+ if (part_it == nc.end ()) {
1839+ llama_ngram_cache_part part;
1840+ part.emplace (token, 1 );
1841+ nc.emplace (ngram, part);
1842+ } else {
1843+ llama_ngram_cache_part::iterator token_count_it = part_it->second .find (token);
1844+ if (token_count_it == part_it->second .end ()) {
1845+ part_it->second .emplace (token, 1 );
1846+ } else {
1847+ token_count_it->second ++;
1848+ }
1849+ }
1850+ if (print_progress && i % 10000000 == 0 ) {
1851+ const int64_t t_now_ms = ggml_time_ms ();
1852+ const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
1853+ const int64_t eta_min = eta_ms / (60 *1000 );
1854+ const int64_t eta_s = (eta_ms - eta_min) / 1000 ;
1855+
1856+ fprintf (stderr, " %s: %d/%d done, ETA: %02ld:%02ld\n " , __func__, i, inp_size, eta_min, eta_s);
1857+ }
1858+ }
1859+ }
1860+ }
1861+
1862+ // Helper function to get a token from the combined, speculative sequence of inp and draft.
1863+ static llama_token get_token (const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
1864+ return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
1865+ };
1866+
1867+ // If sample size or percentage in context are below these thresholds the draft is aborted early:
1868+ constexpr int draft_min_sample_size[LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
1869+ constexpr int draft_min_percent[LLAMA_NGRAM_MAX] = {50 , 50 , 50 , 50 };
1870+
1871+ void llama_ngram_cache_draft (
1872+ std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
1873+ std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2
1874+ ) {
1875+ const int inp_size = inp.size ();
1876+ const int ngram_max = ngram_min + ncs_t1.size ()-1 ;
1877+
1878+ while ((int ) draft.size ()-1 < n_draft) {
1879+ bool draft_success = false ;
1880+
1881+ const int ngram_start_t2 = inp_size-2 + draft.size ()-1 ;
1882+ uint64_t ngram_t2 = get_token (inp, draft, ngram_start_t2);
1883+ for (int j = ngram_start_t2+1 ; j < ngram_start_t2 + 2 ; ++j) {
1884+ const uint64_t token = get_token (inp, draft, j);
1885+ ngram_t2 <<= 16 ;
1886+ ngram_t2 |= token;
1887+ }
1888+ llama_ngram_cache::iterator part_t2_it = nc_t2.find (ngram_t2);
1889+ llama_ngram_cache_part part_t2;
1890+ if (part_t2_it != nc_t2.end ()) {
1891+ part_t2 = part_t2_it->second ;
1892+ }
1893+
1894+ for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
1895+ if (ngram_size > inp_size) {
1896+ continue ;
1897+ }
1898+
1899+ llama_ngram_cache & nc_t1 = ncs_t1[ngram_size - ngram_min];
1900+
1901+ const int ngram_start_t1 = inp_size-ngram_size + draft.size ()-1 ;
1902+ uint64_t ngram_t1 = get_token (inp, draft, ngram_start_t1);
1903+ for (int j = ngram_start_t1+1 ; j < ngram_start_t1 + ngram_size; ++j) {
1904+ const uint64_t token = get_token (inp, draft, j);
1905+ ngram_t1 <<= 16 ;
1906+ ngram_t1 |= token;
1907+ }
1908+
1909+ llama_ngram_cache::iterator part_t1_it = nc_t1.find (ngram_t1);
1910+ if (part_t1_it == nc_t1.end ()) {
1911+ continue ;
1912+ }
1913+ const llama_ngram_cache_part part_t1 = part_t1_it->second ;
1914+
1915+ int max_count_t1 = 0 ;
1916+ int max_count_t2 = 0 ;
1917+ int sum_count_t1 = 0 ;
1918+ llama_token max_token = -1 ;
1919+
1920+ for (std::pair<llama_token, int > token_count_t1 : part_t1) {
1921+ const llama_token token = token_count_t1.first ;
1922+
1923+ llama_ngram_cache_part::iterator token_count_t2_it = part_t2.find (token);
1924+ const int32_t count_t1 = token_count_t1.second ;
1925+ const int32_t count_t2 = token_count_t2_it != part_t2.end () ? 100 *token_count_t2_it->second : 1 ;
1926+
1927+ if (count_t1*count_t2 > max_count_t1*max_count_t2) {
1928+ max_token = token;
1929+ max_count_t1 = count_t1;
1930+ max_count_t2 = count_t2;
1931+ }
1932+ sum_count_t1 += count_t1;
1933+ }
1934+ // Skip this candidate if the sample size is too low:
1935+ if (sum_count_t1 < draft_min_sample_size[ngram_size-1 ]) {
1936+ continue ;
1937+ }
1938+ // skip this candidate if the empirically most likely token following this token is not likely enough:
1939+ if (100 *max_count_t1 < draft_min_percent[ngram_size-1 ]*sum_count_t1) {
1940+ continue ;
1941+ }
1942+
1943+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count_t1);
1944+ draft.push_back (max_token);
1945+ draft_success = true ;
1946+ break ;
1947+ }
1948+
1949+ if (!draft_success) {
1950+ int max_count_t2 = 0 ;
1951+ int sum_count_t2 = 0 ;
1952+ llama_token max_token = -1 ;
1953+
1954+ for (std::pair<llama_token, int > token_count_t2 : part_t2) {
1955+ const llama_token token = token_count_t2.first ;
1956+ const int32_t count_t2 = token_count_t2.second ;
1957+
1958+ if (count_t2 > max_count_t2) {
1959+ max_token = token;
1960+ max_count_t2 = count_t2;
1961+ }
1962+ sum_count_t2 += count_t2;
1963+ }
1964+
1965+ // Skip this candidate if the sample size is too low:
1966+ if (sum_count_t2 < draft_min_sample_size[2 -1 ]) {
1967+ break ;
1968+ }
1969+ // skip this candidate if the empirically most likely token following this token is not likely enough:
1970+ if (100 *max_count_t2 < draft_min_percent[2 -1 ]*sum_count_t2) {
1971+ break ;
1972+ }
1973+
1974+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count_t2);
1975+ draft.push_back (max_token);
1976+ draft_success = true ;
1977+ break ;
1978+ }
1979+
1980+ if (!draft_success) {
1981+ break ;
1982+ }
1983+ }
1984+ };
1985+
1986+ void llama_ngram_cache_save (std::vector<llama_ngram_cache> & ngram_cache, std::string & filename) {
1987+ GGML_ASSERT (ngram_cache.size () == 1 );
1988+ std::ofstream file_out (filename, std::ios::binary);
1989+ for (std::pair<uint64_t , llama_ngram_cache_part> item : ngram_cache[0 ]) {
1990+ const uint64_t ngram = item.first ;
1991+ llama_ngram_cache_part token_counts = item.second ;
1992+ GGML_ASSERT (!token_counts.empty ());
1993+ const int32_t ntokens = token_counts.size ();
1994+
1995+
1996+ file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (uint64_t ));
1997+ file_out.write (reinterpret_cast <const char *>(&ntokens), sizeof (int32_t ));
1998+ for (std::pair<llama_token, int32_t > item2 : token_counts) {
1999+ const llama_token token = item2.first ;
2000+ const int32_t count = item2.second ;
2001+ file_out.write (reinterpret_cast <const char *>(&token), sizeof (llama_token));
2002+ file_out.write (reinterpret_cast <const char *>(&count), sizeof (int32_t ));
2003+ }
2004+ }
2005+
2006+ }
2007+
2008+ llama_ngram_cache llama_ngram_cache_load (std::string & filename) {
2009+ std::ifstream hashmap_file (filename, std::ios::binary);
2010+ if (!hashmap_file) {
2011+ fprintf (stderr, " error: failed to open file '%s'\n " , filename.c_str ());
2012+ exit (1 );
2013+ }
2014+ llama_ngram_cache ngram_cache;
2015+
2016+ uint64_t ngram;
2017+ int32_t ntokens;
2018+ llama_token token;
2019+ int32_t count;
2020+
2021+ char * ngramc = reinterpret_cast <char *>(&ngram);
2022+ char * ntokensc = reinterpret_cast <char *>(&ntokens);
2023+ char * tokenc = reinterpret_cast <char *>(&token);
2024+ char * countc = reinterpret_cast <char *>(&count);
2025+ while (hashmap_file.read (ngramc, sizeof (uint64_t ))) {
2026+ GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
2027+ llama_ngram_cache_part token_counts;
2028+
2029+ for (int i = 0 ; i < ntokens; ++i) {
2030+ GGML_ASSERT (hashmap_file.read (tokenc, sizeof (llama_token)));
2031+ GGML_ASSERT (hashmap_file.read (countc, sizeof (int32_t )));
2032+ token_counts.emplace (token, count);
2033+ }
2034+
2035+ ngram_cache.emplace (ngram, token_counts);
2036+ }
2037+ GGML_ASSERT (hashmap_file.eof ());
2038+
2039+ return ngram_cache;
2040+ }
0 commit comments