Skip to content

Commit 225de38

Browse files
committed
examples : fix build after sampling refactoring
1 parent 4a7f43f commit 225de38

File tree

8 files changed

+145
-233
lines changed

8 files changed

+145
-233
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h l
545545
$(CXX) $(CXXFLAGS) -c $< -o $@
546546

547547
COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
548-
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
548+
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o
549549

550550
common.o: common/common.cpp $(COMMON_H_DEPS)
551551
$(CXX) $(CXXFLAGS) -c $< -o $@

common/log.h

+69-69
Original file line numberDiff line numberDiff line change
@@ -579,75 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
579579
return buf.str();
580580
}
581581

582-
#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
583-
[&tokens, &ctx]() \
584-
{ \
585-
std::stringstream buf; \
586-
buf << "[ "; \
587-
\
588-
bool first = true; \
589-
for (const auto &token : tokens) \
590-
{ \
591-
if (!first) \
592-
buf << ", "; \
593-
else \
594-
first = false; \
595-
\
596-
auto detokenized = llama_token_to_piece(ctx, token); \
597-
\
598-
detokenized.erase( \
599-
std::remove_if( \
600-
detokenized.begin(), \
601-
detokenized.end(), \
602-
[](const unsigned char c) { return !std::isprint(c); }), \
603-
detokenized.end()); \
604-
\
605-
buf \
606-
<< "'" << detokenized << "'" \
607-
<< ":" << std::to_string(token); \
608-
} \
609-
buf << " ]"; \
610-
\
611-
return buf.str(); \
612-
}() \
613-
.c_str()
614-
615-
#define LOG_BATCH_TOSTR_PRETTY(ctx, batch) \
616-
[&batch, &ctx]() \
617-
{ \
618-
std::stringstream buf; \
619-
buf << "[ "; \
620-
\
621-
bool first = true; \
622-
for (int i = 0; i < batch.n_tokens; ++i) \
623-
{ \
624-
if (!first) \
625-
buf << ", "; \
626-
else \
627-
first = false; \
628-
\
629-
auto detokenized = llama_token_to_piece(ctx, batch.token[i]); \
630-
\
631-
detokenized.erase( \
632-
std::remove_if( \
633-
detokenized.begin(), \
634-
detokenized.end(), \
635-
[](const unsigned char c) { return !std::isprint(c); }), \
636-
detokenized.end()); \
637-
\
638-
buf \
639-
<< "\n" << std::to_string(i) \
640-
<< ":token '" << detokenized << "'" \
641-
<< ":pos " << std::to_string(batch.pos[i]) \
642-
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i]) \
643-
<< ":seq_id " << std::to_string(batch.seq_id[i][0]) \
644-
<< ":logits " << std::to_string(batch.logits[i]); \
645-
} \
646-
buf << " ]"; \
647-
\
648-
return buf.str(); \
649-
}() \
650-
.c_str()
582+
template <typename C, typename T>
583+
inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
584+
{
585+
std::stringstream buf;
586+
buf << "[ ";
587+
588+
bool first = true;
589+
for (const auto &token : tokens)
590+
{
591+
if (!first) {
592+
buf << ", ";
593+
} else {
594+
first = false;
595+
}
596+
597+
auto detokenized = llama_token_to_piece(ctx, token);
598+
599+
detokenized.erase(
600+
std::remove_if(
601+
detokenized.begin(),
602+
detokenized.end(),
603+
[](const unsigned char c) { return !std::isprint(c); }),
604+
detokenized.end());
605+
606+
buf
607+
<< "'" << detokenized << "'"
608+
<< ":" << std::to_string(token);
609+
}
610+
buf << " ]";
611+
612+
return buf.str();
613+
}
614+
615+
template <typename C, typename B>
616+
inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
617+
{
618+
std::stringstream buf;
619+
buf << "[ ";
620+
621+
bool first = true;
622+
for (int i = 0; i < batch.n_tokens; ++i)
623+
{
624+
if (!first) {
625+
buf << ", ";
626+
} else {
627+
first = false;
628+
}
629+
630+
auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
631+
632+
detokenized.erase(
633+
std::remove_if(
634+
detokenized.begin(),
635+
detokenized.end(),
636+
[](const unsigned char c) { return !std::isprint(c); }),
637+
detokenized.end());
638+
639+
buf
640+
<< "\n" << std::to_string(i)
641+
<< ":token '" << detokenized << "'"
642+
<< ":pos " << std::to_string(batch.pos[i])
643+
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
644+
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
645+
<< ":logits " << std::to_string(batch.logits[i]);
646+
}
647+
buf << " ]";
648+
649+
return buf.str();
650+
}
651651

652652
#ifdef LOG_DISABLE_LOGS
653653

common/sampling.h

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct llama_sampling_context {
5050
// internal
5151
grammar_parser::parse_state parsed_grammar;
5252

53+
// TODO: replace with ring-buffer
5354
std::vector<llama_token> prev;
5455
std::vector<llama_token_data> cur;
5556
};

examples/infill/infill.cpp

+18-26
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,12 @@ int main(int argc, char ** argv) {
257257

258258
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
259259
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
260-
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
260+
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
261261

262262
// Should not run without any tokens
263263
if (embd_inp.empty()) {
264264
embd_inp.push_back(llama_token_bos(ctx));
265-
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
265+
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
266266
}
267267

268268
// Tokenize negative prompt
@@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
273273
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
274274

275275
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
276-
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
276+
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
277277

278278
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
279-
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
279+
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
280280

281281
original_prompt_len = original_inp.size();
282282
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
294294
params.n_keep = (int)embd_inp.size();
295295
}
296296

297-
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
298-
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
297+
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
298+
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
299299

300300

301301
// enable interactive mode if interactive start is specified
@@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
388388
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
389389
}
390390

391-
// TODO: replace with ring-buffer
392-
std::vector<llama_token> last_tokens(n_ctx);
393-
std::fill(last_tokens.begin(), last_tokens.end(), 0);
394391
LOG_TEE("\n##### Infill mode #####\n\n");
395392
if (params.infill) {
396393
printf("\n************\n");
@@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
433430
std::vector<llama_token> embd;
434431
std::vector<llama_token> embd_guidance;
435432

436-
const int n_vocab = llama_n_vocab(model);
437-
438-
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
439-
std::vector<llama_token_data> candidates;
440-
candidates.reserve(n_vocab);
433+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
441434

442435
while (n_remain != 0 || params.interactive) {
443436
// predict
@@ -484,7 +477,7 @@ int main(int argc, char ** argv) {
484477

485478
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
486479

487-
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
480+
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
488481

489482
}
490483

@@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
512505
input_buf = embd_guidance.data();
513506
input_size = embd_guidance.size();
514507

515-
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
508+
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
516509
} else {
517510
input_buf = embd.data();
518511
input_size = embd.size();
@@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
535528
n_eval = params.n_batch;
536529
}
537530

538-
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
531+
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
539532

540533
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
541534
LOG_TEE("%s : failed to eval\n", __func__);
@@ -554,12 +547,11 @@ int main(int argc, char ** argv) {
554547

555548
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
556549

557-
const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
550+
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
558551

559-
last_tokens.erase(last_tokens.begin());
560-
last_tokens.push_back(id);
552+
llama_sampling_accept(ctx_sampling, ctx, id);
561553

562-
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
554+
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
563555

564556
embd.push_back(id);
565557

@@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
575567
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
576568
while ((int) embd_inp.size() > n_consumed) {
577569
embd.push_back(embd_inp[n_consumed]);
578-
last_tokens.erase(last_tokens.begin());
579-
last_tokens.push_back(embd_inp[n_consumed]);
570+
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
571+
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
580572
++n_consumed;
581573
if ((int) embd.size() >= params.n_batch) {
582574
break;
@@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
608600
if ((int) embd_inp.size() <= n_consumed) {
609601

610602
// deal with eot token in infill mode
611-
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
603+
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
612604
if(is_interacting && !params.interactive_first) {
613605
// print an eot token
614606
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
675667
is_interacting = false;
676668
}
677669
// deal with end of text token in interactive mode
678-
else if (last_tokens.back() == llama_token_eos(ctx)) {
670+
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
679671
LOG("found EOS token\n");
680672

681673
if (params.interactive) {
@@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
727719
const size_t original_size = embd_inp.size();
728720

729721
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
730-
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
722+
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
731723

732724
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
733725

0 commit comments

Comments
 (0)