Skip to content

Commit

Permalink
Merge pull request facebookresearch#10 from Enet4/imp/error-handling
Browse files Browse the repository at this point in the history
Update error handling
  • Loading branch information
Enet4 authored Sep 21, 2020
2 parents 30da2c4 + d96ac7d commit 0bb3e95
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 146 deletions.
21 changes: 11 additions & 10 deletions src/cluster.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Vector clustering interface and implementation.
use crate::error::Result;
use crate::faiss_try;
use crate::index::NativeIndex;
use faiss_sys::*;
use std::os::raw::c_int;
Expand Down Expand Up @@ -175,7 +176,7 @@ impl Clustering {
let d = d as c_int;
let k = k as c_int;
let mut inner: *mut FaissClustering = ptr::null_mut();
faiss_try!(faiss_Clustering_new(&mut inner, d, k));
faiss_try(faiss_Clustering_new(&mut inner, d, k))?;
Ok(Clustering { inner })
}
}
Expand All @@ -188,12 +189,12 @@ impl Clustering {
let d = d as c_int;
let k = k as c_int;
let mut inner: *mut FaissClustering = ptr::null_mut();
faiss_try!(faiss_Clustering_new_with_params(
faiss_try(faiss_Clustering_new_with_params(
&mut inner,
d,
k,
&params.inner
));
&params.inner,
))?;
Ok(Clustering { inner })
}
}
Expand All @@ -208,12 +209,12 @@ impl Clustering {
{
unsafe {
let n = x.len() / self.d() as usize;
faiss_try!(faiss_Clustering_train(
faiss_try(faiss_Clustering_train(
self.inner,
n as idx_t,
x.as_ptr(),
index.inner_ptr()
));
index.inner_ptr(),
))?;
Ok(())
}
}
Expand Down Expand Up @@ -357,14 +358,14 @@ pub fn kmeans_clustering(d: u32, k: u32, x: &[f32]) -> Result<KMeansResult> {
let n = x.len() / d as usize;
let mut centroids = vec![0_f32; (d * k) as usize];
let mut q_error: f32 = 0.;
faiss_try!(faiss_kmeans_clustering(
faiss_try(faiss_kmeans_clustering(
d as usize,
n,
k as usize,
x.as_ptr(),
centroids.as_mut_ptr(),
&mut q_error
));
&mut q_error,
))?;
Ok(KMeansResult { centroids, q_error })
}
}
Expand Down
20 changes: 12 additions & 8 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@ pub enum Error {

impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(self.description())
match self {
Error::Native(e) => write!(fmt, "Native faiss error: {}", e.msg),
Error::BadCast => fmt.write_str("Invalid index type cast"),
Error::IndexDescription => fmt.write_str("Invalid index description"),
Error::BadFilePath => fmt.write_str("Invalid file path"),
}
}
}

impl StdError for Error {
fn description(&self) -> &str {
match *self {
Error::Native(ref e) => &e.msg,
Error::BadCast => "Invalid index type cast",
Error::IndexDescription => "Invalid index description",
Error::BadFilePath => "Invalid file path",
fn source(&self) -> Option<&(dyn StdError + 'static)> {
if let Error::Native(err) = self {
Some(err)
} else {
None
}
}
}
Expand Down Expand Up @@ -70,7 +74,7 @@ impl NativeError {
/// a operation which returned a non-zero error code.
/// This function might panic if no operation was made
/// or the last operation was successful.
pub fn from_last_error(code: c_int) -> Self {
pub(crate) fn from_last_error(code: c_int) -> Self {
unsafe {
let e: *const _ = faiss_get_last_error();
assert!(!e.is_null());
Expand Down
9 changes: 5 additions & 4 deletions src/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Contents for GPU support
use crate::error::Result;
use crate::faiss_try;
use faiss_sys::*;
use std::ptr;

Expand Down Expand Up @@ -97,7 +98,7 @@ impl StandardGpuResources {
pub fn new() -> Result<Self> {
unsafe {
let mut ptr = ptr::null_mut();
faiss_try!(faiss_StandardGpuResources_new(&mut ptr));
faiss_try(faiss_StandardGpuResources_new(&mut ptr))?;
Ok(StandardGpuResources { inner: ptr })
}
}
Expand All @@ -110,21 +111,21 @@ impl GpuResources for StandardGpuResources {

fn no_temp_memory(&mut self) -> Result<()> {
unsafe {
faiss_try!(faiss_StandardGpuResources_noTempMemory(self.inner));
faiss_try(faiss_StandardGpuResources_noTempMemory(self.inner))?;
Ok(())
}
}

fn set_temp_memory(&mut self, size: usize) -> Result<()> {
unsafe {
faiss_try!(faiss_StandardGpuResources_setTempMemory(self.inner, size));
faiss_try(faiss_StandardGpuResources_setTempMemory(self.inner, size))?;
Ok(())
}
}

fn set_pinned_memory(&mut self, size: usize) -> Result<()> {
unsafe {
faiss_try!(faiss_StandardGpuResources_setPinnedMemory(self.inner, size));
faiss_try(faiss_StandardGpuResources_setPinnedMemory(self.inner, size))?;
Ok(())
}
}
Expand Down
33 changes: 17 additions & 16 deletions src/index/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use super::*;

use crate::error::{Error, Result};
use crate::faiss_try;
use std::mem;
use std::ptr;

Expand Down Expand Up @@ -34,11 +35,11 @@ impl FlatIndexImpl {
unsafe {
let metric = metric as c_uint;
let mut inner = ptr::null_mut();
faiss_try!(faiss_IndexFlat_new_with(
faiss_try(faiss_IndexFlat_new_with(
&mut inner,
(d & 0x7FFF_FFFF) as idx_t,
metric
));
metric,
))?;
Ok(FlatIndexImpl { inner })
}
}
Expand Down Expand Up @@ -74,14 +75,14 @@ impl FlatIndexImpl {
let n = x.len() / self.d() as usize;
let k = labels.len() / n;
let mut distances = vec![0.; n * k];
faiss_try!(faiss_IndexFlat_compute_distance_subset(
faiss_try(faiss_IndexFlat_compute_distance_subset(
self.inner,
n as idx_t,
x.as_ptr(),
k as idx_t,
distances.as_mut_ptr(),
labels.as_ptr() as *const _
));
labels.as_ptr() as *const _,
))?;
Ok(distances)
}
}
Expand Down Expand Up @@ -134,13 +135,13 @@ impl ConcurrentIndex for FlatIndexImpl {
unsafe {
let nq = query.len() / self.d() as usize;
let mut out_labels = vec![Idx::none(); k * nq];
faiss_try!(faiss_Index_assign(
faiss_try(faiss_Index_assign(
self.inner,
nq as idx_t,
query.as_ptr(),
out_labels.as_mut_ptr() as *mut _,
k as i64
));
k as i64,
))?;
Ok(AssignSearchResult { labels: out_labels })
}
}
Expand All @@ -149,29 +150,29 @@ impl ConcurrentIndex for FlatIndexImpl {
let nq = query.len() / self.d() as usize;
let mut distances = vec![0_f32; k * nq];
let mut labels = vec![Idx::none(); k * nq];
faiss_try!(faiss_Index_search(
faiss_try(faiss_Index_search(
self.inner,
nq as idx_t,
query.as_ptr(),
k as idx_t,
distances.as_mut_ptr(),
labels.as_mut_ptr() as *mut _
));
labels.as_mut_ptr() as *mut _,
))?;
Ok(SearchResult { distances, labels })
}
}
fn range_search(&self, query: &[f32], radius: f32) -> Result<RangeSearchResult> {
unsafe {
let nq = (query.len() / self.d() as usize) as idx_t;
let mut p_res: *mut FaissRangeSearchResult = ptr::null_mut();
faiss_try!(faiss_RangeSearchResult_new(&mut p_res, nq));
faiss_try!(faiss_Index_range_search(
faiss_try(faiss_RangeSearchResult_new(&mut p_res, nq))?;
faiss_try(faiss_Index_range_search(
self.inner,
nq,
query.as_ptr(),
radius,
p_res
));
p_res,
))?;
Ok(RangeSearchResult { inner: p_res })
}
}
Expand Down
47 changes: 24 additions & 23 deletions src/index/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::{
RangeSearchResult, SearchResult,
};
use crate::error::Result;
use crate::faiss_try;
use crate::gpu::GpuResources;
use crate::metric::MetricType;
use crate::selector::IdSelector;
Expand Down Expand Up @@ -76,12 +77,12 @@ where
{
unsafe {
let mut gpuindex_ptr = ptr::null_mut();
faiss_try!(faiss_index_cpu_to_gpu(
faiss_try(faiss_index_cpu_to_gpu(
gpu_res.inner_ptr(),
device,
index.inner_ptr(),
&mut gpuindex_ptr
));
&mut gpuindex_ptr,
))?;
Ok(GpuIndexImpl {
inner: gpuindex_ptr,
phantom: PhantomData,
Expand Down Expand Up @@ -135,7 +136,7 @@ where
pub fn to_cpu(&self) -> Result<I> {
unsafe {
let mut cpuindex_ptr = ptr::null_mut();
faiss_try!(faiss_index_gpu_to_cpu(self.inner, &mut cpuindex_ptr));
faiss_try(faiss_index_gpu_to_cpu(self.inner, &mut cpuindex_ptr))?;
Ok(I::from_inner_ptr(cpuindex_ptr))
}
}
Expand Down Expand Up @@ -171,28 +172,28 @@ where
fn add(&mut self, x: &[f32]) -> Result<()> {
unsafe {
let n = x.len() / self.d() as usize;
faiss_try!(faiss_Index_add(self.inner, n as i64, x.as_ptr()));
faiss_try(faiss_Index_add(self.inner, n as i64, x.as_ptr()))?;
Ok(())
}
}

fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> Result<()> {
unsafe {
let n = x.len() / self.d() as usize;
faiss_try!(faiss_Index_add_with_ids(
faiss_try(faiss_Index_add_with_ids(
self.inner,
n as i64,
x.as_ptr(),
xids.as_ptr() as *const _
));
xids.as_ptr() as *const _,
))?;
Ok(())
}
}

fn train(&mut self, x: &[f32]) -> Result<()> {
unsafe {
let n = x.len() / self.d() as usize;
faiss_try!(faiss_Index_train(self.inner, n as i64, x.as_ptr()));
faiss_try(faiss_Index_train(self.inner, n as i64, x.as_ptr()))?;
Ok(())
}
}
Expand All @@ -201,13 +202,13 @@ where
unsafe {
let nq = query.len() / self.d() as usize;
let mut out_labels = vec![Idx::none(); k * nq];
faiss_try!(faiss_Index_assign(
faiss_try(faiss_Index_assign(
self.inner,
nq as idx_t,
query.as_ptr(),
out_labels.as_mut_ptr() as *mut _,
k as i64
));
k as i64,
))?;
Ok(AssignSearchResult { labels: out_labels })
}
}
Expand All @@ -217,14 +218,14 @@ where
let nq = query.len() / self.d() as usize;
let mut distances = vec![0_f32; k * nq];
let mut labels = vec![Idx::none(); k * nq];
faiss_try!(faiss_Index_search(
faiss_try(faiss_Index_search(
self.inner,
nq as idx_t,
query.as_ptr(),
k as idx_t,
distances.as_mut_ptr(),
labels.as_mut_ptr() as *mut _
));
labels.as_mut_ptr() as *mut _,
))?;
Ok(SearchResult { distances, labels })
}
}
Expand All @@ -233,33 +234,33 @@ where
unsafe {
let nq = (query.len() / self.d() as usize) as idx_t;
let mut p_res: *mut FaissRangeSearchResult = ptr::null_mut();
faiss_try!(faiss_RangeSearchResult_new(&mut p_res, nq));
faiss_try!(faiss_Index_range_search(
faiss_try(faiss_RangeSearchResult_new(&mut p_res, nq))?;
faiss_try(faiss_Index_range_search(
self.inner,
nq,
query.as_ptr(),
radius,
p_res
));
p_res,
))?;
Ok(RangeSearchResult { inner: p_res })
}
}

fn reset(&mut self) -> Result<()> {
unsafe {
faiss_try!(faiss_Index_reset(self.inner));
faiss_try(faiss_Index_reset(self.inner))?;
Ok(())
}
}

fn remove_ids(&mut self, sel: &IdSelector) -> Result<usize> {
unsafe {
let mut n_removed = 0;
faiss_try!(faiss_Index_remove_ids(
faiss_try(faiss_Index_remove_ids(
self.inner,
sel.inner_ptr(),
&mut n_removed
));
&mut n_removed,
))?;
Ok(n_removed)
}
}
Expand Down
Loading

0 comments on commit 0bb3e95

Please sign in to comment.