Skip to content

Commit

Permalink
using std::sync::Mutex, waiting for rust-lang/rust#96469 to use clear…
Browse files Browse the repository at this point in the history
…_poison()
  • Loading branch information
randommm committed Nov 13, 2023
1 parent 3bf4713 commit b53d02c
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions src/routes/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ use candle_transformers::generation::LogitsProcessor;

use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
use std::sync::Arc;
use tokio::sync::Mutex;
use std::sync::{Arc, Mutex, TryLockError};
pub struct ModelBuilder {
sample_len: usize,
temperature: f64,
Expand Down Expand Up @@ -136,7 +135,21 @@ impl Model {
prompt_str: String,
pre_prompt_tokens: &Vec<u32>,
) -> Result<(String, Vec<u32>), Box<dyn std::error::Error>> {
let mut model_weights = self.model_weights.lock().await;
let mut model_weights = loop {
match self.model_weights.try_lock() {
Ok(model_weights) => break model_weights,
Err(TryLockError::Poisoned(e)) => {
let guard = e.into_inner();
// waiting for https://github.com/rust-lang/rust/issues/96469
// *guard = build_model_weights()?;
// self.model_weights.clear_poison();
// println!("Note: model_weights mutex was poisoned, will try to rebuild");
break guard;
}
Err(TryLockError::WouldBlock) => {}
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
};

tokio::task::block_in_place(move || {
let prompt_str = format!("[INST] {prompt_str} [/INST]");
Expand Down Expand Up @@ -272,4 +285,27 @@ mod tests {
let (output, _) = model.interact(prompt, &pre_prompt_tokens).await.unwrap();
println!("{output}");
}

// waiting for https://github.com/rust-lang/rust/issues/96469
// #[tokio::test]
// async fn poisoning_rebuild() {
// let model = ModelBuilder::default().build().unwrap();
// let c_model = model.clone();

// #[allow(unused_variables, unreachable_code)]
// std::thread::spawn(move || {
// let lock = c_model.model_weights.lock().unwrap();
// panic!();
// drop(lock);
// })
// .join()
// .unwrap_or_default();

// assert!(model.model_weights.is_poisoned());

// let prompt = "Create a basic Rust program".to_string();
// let pre_prompt_tokens = vec![];
// let (output, _) = model.interact(prompt, &pre_prompt_tokens).await.unwrap();
// println!("{output}");
// }
}

0 comments on commit b53d02c

Please sign in to comment.