Skip to content

Commit 51f36b7

Browse files
RUST: RegDescList and XferDescList APIs (#828)
- equality operators - set/get - get index Signed-off-by: Evgeny Leksikov <evgenylek@nvidia.com>
1 parent acc7d8e commit 51f36b7

File tree

9 files changed

+476
-215
lines changed

9 files changed

+476
-215
lines changed

src/bindings/rust/src/descriptors.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ use super::*;
1717

1818
mod query;
1919
mod reg;
20+
mod sync_manager;
2021
mod xfer;
2122
mod xfer_dlist_handle;
2223

2324
pub use query::{QueryResponse, QueryResponseIterator, QueryResponseList};
24-
pub use reg::RegDescList;
25-
pub use xfer::XferDescList;
25+
pub use reg::{RegDescList, RegDescriptor};
26+
pub use sync_manager::{BackendSyncable, SyncManager};
27+
pub use xfer::{XferDescList, XferDescriptor};
2628
pub use xfer_dlist_handle::XferDlistHandle;
2729

2830
/// Memory types supported by NIXL

src/bindings/rust/src/descriptors/reg.rs

Lines changed: 139 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,64 @@
1414
// limitations under the License.
1515

1616
use super::*;
17+
use super::sync_manager::{BackendSyncable, SyncManager};
18+
19+
/// Public registration descriptor used for indexing and comparisons
20+
#[derive(Debug, Clone, PartialEq)]
21+
pub struct RegDescriptor {
22+
pub addr: usize,
23+
pub len: usize,
24+
pub dev_id: u64,
25+
pub metadata: Vec<u8>,
26+
}
27+
28+
/// Internal data structure for registration descriptors
29+
#[derive(Debug)]
30+
struct RegDescData {
31+
descriptors: Vec<RegDescriptor>,
32+
}
33+
34+
impl BackendSyncable for RegDescData {
35+
type Backend = NonNull<bindings::nixl_capi_reg_dlist_s>;
36+
type Error = NixlError;
37+
38+
fn sync_to_backend(&self, backend: &Self::Backend) -> Result<(), Self::Error> {
39+
// Clear backend
40+
let status = unsafe { nixl_capi_reg_dlist_clear(backend.as_ptr()) };
41+
match status {
42+
NIXL_CAPI_SUCCESS => {}
43+
NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
44+
_ => return Err(NixlError::BackendError),
45+
}
46+
47+
// Re-add all descriptors
48+
for desc in &self.descriptors {
49+
let status = unsafe {
50+
nixl_capi_reg_dlist_add_desc(
51+
backend.as_ptr(),
52+
desc.addr as uintptr_t,
53+
desc.len,
54+
desc.dev_id,
55+
desc.metadata.as_ptr() as *const std::ffi::c_void,
56+
desc.metadata.len(),
57+
)
58+
};
59+
match status {
60+
NIXL_CAPI_SUCCESS => {}
61+
NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
62+
_ => return Err(NixlError::BackendError),
63+
}
64+
}
65+
66+
Ok(())
67+
}
68+
}
1769

1870
/// A safe wrapper around a NIXL registration descriptor list
1971
pub struct RegDescList<'a> {
20-
inner: NonNull<bindings::nixl_capi_reg_dlist_s>,
72+
sync_mgr: SyncManager<RegDescData>,
2173
_phantom: PhantomData<&'a dyn NixlDescriptor>,
74+
mem_type: MemType,
2275
}
2376

2477
impl<'a> RegDescList<'a> {
@@ -35,26 +88,24 @@ impl<'a> RegDescList<'a> {
3588
tracing::error!("Failed to create registration descriptor list");
3689
return Err(NixlError::RegDescListCreationFailed);
3790
}
38-
let ptr = NonNull::new(dlist).ok_or(NixlError::RegDescListCreationFailed)?;
91+
let backend = NonNull::new(dlist).ok_or(NixlError::RegDescListCreationFailed)?;
92+
93+
let data = RegDescData {
94+
descriptors: Vec::new(),
95+
};
96+
let sync_mgr = SyncManager::new(data, backend);
3997

4098
Ok(Self {
41-
inner: ptr,
99+
sync_mgr,
42100
_phantom: PhantomData,
101+
mem_type,
43102
})
44103
}
45104
_ => Err(NixlError::RegDescListCreationFailed),
46105
}
47106
}
48107

49-
pub fn get_type(&self) -> Result<MemType, NixlError> {
50-
let mut mem_type = 0;
51-
let status = unsafe { nixl_capi_reg_dlist_get_type(self.inner.as_ptr(), &mut mem_type) };
52-
53-
match status {
54-
NIXL_CAPI_SUCCESS => Ok(MemType::from(mem_type)),
55-
_ => Err(NixlError::BackendError),
56-
}
57-
}
108+
pub fn get_type(&self) -> Result<MemType, NixlError> { Ok(self.mem_type) }
58109

59110
/// Adds a descriptor to the list
60111
pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) -> Result<(), NixlError> {
@@ -69,22 +120,15 @@ impl<'a> RegDescList<'a> {
69120
dev_id: u64,
70121
metadata: &[u8],
71122
) -> Result<(), NixlError> {
72-
let status = unsafe {
73-
nixl_capi_reg_dlist_add_desc(
74-
self.inner.as_ptr(),
75-
addr as uintptr_t,
123+
self.sync_mgr.modify(|data| {
124+
data.descriptors.push(RegDescriptor {
125+
addr,
76126
len,
77127
dev_id,
78-
metadata.as_ptr() as *const std::ffi::c_void,
79-
metadata.len(),
80-
)
81-
};
82-
83-
match status {
84-
NIXL_CAPI_SUCCESS => Ok(()),
85-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
86-
_ => Err(NixlError::BackendError),
87-
}
128+
metadata: metadata.to_vec(),
129+
});
130+
});
131+
Ok(())
88132
}
89133

90134
/// Returns true if the list is empty
@@ -93,81 +137,62 @@ impl<'a> RegDescList<'a> {
93137
}
94138

95139
/// Returns the number of descriptors in the list
96-
pub fn desc_count(&self) -> Result<usize, NixlError> {
97-
let mut count = 0;
98-
let status = unsafe { nixl_capi_reg_dlist_desc_count(self.inner.as_ptr(), &mut count) };
99-
100-
match status {
101-
NIXL_CAPI_SUCCESS => Ok(count),
102-
_ => Err(NixlError::BackendError),
103-
}
104-
}
140+
pub fn desc_count(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
105141

106142
/// Returns the number of descriptors in the list
107-
pub fn len(&self) -> Result<usize, NixlError> {
108-
let mut len = 0;
109-
let status = unsafe { nixl_capi_reg_dlist_len(self.inner.as_ptr(), &mut len) };
110-
111-
match status {
112-
NIXL_CAPI_SUCCESS => Ok(len),
113-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
114-
_ => Err(NixlError::BackendError),
115-
}
116-
}
143+
pub fn len(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
117144

118145
/// Trims the list to the given size
119146
pub fn trim(&mut self) -> Result<(), NixlError> {
120-
let status = unsafe { nixl_capi_reg_dlist_trim(self.inner.as_ptr()) };
121-
122-
match status {
123-
NIXL_CAPI_SUCCESS => Ok(()),
124-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
125-
_ => Err(NixlError::BackendError),
126-
}
147+
self.sync_mgr.modify(|data| {
148+
data.descriptors.shrink_to_fit();
149+
});
150+
Ok(())
127151
}
128152

129153
/// Removes the descriptor at the given index
130154
pub fn rem_desc(&mut self, index: i32) -> Result<(), NixlError> {
131-
let status = unsafe { nixl_capi_reg_dlist_rem_desc(self.inner.as_ptr(), index) };
155+
if index < 0 { return Err(NixlError::InvalidParam); }
156+
let idx = index as usize;
132157

133-
match status {
134-
NIXL_CAPI_SUCCESS => Ok(()),
135-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
136-
_ => Err(NixlError::BackendError),
137-
}
158+
self.sync_mgr.modify(|data| {
159+
if idx >= data.descriptors.len() { return Err(NixlError::InvalidParam); }
160+
data.descriptors.remove(idx);
161+
Ok(())
162+
})
138163
}
139164

140165
/// Prints the list contents
141166
pub fn print(&self) -> Result<(), NixlError> {
142-
let status = unsafe { nixl_capi_reg_dlist_print(self.inner.as_ptr()) };
143-
144-
match status {
145-
NIXL_CAPI_SUCCESS => Ok(()),
146-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
147-
_ => Err(NixlError::BackendError),
148-
}
167+
self.sync_mgr.with_backend(|_data, backend| {
168+
let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) };
169+
match status {
170+
NIXL_CAPI_SUCCESS => Ok(()),
171+
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
172+
_ => Err(NixlError::BackendError),
173+
}
174+
})?
149175
}
150176

151177
/// Clears all descriptors from the list
152178
pub fn clear(&mut self) -> Result<(), NixlError> {
153-
let status = unsafe { nixl_capi_reg_dlist_clear(self.inner.as_ptr()) };
154-
155-
match status {
156-
NIXL_CAPI_SUCCESS => Ok(()),
157-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
158-
_ => Err(NixlError::BackendError),
159-
}
179+
self.sync_mgr.modify(|data| {
180+
data.descriptors.clear();
181+
});
182+
Ok(())
160183
}
161184

162185
/// Resizes the list to the given size
163186
pub fn resize(&mut self, new_size: usize) -> Result<(), NixlError> {
164-
let status = unsafe { nixl_capi_reg_dlist_resize(self.inner.as_ptr(), new_size) };
165-
166-
match status {
167-
NIXL_CAPI_SUCCESS => Ok(()),
168-
NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
169-
_ => Err(NixlError::BackendError),
170-
}
187+
self.sync_mgr.modify(|data| {
188+
data.descriptors.resize(new_size, RegDescriptor {
189+
addr: 0,
190+
len: 0,
191+
dev_id: 0,
192+
metadata: Vec::new(),
193+
});
194+
});
195+
Ok(())
171196
}
172197

173198
/// Add a descriptor from a type implementing NixlDescriptor
@@ -179,19 +204,10 @@ impl<'a> RegDescList<'a> {
179204
pub fn add_storage_desc(&mut self, desc: &'a dyn NixlDescriptor) -> Result<(), NixlError> {
180205
// Validate memory type matches
181206
let desc_mem_type = desc.mem_type();
182-
let list_mem_type = unsafe {
183-
// Get the memory type from the list by checking first descriptor
184-
let mut len = 0;
185-
match nixl_capi_reg_dlist_len(self.inner.as_ptr(), &mut len) {
186-
0 => Ok(()),
187-
-1 => Err(NixlError::InvalidParam),
188-
_ => Err(NixlError::BackendError),
189-
}?;
190-
if len > 0 {
191-
self.get_type()?
192-
} else {
193-
desc_mem_type
194-
}
207+
let list_mem_type = if self.len()? > 0 {
208+
self.get_type()?
209+
} else {
210+
desc_mem_type
195211
};
196212

197213
if desc_mem_type != list_mem_type && list_mem_type != MemType::Unknown {
@@ -208,15 +224,43 @@ impl<'a> RegDescList<'a> {
208224
}
209225

210226
pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_reg_dlist_s {
211-
self.inner.as_ptr()
227+
self.sync_mgr.backend().map(|b| b.as_ptr()).unwrap_or(ptr::null_mut())
228+
}
229+
}
230+
231+
impl std::fmt::Debug for RegDescList<'_> {
232+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233+
let mem_type = self.get_type().unwrap_or(MemType::Unknown);
234+
let len = self.len().unwrap_or(0);
235+
let desc_count = self.desc_count().unwrap_or(0);
236+
237+
f.debug_struct("RegDescList")
238+
.field("mem_type", &mem_type)
239+
.field("len", &len)
240+
.field("desc_count", &desc_count)
241+
.finish()
242+
}
243+
}
244+
245+
impl PartialEq for RegDescList<'_> {
246+
fn eq(&self, other: &Self) -> bool {
247+
// Compare memory types first
248+
if self.mem_type != other.mem_type {
249+
return false;
250+
}
251+
252+
// Compare internal descriptor tracking
253+
self.sync_mgr.data().descriptors == other.sync_mgr.data().descriptors
212254
}
213255
}
214256

215257
impl Drop for RegDescList<'_> {
216258
fn drop(&mut self) {
217259
tracing::trace!("Dropping registration descriptor list");
218-
unsafe {
219-
nixl_capi_destroy_reg_dlist(self.inner.as_ptr());
260+
if let Ok(backend) = self.sync_mgr.backend() {
261+
unsafe {
262+
nixl_capi_destroy_reg_dlist(backend.as_ptr());
263+
}
220264
}
221265
tracing::trace!("Registration descriptor list dropped");
222266
}

0 commit comments

Comments
 (0)