Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions rust/lance-core/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,18 @@ impl FieldRef<'_> {
Ok(id)
}
FieldRef::ByPath(path) => {
let field = schema.field(path).ok_or_else(|| Error::InvalidInput {
source: format!("Field '{}' not found in schema", path).into(),
location: location!(),
let field = schema.field(path).ok_or_else(|| {
let paths = schema.field_paths();
let field_paths: Vec<&str> = paths.iter().map(|s| s.as_str()).collect();
let suggestion = crate::levenshtein::find_best_suggestion(path, &field_paths);
let mut error_msg = format!("Field '{}' not found in schema", path);
if let Some(suggestion) = suggestion {
error_msg = format!("{}. Did you mean '{}'?", error_msg, suggestion);
}
Error::InvalidInput {
source: error_msg.into(),
location: location!(),
}
})?;
Ok(field.id)
}
Expand Down Expand Up @@ -347,6 +356,27 @@ impl Schema {
SchemaFieldIterPreOrder::new(self)
}

/// Get all field paths in the schema as a list of strings.
///
/// This returns all field paths in the schema, including nested fields.
/// For example, if there's a struct field "user" with a field "name",
/// this will return "user.name" as one of the paths.
pub fn field_paths(&self) -> Vec<String> {
let mut paths = Vec::new();
for field in self.fields_pre_order() {
let ancestry = self.field_ancestry_by_id(field.id);
if let Some(ancestry) = ancestry {
let path = ancestry
.iter()
.map(|f| f.name.as_str())
.collect::<Vec<_>>()
.join(".");
paths.push(path);
}
}
paths
}

/// Returns a new schema that only contains the fields in `column_ids`.
///
/// This projection can filter out both top-level and nested fields
Expand Down Expand Up @@ -507,12 +537,19 @@ impl Schema {

// TODO: This is not a public API, change to pub(crate) after refactor is done.
pub fn field_id(&self, column: &str) -> Result<i32> {
self.field(column)
.map(|f| f.id)
.ok_or_else(|| Error::Schema {
message: "Vector column not in schema".to_string(),
self.field(column).map(|f| f.id).ok_or_else(|| {
let paths = self.field_paths();
let field_paths: Vec<&str> = paths.iter().map(|s| s.as_str()).collect();
let suggestion = crate::levenshtein::find_best_suggestion(column, &field_paths);
let mut error_msg = format!("Vector column '{}' not in schema", column);
if let Some(suggestion) = suggestion {
error_msg = format!("{}. Did you mean '{}'?", error_msg, suggestion);
}
Error::Schema {
message: error_msg.to_string(),
location: location!(),
})
}
})
}

pub fn top_level_field_ids(&self) -> Vec<i32> {
Expand Down
149 changes: 149 additions & 0 deletions rust/lance-core/src/levenshtein.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

/// Calculate the Levenshtein distance between two strings.
///
/// The Levenshtein distance is a measure of the number of single-character edits
/// (insertions, deletions, or substitutions) required to change one word into the other.
///
/// # Examples
///
/// ```
/// use lance_core::levenshtein::levenshtein_distance;
///
/// assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
/// assert_eq!(levenshtein_distance("hello", "hello"), 0);
/// assert_eq!(levenshtein_distance("hello", "world"), 4);
/// ```
pub fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let s1_len = s1.chars().count();
let s2_len = s2.chars().count();

// If one of the strings is empty, the distance is the length of the other
if s1_len == 0 {
return s2_len;
}
if s2_len == 0 {
return s1_len;
}

// Create a matrix to store the distances
let mut matrix = vec![vec![0; s2_len + 1]; s1_len + 1];

// Initialize the first row and column
for i in 0..=s1_len {
matrix[i][0] = i;
}
for j in 0..=s2_len {
matrix[0][j] = j;
}

// Fill the matrix
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();

for i in 1..=s1_len {
for j in 1..=s2_len {
let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = std::cmp::min(
std::cmp::min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1),
matrix[i - 1][j - 1] + cost,
);
}
}

matrix[s1_len][s2_len]
}

/// Find the best suggestion from a list of options based on Levenshtein distance.
///
/// Returns `Some(suggestion)` if there's an option where the Levenshtein distance
/// is less than 1/3 of the length of the input string.
/// Otherwise returns `None`.
///
/// # Examples
///
/// ```
/// use lance_core::levenshtein::find_best_suggestion;
///
/// let options = vec!["vector", "vector", "vector"];
/// assert_eq!(find_best_suggestion("vacter", &options), Some("vector"));
/// assert_eq!(find_best_suggestion("hello", &options), None);
/// ```
pub fn find_best_suggestion(input: &str, options: &[&str]) -> Option<String> {
let input_len = input.chars().count();
if input_len == 0 {
return None;
}

let threshold = input_len / 3;
let mut best_option: Option<(String, usize)> = None;

for option in options {
let distance = levenshtein_distance(input, option);
if distance <= threshold {
match &best_option {
None => best_option = Some((option.to_string(), distance)),
Some((_, best_distance)) => {
if distance < *best_distance {
best_option = Some((option.to_string(), distance));
}
}
}
}
}

best_option.map(|(option, _)| option)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("", ""), 0);
assert_eq!(levenshtein_distance("a", ""), 1);
assert_eq!(levenshtein_distance("", "a"), 1);
assert_eq!(levenshtein_distance("abc", "abc"), 0);
assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(levenshtein_distance("hello", "world"), 4);
assert_eq!(levenshtein_distance("vector", "vector"), 1);
assert_eq!(levenshtein_distance("vector", "vector"), 1);
assert_eq!(levenshtein_distance("vacter", "vector"), 2);
}

#[test]
fn test_find_best_suggestion() {
let options = vec!["vector", "vector", "vector", "column", "table"];

assert_eq!(
find_best_suggestion("vacter", &options),
Some("vector".to_string())
);
assert_eq!(
find_best_suggestion("vectr", &options),
Some("vector".to_string())
);
assert_eq!(
find_best_suggestion("column", &options),
Some("column".to_string())
);
assert_eq!(
find_best_suggestion("tble", &options),
Some("table".to_string())
);

// Should return None if no good match
assert_eq!(find_best_suggestion("hello", &options), None);
assert_eq!(find_best_suggestion("world", &options), None);

// Should return None if input is too short
assert_eq!(find_best_suggestion("v", &options), None);
assert_eq!(find_best_suggestion("", &options), None);
}
}
1 change: 1 addition & 0 deletions rust/lance-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod cache;
pub mod container;
pub mod datatypes;
pub mod error;
pub mod levenshtein;
pub mod traits;
pub mod utils;

Expand Down
37 changes: 33 additions & 4 deletions rust/lance-index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,39 @@ impl TryFrom<&str> for IndexType {
"IVF_HNSW_FLAT" => Ok(Self::IvfHnswFlat),
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
_ => Err(Error::invalid_input(
format!("invalid index type: {}", value),
location!(),
)),
_ => {
let valid_index_types = vec![
"BTree",
"Bitmap",
"LabelList",
"Inverted",
"NGram",
"FragmentReuse",
"MemWal",
"ZoneMap",
"Vector",
"IVF_FLAT",
"IVF_SQ",
"IVF_PQ",
"IVF_RQ",
"IVF_HNSW_FLAT",
"IVF_HNSW_SQ",
"IVF_HNSW_PQ",
];
let suggestion =
lance_core::levenshtein::find_best_suggestion(value, &valid_index_types);
let mut error_msg = format!("invalid index type: {}", value);
if let Some(suggestion) = suggestion {
error_msg = format!("{}. Did you mean '{}'?", error_msg, suggestion);
}
Err(Error::invalid_input(error_msg, location!()));
lance_core::levenshtein::find_best_suggestion(value, &valid_index_types);
let mut error_msg = format!("invalid index type: {}", value);
if let Some(suggestion) = suggestion {
error_msg = format!("{}. Did you mean '{}'?", error_msg, suggestion);
}
Err(Error::invalid_input(error_msg, location!()))
}
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions rust/lance-linalg/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,16 @@ impl TryFrom<&str> for DistanceType {
"cosine" => Ok(Self::Cosine),
"dot" => Ok(Self::Dot),
"hamming" => Ok(Self::Hamming),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Metric type '{s}' is not supported"
))),
_ => {
let valid_distance_types = vec!["l2", "euclidean", "cosine", "dot", "hamming"];
let suggestion =
lance_core::levenshtein::find_best_suggestion(s, &valid_distance_types);
let mut error_msg = format!("Metric type '{s}' is not supported");
if let Some(suggestion) = suggestion {
error_msg = format!("{}. Did you mean '{}'?", error_msg, suggestion);
}
Err(ArrowError::InvalidArgumentError(error_msg))
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ pub(crate) mod test_utils;

pub use clustering::Clustering;

type Error = ArrowError;
use lance_core::Error;
pub type Result<T> = std::result::Result<T, Error>;
Loading
Loading