Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions src/bindings/rust/src/descriptors/reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

use super::*;
use super::sync_manager::{BackendSyncable, SyncManager};
use std::ops::{Index, IndexMut};

/// Public registration descriptor used for indexing and comparisons
#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -108,7 +109,7 @@ impl<'a> RegDescList<'a> {
pub fn get_type(&self) -> Result<MemType, NixlError> { Ok(self.mem_type) }

/// Adds a descriptor to the list
pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) -> Result<(), NixlError> {
pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) {
self.add_desc_with_meta(addr, len, dev_id, &[])
}

Expand All @@ -119,16 +120,13 @@ impl<'a> RegDescList<'a> {
len: usize,
dev_id: u64,
metadata: &[u8],
) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.push(RegDescriptor {
addr,
len,
dev_id,
metadata: metadata.to_vec(),
});
) {
self.sync_mgr.data_mut().descriptors.push(RegDescriptor {
addr,
len,
dev_id,
metadata: metadata.to_vec(),
});
Ok(())
}

/// Returns true if the list is empty
Expand All @@ -143,56 +141,61 @@ impl<'a> RegDescList<'a> {
pub fn len(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }

/// Trims the list to the given size
pub fn trim(&mut self) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.shrink_to_fit();
});
Ok(())
pub fn trim(&mut self) {
self.sync_mgr.data_mut().descriptors.shrink_to_fit();
}

/// Removes the descriptor at the given index
pub fn rem_desc(&mut self, index: i32) -> Result<(), NixlError> {
if index < 0 { return Err(NixlError::InvalidParam); }
let idx = index as usize;

self.sync_mgr.modify(|data| {
if idx >= data.descriptors.len() { return Err(NixlError::InvalidParam); }
data.descriptors.remove(idx);
Ok(())
})
let data = self.sync_mgr.data_mut();
if idx >= data.descriptors.len() {
return Err(NixlError::InvalidParam);
}
data.descriptors.remove(idx);
Ok(())
}

/// Prints the list contents
pub fn print(&self) -> Result<(), NixlError> {
self.sync_mgr.with_backend(|_data, backend| {
let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) };
match status {
NIXL_CAPI_SUCCESS => Ok(()),
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
_ => Err(NixlError::BackendError),
}
})?
let backend = self.sync_mgr.backend()?;
let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) };
match status {
NIXL_CAPI_SUCCESS => Ok(()),
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
_ => Err(NixlError::BackendError),
}
}

/// Clears all descriptors from the list
pub fn clear(&mut self) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.clear();
});
Ok(())
pub fn clear(&mut self) {
self.sync_mgr.data_mut().descriptors.clear();
}

/// Resizes the list to the given size
pub fn resize(&mut self, new_size: usize) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.resize(new_size, RegDescriptor {
addr: 0,
len: 0,
dev_id: 0,
metadata: Vec::new(),
});
pub fn resize(&mut self, new_size: usize) {
self.sync_mgr.data_mut().descriptors.resize(new_size, RegDescriptor {
addr: 0,
len: 0,
dev_id: 0,
metadata: Vec::new(),
});
Ok(())
}

/// Safe immutable access to descriptor by index
pub fn get(&self, index: usize) -> Result<&RegDescriptor, NixlError> {
self.sync_mgr.data().descriptors
.get(index)
.ok_or(NixlError::InvalidParam)
}

/// Safe mutable access to descriptor by index
pub fn get_mut(&mut self, index: usize) -> Result<&mut RegDescriptor, NixlError> {
self.sync_mgr.data_mut().descriptors
.get_mut(index)
.ok_or(NixlError::InvalidParam)
}

/// Add a descriptor from a type implementing NixlDescriptor
Expand Down Expand Up @@ -220,7 +223,8 @@ impl<'a> RegDescList<'a> {
let dev_id = desc.device_id();

// Add to list
self.add_desc(addr, len, dev_id)
self.add_desc(addr, len, dev_id);
Ok(())
}

pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_reg_dlist_s {
Expand Down Expand Up @@ -254,6 +258,23 @@ impl PartialEq for RegDescList<'_> {
}
}

// Implement Index trait for immutable indexing (list[i])
impl Index<usize> for RegDescList<'_> {
type Output = RegDescriptor;

fn index(&self, index: usize) -> &Self::Output {
&self.sync_mgr.data().descriptors[index]
}
}

// Implement IndexMut trait for mutable indexing (list[i] = value)
impl IndexMut<usize> for RegDescList<'_> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
// data_mut() automatically marks dirty
&mut self.sync_mgr.data_mut().descriptors[index]
}
}

impl Drop for RegDescList<'_> {
fn drop(&mut self) {
tracing::trace!("Dropping registration descriptor list");
Expand Down
18 changes: 3 additions & 15 deletions src/bindings/rust/src/descriptors/sync_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,10 @@ impl<T: BackendSyncable> SyncManager<T> {
}
}

/// Mutates the frontend data (marks as dirty)
pub fn modify<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
/// Provides mutable access to the frontend data
pub fn data_mut(&mut self) -> &mut T {
self.dirty.set(true);
f(&mut self.data)
}

/// Provides access to both data and backend after ensuring synchronization
pub fn with_backend<F, R>(&self, f: F) -> Result<R, T::Error>
where
F: FnOnce(&T, &T::Backend) -> R,
{
self.ensure_synced()?;
Ok(f(&self.data, &self.backend))
&mut self.data
}

/// Provides read-only access to the frontend data (no sync)
Expand Down
95 changes: 58 additions & 37 deletions src/bindings/rust/src/descriptors/xfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

use super::*;
use super::sync_manager::{BackendSyncable, SyncManager};
use std::ops::{Index, IndexMut};

/// Public transfer descriptor used for indexing and comparisons
#[derive(Debug, Clone, PartialEq)]
Expand Down Expand Up @@ -102,11 +103,8 @@ impl<'a> XferDescList<'a> {
pub fn get_type(&self) -> Result<MemType, NixlError> { Ok(self.mem_type) }

/// Adds a descriptor to the list
pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.push(XferDescriptor { addr, len, dev_id });
});
Ok(())
pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) {
self.sync_mgr.data_mut().descriptors.push(XferDescriptor { addr, len, dev_id });
}

/// Returns true if the list is empty
Expand All @@ -121,55 +119,60 @@ impl<'a> XferDescList<'a> {
pub fn len(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }

/// Trims the list to the given size
pub fn trim(&mut self) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.shrink_to_fit();
});
Ok(())
pub fn trim(&mut self) {
self.sync_mgr.data_mut().descriptors.shrink_to_fit();
}

/// Removes the descriptor at the given index
pub fn rem_desc(&mut self, index: i32) -> Result<(), NixlError> {
if index < 0 { return Err(NixlError::InvalidParam); }
let idx = index as usize;

self.sync_mgr.modify(|data| {
if idx >= data.descriptors.len() { return Err(NixlError::InvalidParam); }
data.descriptors.remove(idx);
Ok(())
})
let data = self.sync_mgr.data_mut();
if idx >= data.descriptors.len() {
return Err(NixlError::InvalidParam);
}
data.descriptors.remove(idx);
Ok(())
}

/// Clears all descriptors from the list
pub fn clear(&mut self) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.clear();
});
Ok(())
pub fn clear(&mut self) {
self.sync_mgr.data_mut().descriptors.clear();
}

/// Prints the list contents
pub fn print(&self) -> Result<(), NixlError> {
self.sync_mgr.with_backend(|_data, backend| {
let status = unsafe { nixl_capi_xfer_dlist_print(backend.as_ptr()) };
match status {
NIXL_CAPI_SUCCESS => Ok(()),
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
_ => Err(NixlError::BackendError),
}
})?
let backend = self.sync_mgr.backend()?;
let status = unsafe { nixl_capi_xfer_dlist_print(backend.as_ptr()) };
match status {
NIXL_CAPI_SUCCESS => Ok(()),
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
_ => Err(NixlError::BackendError),
}
}

/// Resizes the list to the given size
pub fn resize(&mut self, new_size: usize) -> Result<(), NixlError> {
self.sync_mgr.modify(|data| {
data.descriptors.resize(new_size, XferDescriptor {
addr: 0,
len: 0,
dev_id: 0,
});
pub fn resize(&mut self, new_size: usize) {
self.sync_mgr.data_mut().descriptors.resize(new_size, XferDescriptor {
addr: 0,
len: 0,
dev_id: 0,
});
Ok(())
}

/// Safe immutable access to descriptor by index
pub fn get(&self, index: usize) -> Result<&XferDescriptor, NixlError> {
self.sync_mgr.data().descriptors
.get(index)
.ok_or(NixlError::InvalidParam)
}

/// Safe mutable access to descriptor by index
pub fn get_mut(&mut self, index: usize) -> Result<&mut XferDescriptor, NixlError> {
self.sync_mgr.data_mut().descriptors
.get_mut(index)
.ok_or(NixlError::InvalidParam)
}

/// Add a descriptor from a type implementing NixlDescriptor
Expand All @@ -196,7 +199,8 @@ impl<'a> XferDescList<'a> {
let dev_id = desc.device_id();

// Add to list
self.add_desc(addr, len, dev_id)
self.add_desc(addr, len, dev_id);
Ok(())
}

pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_xfer_dlist_s {
Expand Down Expand Up @@ -230,6 +234,23 @@ impl PartialEq for XferDescList<'_> {
}
}

// Implement Index trait for immutable indexing (list[i])
impl Index<usize> for XferDescList<'_> {
type Output = XferDescriptor;

fn index(&self, index: usize) -> &Self::Output {
&self.sync_mgr.data().descriptors[index]
}
}

// Implement IndexMut trait for mutable indexing (list[i] = value)
impl IndexMut<usize> for XferDescList<'_> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
// data_mut() automatically marks dirty
&mut self.sync_mgr.data_mut().descriptors[index]
}
}

impl Drop for XferDescList<'_> {
fn drop(&mut self) {
if let Ok(backend) = self.sync_mgr.backend() {
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl RegistrationHandle {
);
let mut reg_dlist = RegDescList::new(self.mem_type)?;
unsafe {
reg_dlist.add_desc(self.ptr, self.size, self.dev_id)?;
reg_dlist.add_desc(self.ptr, self.size, self.dev_id);
let _opt_args = OptArgs::new().unwrap();
nixl_capi_deregister_mem(
agent.write().unwrap().handle.as_ptr(),
Expand Down
Loading