Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rust] Avoid panic in error case #3133

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions extensions/tokenizers/rust/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod distilbert;
use crate::ndarray::as_data_type;
use crate::{cast_handle, to_handle, to_string_array};
use bert::{BertConfig, BertModel};
use candle_core::DType;
use candle_core::{DType, Error};
use candle_core::{Device, Result, Tensor};
use candle_nn::VarBuilder;
use distilbert::{DistilBertConfig, DistilBertModel};
Expand Down Expand Up @@ -43,7 +43,10 @@ fn load_model<'local>(

// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))?;
let config: Config = serde_json::from_str(&config).unwrap();
let config: Config = match serde_json::from_str(&config) {
Ok(conf) => conf,
Err(err) => return Err(Error::wrap(err)),
};

// Get candle device
let device = if candle_core::utils::cuda_is_available() {
Expand All @@ -55,7 +58,7 @@ fn load_model<'local>(
}?;

// Get candle dtype
let dtype = as_data_type(dtype).unwrap();
let dtype = as_data_type(dtype)?;

let safetensors_path = model_path.join("model.safetensors");
let vb = if safetensors_path.exists() {
Expand Down
36 changes: 22 additions & 14 deletions extensions/tokenizers/rust/src/ndarray/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,22 +295,30 @@ fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) -
match device_type.as_str() {
"cpu" => Ok(Device::Cpu),
"gpu" => {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).unwrap();
*device = Some(d.clone());
Ok(d)
if candle_core::utils::cuda_is_available() {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).unwrap();
*device = Some(d.clone());
Ok(d)
} else {
Err(Error::Msg(String::from("CUDA is not available.")))
}
}
"mps" => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).unwrap();
*device = Some(d.clone());
Ok(d)
if candle_core::utils::metal_is_available() {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).unwrap();
*device = Some(d.clone());
Ok(d)
} else {
Err(Error::Msg(String::from("metal is not available.")))
}
}
_ => Err(Error::Msg(format!("Invalid device type: {}", device_type))),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
"Model directory doesn't exist: " + modelPath.toAbsolutePath());
}
modelDir = modelPath.toAbsolutePath();
Path config = modelDir.resolve("config.json");
if (!Files.isRegularFile(config)) {
throw new FileNotFoundException("config.json file not found");
}
Path file = modelDir.resolve("model.safetensors");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle sharded checkpoints here?

Copy link
Contributor Author

@frankfliu frankfliu Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only support safetensors, we don't support for LLM for now.

Copy link
Contributor

@siddvenk siddvenk Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case for text-embedding models that they are provided in a single checkpoint? It seems from huggingface models that is overwhelmingly true, but there are some big text-embedding models in the single of billions parameter range (5b-7b) that are typically sharded but can still fit on single gpu. Example https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct, https://huggingface.co/intfloat/e5-mistral-7b-instruct

It's not super important for now, but at some point we'll probably need to handle sharded checkpoints even for embedding models since the user is free to specify how many shards/max file size of checkpoints

if (!Files.isRegularFile(file)) {
throw new FileNotFoundException("model.safetensors file not found");
}
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
block = new RsSymbolBlock((RsNDManager) manager, handle);
}
Expand Down
Loading