diff --git a/moxin-backend/src/backend_impls/api_server.rs b/moxin-backend/src/backend_impls/api_server.rs new file mode 100644 index 00000000..74713666 --- /dev/null +++ b/moxin-backend/src/backend_impls/api_server.rs @@ -0,0 +1,267 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use anyhow::anyhow; +use futures_util::StreamExt; +use moxin_protocol::{ + open_ai::{ + ChatResponse, ChatResponseChunkData, ChatResponseData, ChunkChoiceData, MessageData, Role, + StopReason, + }, + protocol::LoadModelOptions, +}; +use wasmedge_sdk::{wasi::WasiModule, Module, Store, Vm}; + +use crate::store::download_files::DownloadedFile; + +use super::BackendModel; + +static WASM: &[u8] = include_bytes!("../../wasm/llama-api-server.wasm"); + +/// Use server which is OpenAI compatible +pub struct LLamaEdgeApiServer { + id: String, + listen_addr: SocketAddr, + wasm_module: Module, + running_controller: tokio::sync::broadcast::Sender<()>, + #[allow(dead_code)] + model_thread: std::thread::JoinHandle<()>, +} + +fn create_wasi( + listen_addr: SocketAddr, + file: &DownloadedFile, + load_model: &LoadModelOptions, +) -> wasmedge_sdk::WasmEdgeResult { + let ctx_size = if load_model.n_ctx > 0 { + Some(load_model.n_ctx.to_string()) + } else { + None + }; + + let n_gpu_layers = match load_model.gpu_layers { + moxin_protocol::protocol::GPULayers::Specific(n) => Some(n.to_string()), + moxin_protocol::protocol::GPULayers::Max => None, + }; + + let batch_size = if load_model.n_batch > 0 { + Some(load_model.n_batch.to_string()) + } else { + None + }; + + let mut prompt_template = load_model.prompt_template.clone(); + if prompt_template.is_none() && !file.prompt_template.is_empty() { + prompt_template = Some(file.prompt_template.clone()); + } + + let reverse_prompt = if file.reverse_prompt.is_empty() { + None + } else { + Some(file.reverse_prompt.clone()) + }; + + let listen_addr = Some(format!("{listen_addr}")); + + let module_alias = file.name.as_ref(); + + let mut args = vec!["llama-api-server", "-a", module_alias, "-m", module_alias]; + + macro_rules! add_args { + ($flag:expr, $value:expr) => { + if let Some(ref value) = $value { + args.push($flag); + args.push(value.as_str()); + } + }; + } + + add_args!("-c", ctx_size); + add_args!("-g", n_gpu_layers); + add_args!("-b", batch_size); + add_args!("-p", prompt_template); + add_args!("-r", reverse_prompt); + add_args!("--socket-addr", listen_addr); + + WasiModule::create(Some(args), None, None) +} + +pub fn run_wasm_by_downloaded_file( + listen_addr: SocketAddr, + wasm_module: Module, + file: DownloadedFile, + load_model: LoadModelOptions, +) { + use wasmedge_sdk::AsInstance; + + let mut instances = HashMap::new(); + + let mut wasi = create_wasi(listen_addr, &file, &load_model).unwrap(); + instances.insert(wasi.name().to_string(), wasi.as_mut()); + + let mut wasi_nn = wasmedge_sdk::plugin::PluginManager::load_plugin_wasi_nn().unwrap(); + instances.insert(wasi_nn.name().unwrap(), &mut wasi_nn); + + let mut wasi_logger = wasmedge_sdk::plugin::PluginManager::create_plugin_instance( + "wasi_logging", + "wasi:logging/logging", + ) + .unwrap(); + instances.insert(wasi_logger.name().unwrap(), &mut wasi_logger); + + let store = Store::new(None, instances).unwrap(); + let mut vm = Vm::new(store); + vm.register_module(None, wasm_module.clone()).unwrap(); + + let _ = vm.run_func(None, "_start", []); + + log::debug!("wasm exit"); +} + +fn stop_chunk(reason: StopReason) -> ChatResponseChunkData { + ChatResponseChunkData { + id: String::new(), + choices: vec![ChunkChoiceData { + finish_reason: Some(reason), + index: 0, + delta: MessageData { + content: String::new(), + role: Role::Assistant, + }, + logprobs: None, + }], + created: 0, + model: String::new(), + system_fingerprint: String::new(), + object: "chat.completion.chunk".to_string(), + } +} + +impl BackendModel for LLamaEdgeApiServer { + fn new_or_reload( + async_rt: &tokio::runtime::Runtime, + old_model: Option, + file: crate::store::download_files::DownloadedFile, + options: moxin_protocol::protocol::LoadModelOptions, + tx: std::sync::mpsc::Sender>, + ) -> Self { + let mut need_reload = true; + let (wasm_module, listen_addr) = if let Some(old_model) = &old_model { + if old_model.id == file.id.as_str() { + need_reload = false; + } + (old_model.wasm_module.clone(), old_model.listen_addr) + } else { + ( + Module::from_bytes(None, WASM).unwrap(), + ([0, 0, 0, 0], 8080).into(), + ) + }; + + if !need_reload { + let _ = tx.send(Ok(moxin_protocol::protocol::LoadModelResponse::Completed( + moxin_protocol::protocol::LoadedModelInfo { + file_id: file.id.to_string(), + model_id: file.model_id, + information: "".to_string(), + }, + ))); + return old_model.unwrap(); + } + + let wasm_module_ = wasm_module.clone(); + + let file_id = file.id.to_string(); + let model_thread = std::thread::spawn(move || { + run_wasm_by_downloaded_file(listen_addr, wasm_module_, file, options) + }); + + let running_controller = tokio::sync::broadcast::channel(1).0; + + let new_model = Self { + id: file_id, + wasm_module, + listen_addr, + running_controller, + model_thread, + }; + + if let Some(old_model) = old_model { + old_model.stop(async_rt); + } + + new_model + } + + fn chat( + &self, + async_rt: &tokio::runtime::Runtime, + data: moxin_protocol::open_ai::ChatRequestData, + tx: std::sync::mpsc::Sender>, + ) -> bool { + let is_stream = data.stream.unwrap_or(false); + let url = format!("http://{}/v1/chat/completions", self.listen_addr); + let mut cancel = self.running_controller.subscribe(); + + async_rt.spawn(async move { + let request_body = serde_json::to_string(&data).unwrap(); + let resp = reqwest::Client::new() + .post(url) + .body(request_body) + .send() + .await + .map_err(|e| anyhow!(e)); + + match resp { + Ok(resp) => { + if is_stream { + let mut stream = resp.bytes_stream(); + + while let Some(chunk) = tokio::select! { + chunk = stream.next() => chunk, + _ = cancel.recv() => None, + } { + match chunk { + Ok(chunk) => { + if chunk.starts_with(b"data: [DONE]") { + break; + } + let resp: Result = + serde_json::from_slice(&chunk[5..]).map_err(|e| anyhow!(e)); + let _ = tx.send(resp.map(ChatResponse::ChatResponseChunk)); + } + Err(e) => { + let _ = tx.send(Err(anyhow!(e))); + return; + } + } + } + + let _ = tx.send(Ok(ChatResponse::ChatResponseChunk(stop_chunk( + StopReason::Stop, + )))); + } else { + let resp: Result = + resp.json().await.map_err(|e| anyhow!(e)); + let _ = tx.send(resp.map(ChatResponse::ChatFinalResponseData)); + } + } + Err(e) => { + let _ = tx.send(Err(e)); + let _ = tx.send(Ok(ChatResponse::ChatResponseChunk(stop_chunk( + StopReason::Stop, + )))); + } + } + }); + + true + } + + fn stop_chat(&self, _async_rt: &tokio::runtime::Runtime) { + let _ = self.running_controller.send(()); + } + + fn stop(self, _async_rt: &tokio::runtime::Runtime) { + // TODO + } +} diff --git a/moxin-backend/src/backend_impls/chat_ui.rs b/moxin-backend/src/backend_impls/chat_ui.rs new file mode 100644 index 00000000..653e93fd --- /dev/null +++ b/moxin-backend/src/backend_impls/chat_ui.rs @@ -0,0 +1,501 @@ +#[derive(Debug)] +pub enum TokenError { + EndOfSequence = 1, + ContextFull, + PromptTooLong, + TooLarge, + InvalidEncoding, + Other, +} + +impl Into for TokenError { + fn into(self) -> StopReason { + match self { + TokenError::EndOfSequence => StopReason::Stop, + TokenError::ContextFull => StopReason::Length, + TokenError::PromptTooLong => StopReason::Length, + TokenError::TooLarge => StopReason::Length, + TokenError::InvalidEncoding => StopReason::Stop, + TokenError::Other => StopReason::Stop, + } + } +} + +use std::{ + collections::HashMap, + io::Read, + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc::{Receiver, Sender}, + Arc, + }, + thread::JoinHandle, +}; + +use moxin_protocol::{ + open_ai::{ + ChatRequestData, ChatResponse, ChatResponseChunkData, ChatResponseData, ChoiceData, + ChunkChoiceData, MessageData, Role, StopReason, UsageData, + }, + protocol::{LoadModelOptions, LoadModelResponse, LoadedModelInfo}, +}; +use wasmedge_sdk::{ + error::{CoreError, CoreExecutionError}, + wasi::WasiModule, + CallingFrame, ImportObject, Instance, Module, Store, Vm, WasmValue, +}; + +use crate::store::download_files::DownloadedFile; + +#[derive(Debug)] +pub struct ChatBotUi { + pub current_req: std::io::Cursor>, + pub request_rx: Receiver<(ChatRequestData, Sender>)>, + request_id: uuid::Uuid, + chat_completion_message: Option>, + pub token_tx: Option>>, + running_controller: Arc, + pub load_model_state: Option<( + DownloadedFile, + LoadModelOptions, + Sender>, + )>, +} + +impl ChatBotUi { + pub fn new( + request_rx: Receiver<(ChatRequestData, Sender>)>, + running_controller: Arc, + file: DownloadedFile, + load_model: LoadModelOptions, + tx: Sender>, + ) -> Self { + Self { + request_rx, + request_id: uuid::Uuid::new_v4(), + token_tx: None, + running_controller, + current_req: std::io::Cursor::new(vec![]), + load_model_state: Some((file, load_model, tx)), + chat_completion_message: None, + } + } + + fn init_request(&mut self) -> Result<(), ()> { + if let Ok((req, tx)) = self.request_rx.recv() { + // Init current_req + if !req.stream.unwrap_or_default() { + self.chat_completion_message = Some(Vec::with_capacity( + (req.max_tokens.unwrap_or(512) * 8) as usize, + )) + } + *self.current_req.get_mut() = serde_json::to_vec(&req).unwrap(); + self.current_req.set_position(0); + self.request_id = uuid::Uuid::new_v4(); + self.token_tx = Some(tx); + self.running_controller.store(true, Ordering::Release); + Ok(()) + } else { + Err(()) + } + } + + pub fn read_data(&mut self, buf: &mut [u8]) -> std::io::Result { + let n = self.current_req.read(buf)?; + if n == 0 { + self.current_req.get_mut().clear(); + self.current_req.set_position(0); + } + Ok(n) + } + + fn send_completion_output( + token_tx: &mut Sender>, + id: String, + stop_reason: StopReason, + chat_completion_message: &mut Option>, + ) -> bool { + if let Some(chat_completion_message) = chat_completion_message.take() { + let _ = token_tx.send(Ok(ChatResponse::ChatFinalResponseData(ChatResponseData { + id, + choices: vec![ChoiceData { + finish_reason: stop_reason, + index: 0, + message: MessageData { + content: String::from_utf8_lossy(&chat_completion_message).to_string(), + role: Role::Assistant, + }, + logprobs: None, + }], + created: 0, + model: String::new(), + system_fingerprint: String::new(), + usage: UsageData { + completion_tokens: 0, + prompt_tokens: 0, + total_tokens: 0, + }, + object: "chat.completion".to_string(), + }))); + } else { + let _ = token_tx.send(Ok(ChatResponse::ChatResponseChunk(ChatResponseChunkData { + id: String::new(), + choices: vec![ChunkChoiceData { + finish_reason: Some(stop_reason), + index: 0, + delta: MessageData { + content: String::new(), + role: Role::Assistant, + }, + logprobs: None, + }], + created: 0, + model: String::new(), + system_fingerprint: String::new(), + object: "chat.completion.chunk".to_string(), + }))); + }; + true + } + + fn send_streamed_output( + token_tx: &mut Sender>, + id: String, + token: &[u8], + ) -> bool { + let _ = token_tx.send(Ok(ChatResponse::ChatResponseChunk(ChatResponseChunkData { + id, + choices: vec![ChunkChoiceData { + finish_reason: None, + index: 0, + delta: MessageData { + content: String::from_utf8_lossy(token).to_string(), + role: Role::Assistant, + }, + logprobs: None, + }], + created: 0, + model: String::new(), + system_fingerprint: String::new(), + object: "chat.completion.chunk".to_string(), + }))); + true + } + + fn send_output(&mut self, output: Result<&[u8], TokenError>) -> bool { + let id = self.request_id.to_string(); + match ( + output, + &mut self.chat_completion_message, + &mut self.token_tx, + ) { + (Ok(token), Some(chat_completion_message), Some(_tx)) => { + chat_completion_message.extend_from_slice(token); + true + } + (Ok(token), None, Some(tx)) => Self::send_streamed_output(tx, id, token), + (Err(token_error), chat_completion_message, Some(tx)) => { + Self::send_completion_output(tx, id, token_error.into(), chat_completion_message) + } + (_, _, None) => false, + } + } +} + +fn get_input( + data: &mut ChatBotUi, + _inst: &mut Instance, + frame: &mut CallingFrame, + args: Vec, +) -> Result, CoreError> { + let mem = frame + .memory_mut(0) + .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; + + if let Some([buf_ptr, buf_size]) = args.get(0..2) { + let buf_ptr = buf_ptr.to_i32() as usize; + let buf_size = buf_size.to_i32() as usize; + + let buf = mem + .mut_slice::(buf_ptr, buf_size) + .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; + + if data.current_req.get_ref().is_empty() { + if let Some((file, _, tx)) = data.load_model_state.take() { + let file_id = file.id.as_ref().clone(); + let model_id = file.model_id; + let _ = tx.send(Ok(LoadModelResponse::Completed(LoadedModelInfo { + file_id, + model_id, + information: String::new(), + }))); + } + + data.init_request().or(Err(CoreError::Common( + wasmedge_sdk::error::CoreCommonError::Interrupted, + )))?; + } + + let n = data.read_data(buf).unwrap(); + + Ok(vec![WasmValue::from_i32(n as i32)]) + } else { + Err(CoreError::Execution(CoreExecutionError::FuncTypeMismatch)) + } +} + +fn push_token( + data: &mut ChatBotUi, + _inst: &mut Instance, + frame: &mut CallingFrame, + args: Vec, +) -> Result, CoreError> { + if !data.running_controller.load(Ordering::Acquire) { + return Ok(vec![WasmValue::from_i32(-1)]); + } + + let mem = frame + .memory_mut(0) + .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; + + if let Some([buf_ptr, buf_size]) = args.get(0..2) { + let buf_ptr = buf_ptr.to_i32() as usize; + let buf_size = buf_size.to_i32() as usize; + + let r = if buf_ptr != 0 { + let buf = mem + .mut_slice::(buf_ptr, buf_size) + .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; + + data.send_output(Ok(buf)) + } else { + data.send_output(Err(TokenError::EndOfSequence)) + }; + + Ok(vec![WasmValue::from_i32(if r { 0 } else { -1 })]) + } else { + Err(CoreError::Execution(CoreExecutionError::FuncTypeMismatch)) + } +} + +fn return_token_error( + data: &mut ChatBotUi, + _inst: &mut Instance, + _frame: &mut CallingFrame, + args: Vec, +) -> Result, CoreError> { + if let Some(error_code) = args.get(0) { + let error_code = error_code.to_i32(); + let token_err = match error_code { + 1 => TokenError::EndOfSequence, + 2 => TokenError::ContextFull, + 3 => TokenError::PromptTooLong, + 4 => TokenError::TooLarge, + 5 => TokenError::InvalidEncoding, + _ => TokenError::Other, + }; + + data.send_output(Err(token_err)); + + Ok(vec![]) + } else { + Err(CoreError::Execution(CoreExecutionError::FuncTypeMismatch)) + } +} + +pub fn module(data: ChatBotUi) -> wasmedge_sdk::WasmEdgeResult> { + let mut module_builder = wasmedge_sdk::ImportObjectBuilder::new("chat_ui", data)?; + module_builder.with_func::<(i32, i32), i32>("get_input", get_input)?; + module_builder.with_func::<(i32, i32), i32>("push_token", push_token)?; + module_builder.with_func::("return_token_error", return_token_error)?; + + Ok(module_builder.build()) +} + +fn create_wasi( + file: &DownloadedFile, + load_model: &LoadModelOptions, +) -> wasmedge_sdk::WasmEdgeResult { + let ctx_size = if load_model.n_ctx > 0 { + Some(load_model.n_ctx.to_string()) + } else { + None + }; + + let n_gpu_layers = match load_model.gpu_layers { + moxin_protocol::protocol::GPULayers::Specific(n) => Some(n.to_string()), + moxin_protocol::protocol::GPULayers::Max => None, + }; + + let batch_size = if load_model.n_batch > 0 { + Some(load_model.n_batch.to_string()) + } else { + None + }; + + let mut prompt_template = load_model.prompt_template.clone(); + if prompt_template.is_none() && !file.prompt_template.is_empty() { + prompt_template = Some(file.prompt_template.clone()); + } + + let reverse_prompt = if file.reverse_prompt.is_empty() { + None + } else { + Some(file.reverse_prompt.clone()) + }; + + let module_alias = file.name.as_ref(); + + let mut args = vec!["chat_ui.wasm", "-a", module_alias]; + + macro_rules! add_args { + ($flag:expr, $value:expr) => { + if let Some(ref value) = $value { + args.push($flag); + args.push(value.as_str()); + } + }; + } + + add_args!("-c", ctx_size); + add_args!("-g", n_gpu_layers); + add_args!("-b", batch_size); + add_args!("-p", prompt_template); + add_args!("-r", reverse_prompt); + + WasiModule::create(Some(args), None, None) +} + +pub fn run_wasm_by_downloaded_file( + wasm_module: Module, + request_rx: Receiver<(ChatRequestData, Sender>)>, + model_running_controller: Arc, + file: DownloadedFile, + load_model: LoadModelOptions, + tx: Sender>, +) { + use wasmedge_sdk::vm::SyncInst; + use wasmedge_sdk::AsInstance; + + let mut instances: HashMap = HashMap::new(); + + let mut wasi = create_wasi(&file, &load_model).unwrap(); + let mut chatui = module(ChatBotUi::new( + request_rx, + model_running_controller, + file, + load_model, + tx, + )) + .unwrap(); + + instances.insert(wasi.name().to_string(), wasi.as_mut()); + let mut wasi_nn = wasmedge_sdk::plugin::PluginManager::load_plugin_wasi_nn().unwrap(); + instances.insert(wasi_nn.name().unwrap(), &mut wasi_nn); + instances.insert(chatui.name().unwrap(), &mut chatui); + + let store = Store::new(None, instances).unwrap(); + let mut vm = Vm::new(store); + vm.register_module(None, wasm_module.clone()).unwrap(); + + let _ = vm.run_func(None, "_start", []); + + log::debug!("wasm exit"); +} + +pub struct ChatBotModel { + id: String, + wasm_module: Module, + pub model_tx: Sender<(ChatRequestData, Sender>)>, + pub model_running_controller: Arc, + pub model_thread: JoinHandle<()>, +} + +static WASM: &[u8] = include_bytes!("../../wasm/chat_ui.wasm"); + +impl super::BackendModel for ChatBotModel { + fn new_or_reload( + async_rt: &tokio::runtime::Runtime, + old_model: Option, + file: DownloadedFile, + options: LoadModelOptions, + tx: Sender>, + ) -> Self { + let mut need_reload = true; + + let wasm_module = if let Some(old_model) = &old_model { + if old_model.id == file.id.as_str() { + need_reload = false; + } + old_model.wasm_module.clone() + } else { + Module::from_bytes(None, WASM).unwrap() + }; + + if !need_reload { + let _ = tx.send(Ok(LoadModelResponse::Completed(LoadedModelInfo { + file_id: file.id.to_string(), + model_id: file.model_id, + information: "".to_string(), + }))); + return old_model.unwrap(); + } + + let (model_tx, request_rx) = std::sync::mpsc::channel(); + let model_running_controller = Arc::new(AtomicBool::new(false)); + let model_running_controller_ = model_running_controller.clone(); + + let wasm_module_ = wasm_module.clone(); + + let file_id = file.id.to_string(); + + let model_thread = std::thread::spawn(move || { + run_wasm_by_downloaded_file( + wasm_module_, + request_rx, + model_running_controller_, + file, + options, + tx, + ) + }); + + let new_model = Self { + id: file_id, + model_tx, + model_thread, + model_running_controller, + wasm_module, + }; + + if let Some(old_model) = old_model { + old_model.stop(async_rt); + } + + new_model + } + + fn chat( + &self, + _async_rt: &tokio::runtime::Runtime, + data: ChatRequestData, + tx: Sender>, + ) -> bool { + self.model_tx.send((data, tx)).is_ok() + } + + fn stop_chat(&self, _async_rt: &tokio::runtime::Runtime) { + self.model_running_controller + .store(false, Ordering::Release); + } + + fn stop(self, _async_rt: &tokio::runtime::Runtime) { + let Self { + model_tx, + model_thread, + .. + } = self; + drop(model_tx); + let _ = model_thread.join(); + } +} diff --git a/moxin-backend/src/backend_impls.rs b/moxin-backend/src/backend_impls/mod.rs similarity index 55% rename from moxin-backend/src/backend_impls.rs rename to moxin-backend/src/backend_impls/mod.rs index f350ddbc..257b0da8 100644 --- a/moxin-backend/src/backend_impls.rs +++ b/moxin-backend/src/backend_impls/mod.rs @@ -15,495 +15,11 @@ use moxin_protocol::{ LocalServerResponse, }, }; -use wasmedge_sdk::Module; use crate::store::{self, ModelFileDownloader, RemoteModel}; -mod chat_ui { - - #[derive(Debug)] - pub enum TokenError { - EndOfSequence = 1, - ContextFull, - PromptTooLong, - TooLarge, - InvalidEncoding, - Other, - } - - impl Into for TokenError { - fn into(self) -> StopReason { - match self { - TokenError::EndOfSequence => StopReason::Stop, - TokenError::ContextFull => StopReason::Length, - TokenError::PromptTooLong => StopReason::Length, - TokenError::TooLarge => StopReason::Length, - TokenError::InvalidEncoding => StopReason::Stop, - TokenError::Other => StopReason::Stop, - } - } - } - - use std::{ - collections::HashMap, - io::Read, - path::Path, - sync::{ - atomic::{AtomicBool, Ordering}, - mpsc::{Receiver, Sender}, - Arc, - }, - thread::JoinHandle, - }; - - use moxin_protocol::{ - open_ai::{ - ChatRequestData, ChatResponse, ChatResponseChunkData, ChatResponseData, ChoiceData, - ChunkChoiceData, MessageData, Role, StopReason, UsageData, - }, - protocol::{LoadModelOptions, LoadModelResponse, LoadedModelInfo}, - }; - use wasmedge_sdk::{ - error::{CoreError, CoreExecutionError}, - wasi::WasiModule, - CallingFrame, ImportObject, Instance, Module, Store, Vm, WasmValue, - }; - - use crate::store::download_files::DownloadedFile; - - #[derive(Debug)] - pub struct ChatBotUi { - pub current_req: std::io::Cursor>, - pub request_rx: Receiver<(ChatRequestData, Sender>)>, - request_id: uuid::Uuid, - chat_completion_message: Option>, - pub token_tx: Option>>, - running_controller: Arc, - pub load_model_state: Option<( - DownloadedFile, - LoadModelOptions, - Sender>, - )>, - } - - impl ChatBotUi { - pub fn new( - request_rx: Receiver<(ChatRequestData, Sender>)>, - running_controller: Arc, - file: DownloadedFile, - load_model: LoadModelOptions, - tx: Sender>, - ) -> Self { - Self { - request_rx, - request_id: uuid::Uuid::new_v4(), - token_tx: None, - running_controller, - current_req: std::io::Cursor::new(vec![]), - load_model_state: Some((file, load_model, tx)), - chat_completion_message: None, - } - } - - fn init_request(&mut self) -> Result<(), ()> { - if let Ok((req, tx)) = self.request_rx.recv() { - // Init current_req - if !req.stream.unwrap_or_default() { - self.chat_completion_message = Some(Vec::with_capacity( - (req.max_tokens.unwrap_or(512) * 8) as usize, - )) - } - *self.current_req.get_mut() = serde_json::to_vec(&req).unwrap(); - self.current_req.set_position(0); - self.request_id = uuid::Uuid::new_v4(); - self.token_tx = Some(tx); - self.running_controller.store(true, Ordering::Release); - Ok(()) - } else { - Err(()) - } - } - - pub fn read_data(&mut self, buf: &mut [u8]) -> std::io::Result { - let n = self.current_req.read(buf)?; - if n == 0 { - self.current_req.get_mut().clear(); - self.current_req.set_position(0); - } - Ok(n) - } - - fn send_completion_output( - token_tx: &mut Sender>, - id: String, - stop_reason: StopReason, - chat_completion_message: &mut Option>, - ) -> bool { - if let Some(chat_completion_message) = chat_completion_message.take() { - let _ = token_tx.send(Ok(ChatResponse::ChatFinalResponseData(ChatResponseData { - id, - choices: vec![ChoiceData { - finish_reason: stop_reason, - index: 0, - message: MessageData { - content: String::from_utf8_lossy(&chat_completion_message).to_string(), - role: Role::Assistant, - }, - logprobs: None, - }], - created: 0, - model: String::new(), - system_fingerprint: String::new(), - usage: UsageData { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0, - }, - object: "chat.completion".to_string(), - }))); - } else { - let _ = token_tx.send(Ok(ChatResponse::ChatResponseChunk(ChatResponseChunkData { - id: String::new(), - choices: vec![ChunkChoiceData { - finish_reason: Some(stop_reason), - index: 0, - delta: MessageData { - content: String::new(), - role: Role::Assistant, - }, - logprobs: None, - }], - created: 0, - model: String::new(), - system_fingerprint: String::new(), - object: "chat.completion.chunk".to_string(), - }))); - }; - true - } - - fn send_streamed_output( - token_tx: &mut Sender>, - id: String, - token: &[u8], - ) -> bool { - let _ = token_tx.send(Ok(ChatResponse::ChatResponseChunk(ChatResponseChunkData { - id, - choices: vec![ChunkChoiceData { - finish_reason: None, - index: 0, - delta: MessageData { - content: String::from_utf8_lossy(token).to_string(), - role: Role::Assistant, - }, - logprobs: None, - }], - created: 0, - model: String::new(), - system_fingerprint: String::new(), - object: "chat.completion.chunk".to_string(), - }))); - true - } - - fn send_output(&mut self, output: Result<&[u8], TokenError>) -> bool { - let id = self.request_id.to_string(); - match ( - output, - &mut self.chat_completion_message, - &mut self.token_tx, - ) { - (Ok(token), Some(chat_completion_message), Some(_tx)) => { - chat_completion_message.extend_from_slice(token); - true - } - (Ok(token), None, Some(tx)) => Self::send_streamed_output(tx, id, token), - (Err(token_error), chat_completion_message, Some(tx)) => { - Self::send_completion_output( - tx, - id, - token_error.into(), - chat_completion_message, - ) - } - (_, _, None) => false, - } - } - } - - fn get_input( - data: &mut ChatBotUi, - _inst: &mut Instance, - frame: &mut CallingFrame, - args: Vec, - ) -> Result, CoreError> { - let mem = frame - .memory_mut(0) - .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; - - if let Some([buf_ptr, buf_size]) = args.get(0..2) { - let buf_ptr = buf_ptr.to_i32() as usize; - let buf_size = buf_size.to_i32() as usize; - - let buf = mem - .mut_slice::(buf_ptr, buf_size) - .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; - - if data.current_req.get_ref().is_empty() { - if let Some((file, _, tx)) = data.load_model_state.take() { - let file_id = file.id.as_ref().clone(); - let model_id = file.model_id; - let _ = tx.send(Ok(LoadModelResponse::Completed(LoadedModelInfo { - file_id, - model_id, - information: String::new(), - }))); - } - - data.init_request().or(Err(CoreError::Common( - wasmedge_sdk::error::CoreCommonError::Interrupted, - )))?; - } - - let n = data.read_data(buf).unwrap(); - - Ok(vec![WasmValue::from_i32(n as i32)]) - } else { - Err(CoreError::Execution(CoreExecutionError::FuncTypeMismatch)) - } - } - - fn push_token( - data: &mut ChatBotUi, - _inst: &mut Instance, - frame: &mut CallingFrame, - args: Vec, - ) -> Result, CoreError> { - if !data.running_controller.load(Ordering::Acquire) { - return Ok(vec![WasmValue::from_i32(-1)]); - } - - let mem = frame - .memory_mut(0) - .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; - - if let Some([buf_ptr, buf_size]) = args.get(0..2) { - let buf_ptr = buf_ptr.to_i32() as usize; - let buf_size = buf_size.to_i32() as usize; - - let r = if buf_ptr != 0 { - let buf = mem - .mut_slice::(buf_ptr, buf_size) - .ok_or(CoreError::Execution(CoreExecutionError::MemoryOutOfBounds))?; - - data.send_output(Ok(buf)) - } else { - data.send_output(Err(TokenError::EndOfSequence)) - }; - - Ok(vec![WasmValue::from_i32(if r { 0 } else { -1 })]) - } else { - Err(CoreError::Execution(CoreExecutionError::FuncTypeMismatch)) - } - } - - fn return_token_error( - data: &mut ChatBotUi, - _inst: &mut Instance, - _frame: &mut CallingFrame, - args: Vec, - ) -> Result, CoreError> { - if let Some(error_code) = args.get(0) { - let error_code = error_code.to_i32(); - let token_err = match error_code { - 1 => TokenError::EndOfSequence, - 2 => TokenError::ContextFull, - 3 => TokenError::PromptTooLong, - 4 => TokenError::TooLarge, - 5 => TokenError::InvalidEncoding, - _ => TokenError::Other, - }; - - data.send_output(Err(token_err)); - - Ok(vec![]) - } else { - Err(CoreError::Execution(CoreExecutionError::FuncTypeMismatch)) - } - } - - pub fn module(data: ChatBotUi) -> wasmedge_sdk::WasmEdgeResult> { - let mut module_builder = wasmedge_sdk::ImportObjectBuilder::new("chat_ui", data)?; - module_builder.with_func::<(i32, i32), i32>("get_input", get_input)?; - module_builder.with_func::<(i32, i32), i32>("push_token", push_token)?; - module_builder.with_func::("return_token_error", return_token_error)?; - - Ok(module_builder.build()) - } - - fn create_wasi( - file: &DownloadedFile, - load_model: &LoadModelOptions, - ) -> wasmedge_sdk::WasmEdgeResult { - let ctx_size = if load_model.n_ctx > 0 { - Some(load_model.n_ctx.to_string()) - } else { - None - }; - - let n_gpu_layers = match load_model.gpu_layers { - moxin_protocol::protocol::GPULayers::Specific(n) => Some(n.to_string()), - moxin_protocol::protocol::GPULayers::Max => None, - }; - - let batch_size = if load_model.n_batch > 0 { - Some(load_model.n_batch.to_string()) - } else { - None - }; - - let mut prompt_template = load_model.prompt_template.clone(); - if prompt_template.is_none() && !file.prompt_template.is_empty() { - prompt_template = Some(file.prompt_template.clone()); - } - - let reverse_prompt = if file.reverse_prompt.is_empty() { - None - } else { - Some(file.reverse_prompt.clone()) - }; - - let module_alias = file.name.as_ref(); - - let mut args = vec!["chat_ui.wasm", "-a", module_alias]; - - macro_rules! add_args { - ($flag:expr, $value:expr) => { - if let Some(ref value) = $value { - args.push($flag); - args.push(value.as_str()); - } - }; - } - - add_args!("-c", ctx_size); - add_args!("-g", n_gpu_layers); - add_args!("-b", batch_size); - add_args!("-p", prompt_template); - add_args!("-r", reverse_prompt); - - WasiModule::create(Some(args), None, None) - } - - pub fn nn_preload_file(file: &DownloadedFile) { - let file_path = Path::new(&file.download_dir) - .join(&file.model_id) - .join(&file.name); - - let preloads = wasmedge_sdk::plugin::NNPreload::new( - file.name.clone(), - wasmedge_sdk::plugin::GraphEncoding::GGML, - wasmedge_sdk::plugin::ExecutionTarget::AUTO, - &file_path, - ); - wasmedge_sdk::plugin::PluginManager::nn_preload(vec![preloads]); - } - - pub fn run_wasm_by_downloaded_file( - wasm_module: Module, - request_rx: Receiver<(ChatRequestData, Sender>)>, - model_running_controller: Arc, - file: DownloadedFile, - load_model: LoadModelOptions, - tx: Sender>, - ) { - use wasmedge_sdk::vm::SyncInst; - use wasmedge_sdk::AsInstance; - - let mut instances: HashMap = HashMap::new(); - - let mut wasi = create_wasi(&file, &load_model).unwrap(); - let mut chatui = module(ChatBotUi::new( - request_rx, - model_running_controller, - file, - load_model, - tx, - )) - .unwrap(); - - instances.insert(wasi.name().to_string(), wasi.as_mut()); - let mut wasi_nn = wasmedge_sdk::plugin::PluginManager::load_plugin_wasi_nn().unwrap(); - instances.insert(wasi_nn.name().unwrap(), &mut wasi_nn); - instances.insert(chatui.name().unwrap(), &mut chatui); - - let store = Store::new(None, instances).unwrap(); - let mut vm = Vm::new(store); - vm.register_module(None, wasm_module.clone()).unwrap(); - - let _ = vm.run_func(None, "_start", []); - - log::debug!("wasm exit"); - } - - pub struct Model { - pub model_tx: Sender<(ChatRequestData, Sender>)>, - pub model_running_controller: Arc, - pub model_thread: JoinHandle<()>, - } - - impl Model { - pub fn new_by_downloaded_file( - wasm_module: Module, - file: DownloadedFile, - options: LoadModelOptions, - tx: Sender>, - ) -> Self { - let (model_tx, request_rx) = std::sync::mpsc::channel(); - let model_running_controller = Arc::new(AtomicBool::new(false)); - let model_running_controller_ = model_running_controller.clone(); - - let model_thread = std::thread::spawn(move || { - run_wasm_by_downloaded_file( - wasm_module, - request_rx, - model_running_controller_, - file, - options, - tx, - ) - }); - Self { - model_tx, - model_thread, - model_running_controller, - } - } - - pub fn chat( - &self, - data: ChatRequestData, - tx: Sender>, - ) -> bool { - self.model_tx.send((data, tx)).is_ok() - } - - pub fn stop_chat(&self) { - self.model_running_controller - .store(false, Ordering::Release); - } - - pub fn stop(self) { - let Self { - model_tx, - model_thread, - .. - } = self; - drop(model_tx); - let _ = model_thread.join(); - } - } -} +mod api_server; +mod chat_ui; #[derive(Clone, Debug)] enum ModelManagementCommand { @@ -598,7 +114,7 @@ fn test_chat() { use moxin_protocol::open_ai::*; let home = std::env::var("HOME").unwrap(); - let bk = BackendImpl::build_command_sender( + let bk = BackendImpl::::build_command_sender( format!("{home}/ai/models"), format!("{home}/ai/models"), 3, @@ -673,7 +189,7 @@ fn test_chat_stop() { use moxin_protocol::open_ai::*; let home = std::env::var("HOME").unwrap(); - let bk = BackendImpl::build_command_sender( + let bk = BackendImpl::::build_command_sender( format!("{home}/ai/models"), format!("{home}/ai/models"), 3, @@ -761,7 +277,7 @@ fn test_chat_stop() { #[test] fn test_download_file() { let home = std::env::var("HOME").unwrap(); - let bk = BackendImpl::build_command_sender( + let bk = BackendImpl::::build_command_sender( format!("{home}/ai/models"), format!("{home}/ai/models"), 3, @@ -809,7 +325,7 @@ fn test_download_file() { #[test] fn test_get_download_file() { let home = std::env::var("HOME").unwrap(); - let bk = BackendImpl::build_command_sender( + let bk = BackendImpl::::build_command_sender( format!("{home}/ai/models"), format!("{home}/ai/models"), 3, @@ -829,7 +345,28 @@ pub enum DownloadControlCommand { Stop(FileID), } -pub struct BackendImpl { +pub type ChatModelBackend = BackendImpl; +pub type LlamaEdgeApiServerBackend = BackendImpl; + +pub trait BackendModel: Sized { + fn new_or_reload( + async_rt: &tokio::runtime::Runtime, + old_model: Option, + file: store::download_files::DownloadedFile, + options: LoadModelOptions, + tx: Sender>, + ) -> Self; + fn chat( + &self, + async_rt: &tokio::runtime::Runtime, + data: ChatRequestData, + tx: Sender>, + ) -> bool; + fn stop_chat(&self, async_rt: &tokio::runtime::Runtime); + fn stop(self, async_rt: &tokio::runtime::Runtime); +} + +pub struct BackendImpl { sql_conn: Arc>, #[allow(unused)] app_data_dir: PathBuf, @@ -840,14 +377,14 @@ pub struct BackendImpl { store::download_files::DownloadedFile, Sender>, )>, - model: Option, + model: Option, #[allow(unused)] async_rt: tokio::runtime::Runtime, control_tx: tokio::sync::broadcast::Sender, } -impl BackendImpl { +impl BackendImpl { /// # Arguments /// * `app_data_dir` - The directory where application data should be stored. /// * `models_dir` - The directory where models should be downloaded. @@ -910,7 +447,7 @@ impl BackendImpl { tx } - fn handle_command(&mut self, wasm_module: &Module, built_in_cmd: BuiltInCommand) { + fn handle_command(&mut self, built_in_cmd: BuiltInCommand) { match built_in_cmd { BuiltInCommand::Model(file) => match file { ModelManagementCommand::GetFeaturedModels(tx) => { @@ -1073,16 +610,12 @@ impl BackendImpl { match download_file { Ok(file) => { - chat_ui::nn_preload_file(&file); - let model = chat_ui::Model::new_by_downloaded_file( - wasm_module.clone(), - file, - options, - tx, - ); - if let Some(old_model) = self.model.replace(model) { - old_model.stop(); - } + nn_preload_file(&file); + let old_model = self.model.take(); + + let model = + Model::new_or_reload(&self.async_rt, old_model, file, options, tx); + self.model = Some(model); } Err(e) => { let _ = tx.send(Err(anyhow::anyhow!("Load model error: {e}"))); @@ -1091,19 +624,21 @@ impl BackendImpl { } ModelInteractionCommand::EjectModel(tx) => { if let Some(model) = self.model.take() { - model.stop(); + model.stop(&self.async_rt); } let _ = tx.send(Ok(())); } ModelInteractionCommand::Chat(data, tx) => { if let Some(model) = &self.model { - model.chat(data, tx); + model.chat(&self.async_rt, data, tx); } else { let _ = tx.send(Err(anyhow::anyhow!("Model not loaded"))); } } ModelInteractionCommand::StopChatCompletion(tx) => { - self.model.as_ref().map(|model| model.stop_chat()); + self.model + .as_ref() + .map(|model| model.stop_chat(&self.async_rt)); let _ = tx.send(Ok(())); } ModelInteractionCommand::StartLocalServer(_, _) => todo!(), @@ -1117,12 +652,9 @@ impl BackendImpl { } fn run_loop(&mut self) { - static WASM: &[u8] = include_bytes!("../chat_ui.wasm"); - let wasm_module = Module::from_bytes(None, WASM).unwrap(); - loop { if let Ok(cmd) = self.rx.recv() { - self.handle_command(&wasm_module, cmd.into()); + self.handle_command(cmd.into()); } else { break; } @@ -1131,3 +663,17 @@ impl BackendImpl { log::debug!("BackendImpl stop"); } } + +pub fn nn_preload_file(file: &store::download_files::DownloadedFile) { + let file_path = Path::new(&file.download_dir) + .join(&file.model_id) + .join(&file.name); + + let preloads = wasmedge_sdk::plugin::NNPreload::new( + file.name.clone(), + wasmedge_sdk::plugin::GraphEncoding::GGML, + wasmedge_sdk::plugin::ExecutionTarget::AUTO, + &file_path, + ); + wasmedge_sdk::plugin::PluginManager::nn_preload(vec![preloads]); +} diff --git a/moxin-backend/src/lib.rs b/moxin-backend/src/lib.rs index 9a086eca..58bbfa34 100644 --- a/moxin-backend/src/lib.rs +++ b/moxin-backend/src/lib.rs @@ -18,7 +18,7 @@ impl Backend { models_dir: M, max_download_threads: usize, ) -> Backend { - let command_sender = backend_impls::BackendImpl::build_command_sender( + let command_sender = backend_impls::ChatModelBackend::build_command_sender( app_data_dir, models_dir, max_download_threads, diff --git a/moxin-backend/chat_ui.wasm b/moxin-backend/wasm/chat_ui.wasm similarity index 100% rename from moxin-backend/chat_ui.wasm rename to moxin-backend/wasm/chat_ui.wasm diff --git a/moxin-backend/wasm/llama-api-server.wasm b/moxin-backend/wasm/llama-api-server.wasm new file mode 100644 index 00000000..36fb1d82 Binary files /dev/null and b/moxin-backend/wasm/llama-api-server.wasm differ