Skip to content

Commit

Permalink
Make default system prompt configurable on web
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Nov 16, 2024
1 parent 81ed1cf commit 35bc088
Show file tree
Hide file tree
Showing 19 changed files with 349 additions and 123 deletions.
17 changes: 15 additions & 2 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
FLAG_seed = sparams.seed; // [jart]
return true;
}
if (arg == "-t" || arg == "--threads") {
Expand Down Expand Up @@ -490,17 +491,20 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "--top-p") {
CHECK_ARG
sparams.top_p = std::stof(argv[i]);
FLAG_top_p = sparams.top_p; // [jart]
return true;
}
if (arg == "--min-p") {
CHECK_ARG
sparams.min_p = std::stof(argv[i]);
return true;
}
if (arg == "--temp") {
if (arg == "--temp" || //
arg == "--temperature") { // [jart]
CHECK_ARG
sparams.temp = std::stof(argv[i]);
sparams.temp = std::max(sparams.temp, 0.0f);
FLAG_temperature = sparams.temp; // [jart]
return true;
}
if (arg == "--tfs") {
Expand All @@ -527,11 +531,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
if (arg == "--frequency-penalty") {
CHECK_ARG
sparams.penalty_freq = std::stof(argv[i]);
FLAG_frequency_penalty = sparams.penalty_freq; // [jart]
return true;
}
if (arg == "--presence-penalty") {
CHECK_ARG
sparams.penalty_present = std::stof(argv[i]);
FLAG_presence_penalty = sparams.penalty_present; // [jart]
return true;
}
if (arg == "--dynatemp-range") {
Expand Down Expand Up @@ -903,8 +909,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.verbose_prompt = true;
return true;
}
if (arg == "--no-display-prompt" || arg == "--silent-prompt") {
if (arg == "--no-display-prompt" || //
arg == "--silent-prompt") { // [jart]
params.display_prompt = false;
FLAG_no_display_prompt = true; // [jart]
return true;
}
if (arg == "--display-prompt") { // [jart]
params.display_prompt = true;
FLAG_no_display_prompt = false;
return true;
}
if (arg == "-r" || arg == "--reverse-prompt") {
Expand Down
9 changes: 9 additions & 0 deletions llama.cpp/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2593,6 +2593,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
FLAG_nologo = true;
}
else if (arg == "--no-display-prompt" || //
arg == "--silent-prompt")
{
FLAG_no_display_prompt = true;
}
else if (arg == "--display-prompt")
{
FLAG_no_display_prompt = false;
}
else if (arg == "--trap")
{
FLAG_trap = true;
Expand Down
72 changes: 56 additions & 16 deletions llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ bool FLAG_iq = false;
bool FLAG_log_disable = false;
bool FLAG_mlock = false;
bool FLAG_mmap = true;
bool FLAG_no_display_prompt = false;
bool FLAG_nocompile = false;
bool FLAG_nologo = false;
bool FLAG_precise = false;
Expand All @@ -59,7 +60,10 @@ const char *FLAG_prompt = nullptr;
const char *FLAG_url_prefix = "";
const char *FLAG_www_root = "/zip/www";
double FLAG_token_rate = 1;
float FLAG_temp = 0.8;
float FLAG_frequency_penalty = 0;
float FLAG_presence_penalty = 0;
float FLAG_temperature = .8;
float FLAG_top_p = .95;
int FLAG_batch = 2048;
int FLAG_ctx_size = 8192;
int FLAG_flash_attn = false;
Expand All @@ -69,7 +73,6 @@ int FLAG_http_obuf_size = 1024 * 1024;
int FLAG_keepalive = 5;
int FLAG_main_gpu = 0;
int FLAG_n_gpu_layers = -1;
int FLAG_seed = LLAMA_DEFAULT_SEED;
int FLAG_slots = 1;
int FLAG_split_mode = LLAMA_SPLIT_MODE_LAYER;
int FLAG_threads = MIN(cpu_get_num_math(), 20);
Expand All @@ -80,6 +83,7 @@ int FLAG_ubatch = 512;
int FLAG_verbose = 0;
int FLAG_warmup = true;
int FLAG_workers;
unsigned FLAG_seed = LLAMA_DEFAULT_SEED;

std::vector<std::string> FLAG_headers;

Expand Down Expand Up @@ -153,6 +157,17 @@ void llamafile_get_flags(int argc, char **argv) {
continue;
}

if (!strcmp(flag, "--no-display-prompt") || //
!strcmp(flag, "--silent-prompt")) {
FLAG_no_display_prompt = true;
continue;
}

if (!strcmp(flag, "--display-prompt")) {
FLAG_no_display_prompt = false;
continue;
}

//////////////////////////////////////////////////////////////////////
// server flags

Expand Down Expand Up @@ -278,6 +293,45 @@ void llamafile_get_flags(int argc, char **argv) {
continue;
}

//////////////////////////////////////////////////////////////////////
// sampling flags

if (!strcmp(flag, "--seed")) {
if (i == argc)
missing("--seed");
FLAG_seed = strtol(argv[i++], 0, 0);
continue;
}

if (!strcmp(flag, "--temp") || //
!strcmp(flag, "--temperature")) {
if (i == argc)
missing("--temp");
FLAG_temperature = atof(argv[i++]);
continue;
}

if (!strcmp(flag, "--top-p")) {
if (i == argc)
missing("--top-p");
FLAG_top_p = atof(argv[i++]);
continue;
}

if (!strcmp(flag, "--frequency-penalty")) {
if (i == argc)
missing("--frequency-penalty");
FLAG_frequency_penalty = atof(argv[i++]);
continue;
}

if (!strcmp(flag, "--presence-penalty")) {
if (i == argc)
missing("--presence-penalty");
FLAG_presence_penalty = atof(argv[i++]);
continue;
}

//////////////////////////////////////////////////////////////////////
// model flags

Expand Down Expand Up @@ -319,20 +373,6 @@ void llamafile_get_flags(int argc, char **argv) {
continue;
}

if (!strcmp(flag, "--seed")) {
if (i == argc)
missing("--seed");
FLAG_seed = atoi(argv[i++]);
continue;
}

if (!strcmp(flag, "--temp")) {
if (i == argc)
missing("--temp");
FLAG_temp = atof(argv[i++]);
continue;
}

if (!strcmp(flag, "-t") || !strcmp(flag, "--threads")) {
if (i == argc)
missing("--threads");
Expand Down
8 changes: 6 additions & 2 deletions llamafile/llamafile.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern bool FLAG_iq;
extern bool FLAG_log_disable;
extern bool FLAG_mlock;
extern bool FLAG_mmap;
extern bool FLAG_no_display_prompt;
extern bool FLAG_nocompile;
extern bool FLAG_nologo;
extern bool FLAG_precise;
Expand All @@ -30,7 +31,10 @@ extern const char *FLAG_prompt;
extern const char *FLAG_url_prefix;
extern const char *FLAG_www_root;
extern double FLAG_token_rate;
extern float FLAG_temp;
extern float FLAG_frequency_penalty;
extern float FLAG_presence_penalty;
extern float FLAG_temperature;
extern float FLAG_top_p;
extern int FLAG_batch;
extern int FLAG_ctx_size;
extern int FLAG_flash_attn;
Expand All @@ -41,7 +45,6 @@ extern int FLAG_http_obuf_size;
extern int FLAG_keepalive;
extern int FLAG_main_gpu;
extern int FLAG_n_gpu_layers;
extern int FLAG_seed;
extern int FLAG_slots;
extern int FLAG_split_mode;
extern int FLAG_threads;
Expand All @@ -52,6 +55,7 @@ extern int FLAG_ubatch;
extern int FLAG_verbose;
extern int FLAG_warmup;
extern int FLAG_workers;
extern unsigned FLAG_seed;

struct llamafile;
struct llamafile *llamafile_open_gguf(const char *, const char *);
Expand Down
2 changes: 2 additions & 0 deletions llamafile/server/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ Client::dispatcher()
return v1_chat_completions();
if (p1 == "slotz")
return slotz();
if (p1 == "flagz")
return flagz();

// serve static endpoints
int infd;
Expand Down
2 changes: 2 additions & 0 deletions llamafile/server/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct Client
char* params_memory_ = nullptr;
std::string_view payload_;
std::string resolved_;
std::string dump_;
Cleanup* cleanups_;
Buffer ibuf_;
Buffer obuf_;
Expand Down Expand Up @@ -112,6 +113,7 @@ struct Client
bool get_v1_chat_completions_params(V1ChatCompletionParams*) __wur;

bool slotz() __wur;
bool flagz() __wur;
};

} // namespace server
Expand Down
2 changes: 2 additions & 0 deletions llamafile/server/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <sys/resource.h>
#include <vector>

using jt::Json;

namespace lf {
namespace server {

Expand Down
49 changes: 49 additions & 0 deletions llamafile/server/flagz.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "client.h"
#include "llama.cpp/llama.h"
#include "llamafile/llamafile.h"
#include "llamafile/server/json.h"

namespace lf {
namespace server {

bool
Client::flagz()
{
jt::Json json;
json["prompt"] = FLAG_prompt;
json["no_display_prompt"] = FLAG_no_display_prompt;
json["nologo"] = FLAG_nologo;
json["temperature"] = FLAG_temperature;
json["presence_penalty"] = FLAG_presence_penalty;
json["frequency_penalty"] = FLAG_frequency_penalty;
if (FLAG_seed == LLAMA_DEFAULT_SEED) {
json["seed"] = nullptr;
} else {
json["seed"] = FLAG_seed;
}
dump_ = json.toStringPretty();
dump_ += '\n';
char* p = append_http_response_message(obuf_.p, 200);
p = stpcpy(p, "Content-Type: application/json\r\n");
return send_response(obuf_.p, p, dump_);
}

} // namespace server
} // namespace lf
Loading

0 comments on commit 35bc088

Please sign in to comment.