Skip to content

Commit

Permalink
Merge pull request #134 from moxin-org/switch-loaded-model-upon-chat-…
Browse files Browse the repository at this point in the history
…change

Switch loaded model upon chat change
  • Loading branch information
jmbejar authored Jul 8, 2024
2 parents 2b1c7e0 + fe7652b commit 09961f7
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 150 deletions.
2 changes: 1 addition & 1 deletion src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl MatchEvent for App {
discover_radio_button.select(cx, &mut Scope::empty());
}

if let ChatAction::Resume(_) = action.as_widget_action().cast() {
if let ChatAction::Resume = action.as_widget_action().cast() {
let chat_radio_button = self.ui.radio_button(id!(chat_tab));
chat_radio_button.select(cx, &mut Scope::empty());
}
Expand Down
8 changes: 2 additions & 6 deletions src/chat/chat_history_card.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,8 @@ impl Widget for ChatHistoryCard {
let title_label = self.view.label(id!(title));
title_label.set_text(chat.borrow_mut().get_title());

let initial_letter = chat
.borrow()
.model_filename
.chars()
.next()
.unwrap_or_default()
let initial_letter = store.get_last_used_file_initial_letter(self.chat_id)
.unwrap_or('A')
.to_uppercase()
.to_string();

Expand Down
4 changes: 2 additions & 2 deletions src/chat/chat_line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ impl ChatLineRef {
}
}

pub fn set_regenerate_enabled(&mut self, enabled: bool) {
pub fn set_regenerate_button_visible(&mut self, visible: bool) {
let Some(mut inner) = self.borrow_mut() else {
return;
};
inner.view(id!(save_and_regenerate)).set_visible(enabled);
inner.button(id!(save_and_regenerate)).set_visible(visible);
}
}
31 changes: 17 additions & 14 deletions src/chat/chat_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,6 @@ impl Widget for ChatPanel {
// Redraw because we expect to see new or updated chat entries
self.redraw(cx);
}
State::NoModelSelected => {
self.unload_model(cx);
}
_ => {}
}
}
Expand Down Expand Up @@ -481,6 +478,7 @@ impl WidgetMatchEvent for ChatPanel {
.filter_map(|action| action.as_widget_action())
{
if let ChatHistoryCardAction::ChatSelected(_) = action.cast() {
self.reset_scroll_messages(&store);
self.redraw(cx);
}

Expand All @@ -501,10 +499,7 @@ impl WidgetMatchEvent for ChatPanel {

store
.chats
.create_empty_chat_with_model_file(&downloaded_file.file);
}
ChatAction::Resume(file_id) => {
store.ensure_model_loaded_in_current_chat(file_id);
.create_empty_chat_and_load_file(&downloaded_file.file);
}
_ => {}
}
Expand All @@ -524,11 +519,12 @@ impl WidgetMatchEvent for ChatPanel {
if let ChatPanelAction::UnloadIfActive(file_id) = action.cast() {
if store
.chats
.get_current_chat()
.map_or(false, |chat| chat.borrow().file_id == file_id)
.loaded_model
.as_ref()
.map_or(false, |file| file.id == file_id)
{
self.unload_model(cx);
store.chats.eject_model().expect("Failed to eject model");
self.unload_model(cx);
}
}
}
Expand Down Expand Up @@ -715,6 +711,13 @@ impl ChatPanel {
list.smooth_scroll_to_end(cx, 10, 80.0);
}

fn reset_scroll_messages(&mut self, store: &Store) {
let list = self.portal_list(id!(chat));
let messages = get_chat_messages(store).unwrap();
let index = messages.len().saturating_sub(1);
list.set_first_id(index);
}

fn unload_model(&mut self, cx: &mut Cx) {
self.model_selector(id!(model_selector)).deselect(cx);
self.view.redraw(cx);
Expand Down Expand Up @@ -765,7 +768,7 @@ impl ChatPanel {

self.view(id!(empty_conversation))
.label(id!(avatar_label))
.set_text(&get_model_initial_letter(store).unwrap().to_string());
.set_text(&get_model_initial_letter(store).unwrap_or('A').to_string());
}
_ => {}
}
Expand Down Expand Up @@ -835,13 +838,13 @@ impl ChatPanel {

let username = chat_line_data.username.as_ref().map_or("", String::as_str);
chat_line_item.set_sender_name(&username);
chat_line_item.set_regenerate_enabled(false);
chat_line_item.set_regenerate_button_visible(false);
chat_line_item
.set_avatar_text(&get_initial_letter(username).unwrap().to_string());
} else {
item = list.item(cx, item_id, live_id!(UserChatLine)).unwrap();
chat_line_item = item.as_chat_line();
chat_line_item.set_regenerate_enabled(true);
chat_line_item.set_regenerate_button_visible(true);
};

chat_line_item.set_message_text(cx, &chat_line_data.content);
Expand Down Expand Up @@ -884,7 +887,7 @@ fn get_initial_letter(word: &str) -> Option<char> {

fn get_model_initial_letter(store: &Store) -> Option<char> {
let chat = get_chat(store)?;
let initial_letter = get_initial_letter(&chat.borrow().model_filename)?;
let initial_letter = store.get_last_used_file_initial_letter(chat.borrow().id)?;
Some(initial_letter.to_ascii_uppercase())
}

Expand Down
52 changes: 26 additions & 26 deletions src/chat/model_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::{
shared::{actions::ChatAction, utils::format_model_size},
};
use makepad_widgets::*;
use moxin_protocol::data::DownloadedFile;

use super::model_selector_list::{ModelSelectorAction, ModelSelectorListWidgetExt};

Expand Down Expand Up @@ -137,13 +136,6 @@ pub struct ModelSelector {

impl Widget for ModelSelector {
fn handle_event(&mut self, cx: &mut Cx, event: &Event, scope: &mut Scope) {
if let Event::Startup = event {
let store = scope.data.get::<Store>().unwrap();
if let Some(downloaded_file) = store.get_loaded_downloaded_file() {
self.update_ui_with_file(cx, downloaded_file);
}
}

self.view.handle_event(cx, event, scope);
self.widget_match_event(cx, event, scope);

Expand Down Expand Up @@ -175,9 +167,10 @@ impl Widget for ModelSelector {
}

fn draw_walk(&mut self, cx: &mut Cx2d, scope: &mut Scope, walk: Walk) -> DrawStep {
let store = scope.data.get::<Store>().unwrap();
let choose_label = self.label(id!(choose.label));

if no_options_to_display(scope) {
if no_options_to_display(store) {
choose_label.set_text("No Available Models");
let color = vec3(0.596, 0.635, 0.702);
choose_label.apply_over(
Expand All @@ -188,7 +181,7 @@ impl Widget for ModelSelector {
}
},
);
} else {
} else if no_active_model(store) {
choose_label.set_text("Choose a Model");
let color = vec3(0.0, 0.0, 0.0);
choose_label.apply_over(
Expand All @@ -199,6 +192,8 @@ impl Widget for ModelSelector {
}
},
);
} else {
self.update_selected_model_info(cx, store);
}

self.view.draw_walk(cx, scope, walk)
Expand All @@ -209,8 +204,12 @@ const MAX_OPTIONS_HEIGHT: f64 = 400.0;

impl WidgetMatchEvent for ModelSelector {
fn handle_actions(&mut self, cx: &mut Cx, actions: &Actions, scope: &mut Scope) {
let store = scope.data.get::<Store>().unwrap();

if let Some(fd) = self.view(id!(button)).finger_down(&actions) {
if no_options_to_display(scope) { return };
if no_options_to_display(store) {
return;
};
if fd.tap_count == 1 {
self.open = !self.open;

Expand Down Expand Up @@ -239,24 +238,16 @@ impl WidgetMatchEvent for ModelSelector {
}

for action in actions {
let store = scope.data.get_mut::<Store>().unwrap();
match action.as_widget_action().cast() {
ModelSelectorAction::Selected(downloaded_file) => {
self.update_ui_with_file(cx, downloaded_file);
ModelSelectorAction::Selected(_) => {
self.hide_options(cx);
}
_ => {}
}

match action.as_widget_action().cast() {
ChatAction::Start(file_id) => {
let downloaded_file = store
.downloads
.downloaded_files
.iter()
.find(|file| file.file.id == file_id)
.expect("Attempted to start chat with a no longer existing file")
.clone();
self.update_ui_with_file(cx, downloaded_file);
ChatAction::Start(_) => {
self.hide_options(cx);
}
_ => {}
}
Expand All @@ -265,10 +256,16 @@ impl WidgetMatchEvent for ModelSelector {
}

impl ModelSelector {
fn update_ui_with_file(&mut self, cx: &mut Cx, downloaded_file: DownloadedFile) {
fn hide_options(&mut self, cx: &mut Cx) {
self.open = false;
self.view(id!(options)).apply_over(cx, live! { height: 0 });
self.animator_cut(cx, id!(open.hide));
}

fn update_selected_model_info(&mut self, cx: &mut Cx, store: &Store) {
let Some(downloaded_file) = store.get_loaded_downloaded_file() else {
return;
};

self.view(id!(choose)).apply_over(
cx,
Expand Down Expand Up @@ -327,7 +324,10 @@ impl ModelSelectorRef {
}
}

fn no_options_to_display(scope: &mut Scope) -> bool {
let store = scope.data.get::<Store>().unwrap();
fn no_options_to_display(store: &Store) -> bool {
store.downloads.downloaded_files.is_empty()
}

fn no_active_model(store: &Store) -> bool {
store.get_loaded_downloaded_file().is_none()
}
25 changes: 11 additions & 14 deletions src/data/chats/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ enum TitleState {
#[derive(Serialize, Deserialize)]
struct ChatData {
id: ChatID,
model_filename: String,
file_id: FileID,
last_used_file_id: Option<FileID>,
messages: Vec<ChatMessage>,
title: String,
#[serde(default)]
Expand Down Expand Up @@ -82,8 +81,7 @@ impl Default for ChatInferenceParams {
pub struct Chat {
/// Unix timestamp in ms.
pub id: ChatID,
pub model_filename: String,
pub file_id: FileID,
pub last_used_file_id: Option<FileID>,
pub messages: Vec<ChatMessage>,
pub messages_update_sender: Sender<ChatTokenArrivalAction>,
pub messages_update_receiver: Receiver<ChatTokenArrivalAction>,
Expand All @@ -97,7 +95,7 @@ pub struct Chat {
}

impl Chat {
pub fn new(filename: String, file_id: FileID, chats_dir: PathBuf) -> Self {
pub fn new(chats_dir: PathBuf) -> Self {
let (tx, rx) = channel();

// Get Unix timestamp in ms for id.
Expand All @@ -109,11 +107,10 @@ impl Chat {
Self {
id,
title: String::from("New Chat"),
model_filename: filename,
file_id,
messages: vec![],
messages_update_sender: tx,
messages_update_receiver: rx,
last_used_file_id: None,
is_streaming: false,
title_state: TitleState::default(),
chats_dir,
Expand All @@ -129,8 +126,7 @@ impl Chat {
let data: ChatData = serde_json::from_str(&json)?;
let chat = Chat {
id: data.id,
model_filename: data.model_filename,
file_id: data.file_id,
last_used_file_id: data.last_used_file_id,
messages: data.messages,
title: data.title,
title_state: data.title_state,
Expand All @@ -149,8 +145,7 @@ impl Chat {
pub fn save(&self) {
let data = ChatData {
id: self.id,
model_filename: self.model_filename.clone(),
file_id: self.file_id.clone(),
last_used_file_id: self.last_used_file_id.clone(),
messages: self.messages.clone(),
title: self.title.clone(),
title_state: self.title_state,
Expand Down Expand Up @@ -225,7 +220,7 @@ impl Chat {
let cmd = Command::Chat(
ChatRequestData {
messages,
model: self.model_filename.clone(),
model: loaded_file.name.clone(),
frequency_penalty: Some(ip.frequency_penalty),
logprobs: None,
top_logprobs: None,
Expand Down Expand Up @@ -256,11 +251,13 @@ impl Chat {
username: None,
content: prompt.clone(),
});
self.model_filename = loaded_file.name.clone();

self.last_used_file_id = Some(loaded_file.id.clone());

self.messages.push(ChatMessage {
id: next_id + 1,
role: Role::Assistant,
username: Some(self.model_filename.clone()),
username: Some(loaded_file.name.clone()),
content: "".to_string(),
});

Expand Down
Loading

0 comments on commit 09961f7

Please sign in to comment.