Skip to content

Commit

Permalink
Merge pull request #194 from L-jasmine/feat/ctx-size
Browse files Browse the repository at this point in the history
Feat/ctx size
  • Loading branch information
jmbejar authored Aug 6, 2024
2 parents 60715bd + ecb4cb0 commit 8bb9d87
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 9 deletions.
35 changes: 27 additions & 8 deletions moxin-backend/src/backend_impls/api_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ static WASM: &[u8] = include_bytes!("../../wasm/llama-api-server.wasm");
pub struct LLamaEdgeApiServer {
id: String,
listen_addr: SocketAddr,
load_model_options: LoadModelOptions,
wasm_module: Module,
running_controller: tokio::sync::broadcast::Sender<()>,
#[allow(dead_code)]
Expand All @@ -35,15 +36,23 @@ fn create_wasi(
load_model: &LoadModelOptions,
) -> wasmedge_sdk::WasmEdgeResult<WasiModule> {
// use model metadata context size
let ctx_size = Some(format!("{}", file.context_size.min(8 * 1024)));
let ctx_size = if let Some(n_ctx) = load_model.n_ctx {
Some(format!("{}", n_ctx))
} else {
Some(format!("{}", file.context_size.min(8 * 1024)))
};

let n_gpu_layers = match load_model.gpu_layers {
moxin_protocol::protocol::GPULayers::Specific(n) => Some(n.to_string()),
moxin_protocol::protocol::GPULayers::Max => None,
};

// Set n_batch to a fixed value of 128.
let batch_size = Some(format!("128"));
let batch_size = if let Some(n_batch) = load_model.n_batch {
Some(format!("{}", n_batch))
} else {
Some("128".to_string())
};

let mut prompt_template = load_model.prompt_template.clone();
if prompt_template.is_none() && !file.prompt_template.is_empty() {
Expand Down Expand Up @@ -133,17 +142,23 @@ impl BackendModel for LLamaEdgeApiServer {
options: moxin_protocol::protocol::LoadModelOptions,
tx: std::sync::mpsc::Sender<anyhow::Result<moxin_protocol::protocol::LoadModelResponse>>,
) -> Self {
let load_model_options = options.clone();
let mut need_reload = true;
let (wasm_module, listen_addr) = if let Some(old_model) = &old_model {
if old_model.id == file.id.as_str() {
if old_model.id == file.id.as_str()
&& old_model.load_model_options.n_ctx == options.n_ctx
&& old_model.load_model_options.n_batch == options.n_batch
{
need_reload = false;
}
(old_model.wasm_module.clone(), old_model.listen_addr)
} else {
(
Module::from_bytes(None, WASM).unwrap(),
([0, 0, 0, 0], 8080).into(),
)
let new_addr = std::net::TcpListener::bind("localhost:0")
.unwrap()
.local_addr()
.unwrap();

(Module::from_bytes(None, WASM).unwrap(), new_addr)
};

if !need_reload {
Expand All @@ -152,6 +167,7 @@ impl BackendModel for LLamaEdgeApiServer {
file_id: file.id.to_string(),
model_id: file.model_id,
information: "".to_string(),
listen_port: listen_addr.port(),
},
)));
return old_model.unwrap();
Expand All @@ -165,7 +181,8 @@ impl BackendModel for LLamaEdgeApiServer {

let file_id = file.id.to_string();

let url = format!("http://localhost:{}/echo", listen_addr.port());
let listen_port = listen_addr.port();
let url = format!("http://localhost:{}/echo", listen_port);

let file_ = file.clone();

Expand Down Expand Up @@ -197,6 +214,7 @@ impl BackendModel for LLamaEdgeApiServer {
file_id: file_.id.to_string(),
model_id: file_.model_id,
information: "".to_string(),
listen_port,
},
)));
} else {
Expand All @@ -212,6 +230,7 @@ impl BackendModel for LLamaEdgeApiServer {
listen_addr,
running_controller,
model_thread,
load_model_options,
};

new_model
Expand Down
2 changes: 2 additions & 0 deletions moxin-backend/src/backend_impls/chat_ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ fn get_input(
file_id,
model_id,
information: String::new(),
listen_port: 0,
})));
}

Expand Down Expand Up @@ -430,6 +431,7 @@ impl super::BackendModel for ChatBotModel {
file_id: file.id.to_string(),
model_id: file.model_id,
information: "".to_string(),
listen_port: 0,
})));
return old_model.unwrap();
}
Expand Down
4 changes: 4 additions & 0 deletions moxin-backend/src/backend_impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ fn test_chat() {
rope_freq_scale: 0.0,
rope_freq_base: 0.0,
context_overflow_policy: moxin_protocol::protocol::ContextOverflowPolicy::StopAtLimit,
n_batch: Some(128),
n_ctx: Some(1024),
},
tx,
);
Expand Down Expand Up @@ -209,6 +211,8 @@ fn test_chat_stop() {
prompt_template: None,
gpu_layers: moxin_protocol::protocol::GPULayers::Max,
use_mlock: false,
n_batch: Some(128),
n_ctx: Some(1024),
rope_freq_scale: 0.0,
rope_freq_base: 0.0,
context_overflow_policy: moxin_protocol::protocol::ContextOverflowPolicy::StopAtLimit,
Expand Down
1 change: 1 addition & 0 deletions moxin-protocol/src/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ pub struct ChatResponseData {
pub choices: Vec<ChoiceData>,
pub created: u32,
pub model: ModelID,
#[serde(default)]
pub system_fingerprint: String,
pub usage: UsageData,

Expand Down
7 changes: 6 additions & 1 deletion moxin-protocol/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ pub struct LoadModelOptions {
pub prompt_template: Option<String>,
pub gpu_layers: GPULayers,
pub use_mlock: bool,
pub n_batch: Option<u32>,
pub n_ctx: Option<u32>,
pub rope_freq_scale: f32,
pub rope_freq_base: f32,

// TBD Not really sure if this is something backend manages or if it is matter of
// the client (if it is done by tweaking the JSON payload for the chat completition)
pub context_overflow_policy: ContextOverflowPolicy,
Expand All @@ -41,6 +42,10 @@ pub struct LoadedModelInfo {
pub file_id: FileID,
pub model_id: ModelID,

// The port where the local server is listening for the model.
// if 0, the server is not running.
pub listen_port: u16,

// JSON formatted string with the model information. See "Model Inspector" in LMStudio.
pub information: String,
}
Expand Down
2 changes: 2 additions & 0 deletions src/data/chats/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ impl ModelLoader {
rope_freq_base: 0.0,
context_overflow_policy:
moxin_protocol::protocol::ContextOverflowPolicy::StopAtLimit,
n_batch: None,
n_ctx: None,
},
tx,
);
Expand Down

0 comments on commit 8bb9d87

Please sign in to comment.