diff --git a/src/bindings/rust/src/descriptors/reg.rs b/src/bindings/rust/src/descriptors/reg.rs index f2462128c..8986fb901 100644 --- a/src/bindings/rust/src/descriptors/reg.rs +++ b/src/bindings/rust/src/descriptors/reg.rs @@ -15,6 +15,7 @@ use super::*; use super::sync_manager::{BackendSyncable, SyncManager}; +use std::ops::{Index, IndexMut}; /// Public registration descriptor used for indexing and comparisons #[derive(Debug, Clone, PartialEq)] @@ -108,7 +109,7 @@ impl<'a> RegDescList<'a> { pub fn get_type(&self) -> Result { Ok(self.mem_type) } /// Adds a descriptor to the list - pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) -> Result<(), NixlError> { + pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) { self.add_desc_with_meta(addr, len, dev_id, &[]) } @@ -119,16 +120,13 @@ impl<'a> RegDescList<'a> { len: usize, dev_id: u64, metadata: &[u8], - ) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.push(RegDescriptor { - addr, - len, - dev_id, - metadata: metadata.to_vec(), - }); + ) { + self.sync_mgr.data_mut().descriptors.push(RegDescriptor { + addr, + len, + dev_id, + metadata: metadata.to_vec(), }); - Ok(()) } /// Returns true if the list is empty @@ -143,11 +141,8 @@ impl<'a> RegDescList<'a> { pub fn len(&self) -> Result { Ok(self.sync_mgr.data().descriptors.len()) } /// Trims the list to the given size - pub fn trim(&mut self) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.shrink_to_fit(); - }); - Ok(()) + pub fn trim(&mut self) { + self.sync_mgr.data_mut().descriptors.shrink_to_fit(); } /// Removes the descriptor at the given index @@ -155,44 +150,52 @@ impl<'a> RegDescList<'a> { if index < 0 { return Err(NixlError::InvalidParam); } let idx = index as usize; - self.sync_mgr.modify(|data| { - if idx >= data.descriptors.len() { return Err(NixlError::InvalidParam); } - data.descriptors.remove(idx); - Ok(()) - }) + let data = self.sync_mgr.data_mut(); + if idx >= data.descriptors.len() { + return Err(NixlError::InvalidParam); + } + data.descriptors.remove(idx); + Ok(()) } /// Prints the list contents pub fn print(&self) -> Result<(), NixlError> { - self.sync_mgr.with_backend(|_data, backend| { - let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) }; - match status { - NIXL_CAPI_SUCCESS => Ok(()), - NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam), - _ => Err(NixlError::BackendError), - } - })? + let backend = self.sync_mgr.backend()?; + let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) }; + match status { + NIXL_CAPI_SUCCESS => Ok(()), + NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam), + _ => Err(NixlError::BackendError), + } } /// Clears all descriptors from the list - pub fn clear(&mut self) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.clear(); - }); - Ok(()) + pub fn clear(&mut self) { + self.sync_mgr.data_mut().descriptors.clear(); } /// Resizes the list to the given size - pub fn resize(&mut self, new_size: usize) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.resize(new_size, RegDescriptor { - addr: 0, - len: 0, - dev_id: 0, - metadata: Vec::new(), - }); + pub fn resize(&mut self, new_size: usize) { + self.sync_mgr.data_mut().descriptors.resize(new_size, RegDescriptor { + addr: 0, + len: 0, + dev_id: 0, + metadata: Vec::new(), }); - Ok(()) + } + + /// Safe immutable access to descriptor by index + pub fn get(&self, index: usize) -> Result<&RegDescriptor, NixlError> { + self.sync_mgr.data().descriptors + .get(index) + .ok_or(NixlError::InvalidParam) + } + + /// Safe mutable access to descriptor by index + pub fn get_mut(&mut self, index: usize) -> Result<&mut RegDescriptor, NixlError> { + self.sync_mgr.data_mut().descriptors + .get_mut(index) + .ok_or(NixlError::InvalidParam) } /// Add a descriptor from a type implementing NixlDescriptor @@ -220,7 +223,8 @@ impl<'a> RegDescList<'a> { let dev_id = desc.device_id(); // Add to list - self.add_desc(addr, len, dev_id) + self.add_desc(addr, len, dev_id); + Ok(()) } pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_reg_dlist_s { @@ -254,6 +258,23 @@ impl PartialEq for RegDescList<'_> { } } +// Implement Index trait for immutable indexing (list[i]) +impl Index for RegDescList<'_> { + type Output = RegDescriptor; + + fn index(&self, index: usize) -> &Self::Output { + &self.sync_mgr.data().descriptors[index] + } +} + +// Implement IndexMut trait for mutable indexing (list[i] = value) +impl IndexMut for RegDescList<'_> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + // data_mut() automatically marks dirty + &mut self.sync_mgr.data_mut().descriptors[index] + } +} + impl Drop for RegDescList<'_> { fn drop(&mut self) { tracing::trace!("Dropping registration descriptor list"); diff --git a/src/bindings/rust/src/descriptors/sync_manager.rs b/src/bindings/rust/src/descriptors/sync_manager.rs index 5337bf245..91435e4a4 100644 --- a/src/bindings/rust/src/descriptors/sync_manager.rs +++ b/src/bindings/rust/src/descriptors/sync_manager.rs @@ -29,22 +29,10 @@ impl SyncManager { } } - /// Mutates the frontend data (marks as dirty) - pub fn modify(&mut self, f: F) -> R - where - F: FnOnce(&mut T) -> R, - { + /// Provides mutable access to the frontend data + pub fn data_mut(&mut self) -> &mut T { self.dirty.set(true); - f(&mut self.data) - } - - /// Provides access to both data and backend after ensuring synchronization - pub fn with_backend(&self, f: F) -> Result - where - F: FnOnce(&T, &T::Backend) -> R, - { - self.ensure_synced()?; - Ok(f(&self.data, &self.backend)) + &mut self.data } /// Provides read-only access to the frontend data (no sync) diff --git a/src/bindings/rust/src/descriptors/xfer.rs b/src/bindings/rust/src/descriptors/xfer.rs index b34cb6e27..956f2eb77 100644 --- a/src/bindings/rust/src/descriptors/xfer.rs +++ b/src/bindings/rust/src/descriptors/xfer.rs @@ -15,6 +15,7 @@ use super::*; use super::sync_manager::{BackendSyncable, SyncManager}; +use std::ops::{Index, IndexMut}; /// Public transfer descriptor used for indexing and comparisons #[derive(Debug, Clone, PartialEq)] @@ -102,11 +103,8 @@ impl<'a> XferDescList<'a> { pub fn get_type(&self) -> Result { Ok(self.mem_type) } /// Adds a descriptor to the list - pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.push(XferDescriptor { addr, len, dev_id }); - }); - Ok(()) + pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) { + self.sync_mgr.data_mut().descriptors.push(XferDescriptor { addr, len, dev_id }); } /// Returns true if the list is empty @@ -121,11 +119,8 @@ impl<'a> XferDescList<'a> { pub fn len(&self) -> Result { Ok(self.sync_mgr.data().descriptors.len()) } /// Trims the list to the given size - pub fn trim(&mut self) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.shrink_to_fit(); - }); - Ok(()) + pub fn trim(&mut self) { + self.sync_mgr.data_mut().descriptors.shrink_to_fit(); } /// Removes the descriptor at the given index @@ -133,43 +128,51 @@ impl<'a> XferDescList<'a> { if index < 0 { return Err(NixlError::InvalidParam); } let idx = index as usize; - self.sync_mgr.modify(|data| { - if idx >= data.descriptors.len() { return Err(NixlError::InvalidParam); } - data.descriptors.remove(idx); - Ok(()) - }) + let data = self.sync_mgr.data_mut(); + if idx >= data.descriptors.len() { + return Err(NixlError::InvalidParam); + } + data.descriptors.remove(idx); + Ok(()) } /// Clears all descriptors from the list - pub fn clear(&mut self) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.clear(); - }); - Ok(()) + pub fn clear(&mut self) { + self.sync_mgr.data_mut().descriptors.clear(); } /// Prints the list contents pub fn print(&self) -> Result<(), NixlError> { - self.sync_mgr.with_backend(|_data, backend| { - let status = unsafe { nixl_capi_xfer_dlist_print(backend.as_ptr()) }; - match status { - NIXL_CAPI_SUCCESS => Ok(()), - NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam), - _ => Err(NixlError::BackendError), - } - })? + let backend = self.sync_mgr.backend()?; + let status = unsafe { nixl_capi_xfer_dlist_print(backend.as_ptr()) }; + match status { + NIXL_CAPI_SUCCESS => Ok(()), + NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam), + _ => Err(NixlError::BackendError), + } } /// Resizes the list to the given size - pub fn resize(&mut self, new_size: usize) -> Result<(), NixlError> { - self.sync_mgr.modify(|data| { - data.descriptors.resize(new_size, XferDescriptor { - addr: 0, - len: 0, - dev_id: 0, - }); + pub fn resize(&mut self, new_size: usize) { + self.sync_mgr.data_mut().descriptors.resize(new_size, XferDescriptor { + addr: 0, + len: 0, + dev_id: 0, }); - Ok(()) + } + + /// Safe immutable access to descriptor by index + pub fn get(&self, index: usize) -> Result<&XferDescriptor, NixlError> { + self.sync_mgr.data().descriptors + .get(index) + .ok_or(NixlError::InvalidParam) + } + + /// Safe mutable access to descriptor by index + pub fn get_mut(&mut self, index: usize) -> Result<&mut XferDescriptor, NixlError> { + self.sync_mgr.data_mut().descriptors + .get_mut(index) + .ok_or(NixlError::InvalidParam) } /// Add a descriptor from a type implementing NixlDescriptor @@ -196,7 +199,8 @@ impl<'a> XferDescList<'a> { let dev_id = desc.device_id(); // Add to list - self.add_desc(addr, len, dev_id) + self.add_desc(addr, len, dev_id); + Ok(()) } pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_xfer_dlist_s { @@ -230,6 +234,23 @@ impl PartialEq for XferDescList<'_> { } } +// Implement Index trait for immutable indexing (list[i]) +impl Index for XferDescList<'_> { + type Output = XferDescriptor; + + fn index(&self, index: usize) -> &Self::Output { + &self.sync_mgr.data().descriptors[index] + } +} + +// Implement IndexMut trait for mutable indexing (list[i] = value) +impl IndexMut for XferDescList<'_> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + // data_mut() automatically marks dirty + &mut self.sync_mgr.data_mut().descriptors[index] + } +} + impl Drop for XferDescList<'_> { fn drop(&mut self) { if let Ok(backend) = self.sync_mgr.backend() { diff --git a/src/bindings/rust/src/lib.rs b/src/bindings/rust/src/lib.rs index adce69543..48d0daedb 100644 --- a/src/bindings/rust/src/lib.rs +++ b/src/bindings/rust/src/lib.rs @@ -159,7 +159,7 @@ impl RegistrationHandle { ); let mut reg_dlist = RegDescList::new(self.mem_type)?; unsafe { - reg_dlist.add_desc(self.ptr, self.size, self.dev_id)?; + reg_dlist.add_desc(self.ptr, self.size, self.dev_id); let _opt_args = OptArgs::new().unwrap(); nixl_capi_deregister_mem( agent.write().unwrap().handle.as_ptr(), diff --git a/src/bindings/rust/tests/tests.rs b/src/bindings/rust/tests/tests.rs index c8a30bedc..88ea08da6 100644 --- a/src/bindings/rust/tests/tests.rs +++ b/src/bindings/rust/tests/tests.rs @@ -52,6 +52,41 @@ fn create_agent_with_backend(name: &str) -> Result<(Agent, OptArgs), NixlError> Ok((agent, opt_args)) } +// Trait for testing common descriptor list operations +trait DescListTestTrait: PartialEq + std::fmt::Debug { + fn new(mem_type: MemType) -> Result where Self: Sized; + fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64); + #[allow(dead_code)] + fn len(&self) -> Result; +} + +impl<'a> DescListTestTrait for RegDescList<'a> { + fn new(mem_type: MemType) -> Result { + RegDescList::new(mem_type) + } + + fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) { + RegDescList::add_desc(self, addr, len, dev_id) + } + + fn len(&self) -> Result { + RegDescList::len(self) + } +} + +impl<'a> DescListTestTrait for XferDescList<'a> { + fn new(mem_type: MemType) -> Result { + XferDescList::new(mem_type) + } + + fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) { + XferDescList::add_desc(self, addr, len, dev_id) + } + + fn len(&self) -> Result { + XferDescList::len(self) + } +} fn create_storage_list(agent: &Agent, opt_args: &OptArgs, size: usize) -> Vec { let mut storage_list = Vec::new(); @@ -281,18 +316,18 @@ fn test_xfer_dlist() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); // Add some descriptors - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.add_desc(0x2000, 0x200, 1).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.add_desc(0x2000, 0x200, 1); // Check length assert_eq!(dlist.len().unwrap(), 2); // Clear list - dlist.clear().unwrap(); + dlist.clear(); assert_eq!(dlist.len().unwrap(), 0); // Resize list - dlist.resize(5).unwrap(); + dlist.resize(5); } #[test] @@ -300,18 +335,18 @@ fn test_reg_dlist() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); // Add some descriptors - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.add_desc(0x2000, 0x200, 1).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.add_desc(0x2000, 0x200, 1); // Check length assert_eq!(dlist.len().unwrap(), 2); // Clear list - dlist.clear().unwrap(); + dlist.clear(); assert_eq!(dlist.len().unwrap(), 0); // Resize list - dlist.resize(5).unwrap(); + dlist.resize(5); } #[test] @@ -740,8 +775,7 @@ fn test_check_remote_metadata() { unsafe { storage.as_ptr() } as usize, storage.size(), storage.device_id(), - ) - .expect("Failed to add descriptor"); + ); // Update metadata after registration let metadata = agent2 @@ -758,8 +792,7 @@ fn test_check_remote_metadata() { let mut invalid_desc_list = XferDescList::new(mem_type).expect("Failed to create invalid desc list"); invalid_desc_list - .add_desc(0xdeadbeef, 1024, 0) - .expect("Failed to add invalid descriptor"); + .add_desc(0xdeadbeef, 1024, 0); // Check with invalid descriptor list - should return false assert!(!agent1.check_remote_metadata("agent2", Some(&invalid_desc_list))); @@ -788,7 +821,7 @@ fn test_xfer_desc_list_get_type() { #[test] fn test_xfer_desc_list_get_type_after_add() { let mut dlist = XferDescList::new(MemType::Block).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert_eq!(dlist.get_type().unwrap(), MemType::Block); } @@ -796,15 +829,15 @@ fn test_xfer_desc_list_get_type_after_add() { fn test_xfer_desc_list_desc_count_basic() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); assert_eq!(dlist.desc_count().unwrap(), 0); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert_eq!(dlist.desc_count().unwrap(), 1); } #[test] fn test_xfer_desc_list_desc_count_after_clear() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.clear().unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.clear(); assert_eq!(dlist.desc_count().unwrap(), 0); } @@ -817,29 +850,29 @@ fn test_xfer_desc_list_is_empty_true() { #[test] fn test_xfer_desc_list_is_empty_false() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert!(!dlist.is_empty().unwrap()); } #[test] fn test_xfer_desc_list_trim_basic() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.trim().unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.trim(); assert!(dlist.desc_count().unwrap() <= 1); } #[test] fn test_xfer_desc_list_trim_empty() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - assert!(dlist.trim().is_ok()); + dlist.trim(); assert!(dlist.is_empty().unwrap()); } #[test] fn test_xfer_desc_list_rem_desc_basic() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert!(dlist.rem_desc(0).is_ok()); assert!(dlist.is_empty().unwrap()); } @@ -853,15 +886,15 @@ fn test_xfer_desc_list_rem_desc_out_of_bounds() { #[test] fn test_xfer_desc_list_clear_basic() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.clear().unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.clear(); assert!(dlist.is_empty().unwrap()); } #[test] fn test_xfer_desc_list_clear_empty() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - assert!(dlist.clear().is_ok()); + dlist.clear(); assert!(dlist.is_empty().unwrap()); } @@ -874,7 +907,7 @@ fn test_xfer_desc_list_print_basic() { #[test] fn test_xfer_desc_list_print_after_add() { let mut dlist = XferDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert!(dlist.print().is_ok()); } @@ -895,7 +928,7 @@ fn test_reg_desc_list_get_type() { #[test] fn test_reg_desc_list_get_type_after_add() { let mut dlist = RegDescList::new(MemType::Block).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert_eq!(dlist.get_type().unwrap(), MemType::Block); } @@ -903,15 +936,15 @@ fn test_reg_desc_list_get_type_after_add() { fn test_reg_desc_list_desc_count_basic() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); assert_eq!(dlist.desc_count().unwrap(), 0); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert_eq!(dlist.desc_count().unwrap(), 1); } #[test] fn test_reg_desc_list_desc_count_after_clear() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.clear().unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.clear(); assert_eq!(dlist.desc_count().unwrap(), 0); } @@ -924,29 +957,29 @@ fn test_reg_desc_list_is_empty_true() { #[test] fn test_reg_desc_list_is_empty_false() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert!(!dlist.is_empty().unwrap()); } #[test] fn test_reg_desc_list_trim_basic() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.trim().unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.trim(); assert!(dlist.desc_count().unwrap() <= 1); } #[test] fn test_reg_desc_list_trim_empty() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - assert!(dlist.trim().is_ok()); + dlist.trim(); assert!(dlist.is_empty().unwrap()); } #[test] fn test_reg_desc_list_rem_desc_basic() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert!(dlist.rem_desc(0).is_ok()); assert!(dlist.is_empty().unwrap()); } @@ -960,15 +993,15 @@ fn test_reg_desc_list_rem_desc_out_of_bounds() { #[test] fn test_reg_desc_list_clear_basic() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); - dlist.clear().unwrap(); + dlist.add_desc(0x1000, 0x100, 0); + dlist.clear(); assert!(dlist.is_empty().unwrap()); } #[test] fn test_reg_desc_list_clear_empty() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - assert!(dlist.clear().is_ok()); + dlist.clear(); assert!(dlist.is_empty().unwrap()); } @@ -981,7 +1014,7 @@ fn test_reg_desc_list_print_basic() { #[test] fn test_reg_desc_list_print_after_add() { let mut dlist = RegDescList::new(MemType::Dram).unwrap(); - dlist.add_desc(0x1000, 0x100, 0).unwrap(); + dlist.add_desc(0x1000, 0x100, 0); assert!(dlist.print().is_ok()); } @@ -1038,15 +1071,13 @@ fn test_query_mem_with_files() { RegDescList::new(MemType::File).expect("Failed to create descriptor list"); // Add blob descriptors with filenames as metadata - for (i, file_path) in file_paths.iter().enumerate() { - descs - .add_desc_with_meta( - DESCRIPTOR_ADDR, - DESCRIPTOR_SIZE, - DESCRIPTOR_DEV_ID, - file_path.to_string_lossy().as_bytes(), - ) - .expect(&format!("Failed to add descriptor for file {}", i + 1)); + for file_path in &file_paths { + descs.add_desc_with_meta( + DESCRIPTOR_ADDR, + DESCRIPTOR_SIZE, + DESCRIPTOR_DEV_ID, + file_path.to_string_lossy().as_bytes(), + ); } // Query memory @@ -1259,8 +1290,7 @@ fn test_get_local_partial_md_success() { // Create a registration descriptor list let mut reg_descs = RegDescList::new(MemType::Dram) .expect("Failed to create registration descriptor list"); - reg_descs.add_desc(0x1000, 0x100, 0) - .expect("Failed to add descriptor"); + reg_descs.add_desc(0x1000, 0x100, 0); // Get local partial metadata let result = agent.get_local_partial_md(®_descs, Some(&opt_args)); // Should succeed and return metadata @@ -1587,133 +1617,243 @@ fn test_get_xfer_telemetry_before_posting() { } // Tests for equality operators on RegDescList and XferDescList + #[test] -fn test_descriptor_list_equality() { - // Test RegDescList equality scenarios - { - // 1. Empty lists of same type should be equal - let reg_list1 = RegDescList::new(MemType::Dram).unwrap(); - let reg_list2 = RegDescList::new(MemType::Dram).unwrap(); - assert_eq!(reg_list1, reg_list2); - assert!(!(reg_list1 != reg_list2)); +fn test_desc_list_equality_empty() { + fn test_impl() { + let list1 = T::new(MemType::Dram).unwrap(); + let list2 = T::new(MemType::Dram).unwrap(); + assert_eq!(list1, list2); + assert!(!(list1 != list2)); + } - // 2. Lists with different memory types should not be equal - let reg_list_vram = RegDescList::new(MemType::Vram).unwrap(); - assert_ne!(reg_list1, reg_list_vram); - assert!(!(reg_list1 == reg_list_vram)); + test_impl::(); + test_impl::(); +} - // 3. Lists with same descriptors should be equal - let mut reg_list3 = RegDescList::new(MemType::Dram).unwrap(); - let mut reg_list4 = RegDescList::new(MemType::Dram).unwrap(); +#[test] +fn test_desc_list_equality_memory_types() { + fn test_impl() { + let list_dram = T::new(MemType::Dram).unwrap(); + let list_vram = T::new(MemType::Vram).unwrap(); + assert_ne!(list_dram, list_vram); + assert!(!(list_dram == list_vram)); + } - reg_list3.add_desc(0x1000, 0x100, 0).unwrap(); - reg_list3.add_desc(0x2000, 0x200, 1).unwrap(); + test_impl::(); + test_impl::(); +} - reg_list4.add_desc(0x1000, 0x100, 0).unwrap(); - reg_list4.add_desc(0x2000, 0x200, 1).unwrap(); +#[test] +fn test_desc_list_equality_same_descriptors() { + fn test_impl() { + let mut list1 = T::new(MemType::Dram).unwrap(); + let mut list2 = T::new(MemType::Dram).unwrap(); - assert_eq!(reg_list3, reg_list4); - assert!(!(reg_list3 != reg_list4)); + list1.add_desc(0x1000, 0x100, 0); + list1.add_desc(0x2000, 0x200, 1); - // 4. Lists with different descriptors should not be equal - let mut reg_list5 = RegDescList::new(MemType::Dram).unwrap(); - let mut reg_list6 = RegDescList::new(MemType::Dram).unwrap(); + list2.add_desc(0x1000, 0x100, 0); + list2.add_desc(0x2000, 0x200, 1); - reg_list5.add_desc(0x1000, 0x100, 0).unwrap(); - reg_list6.add_desc(0x2000, 0x200, 1).unwrap(); + assert_eq!(list1, list2); + assert!(!(list1 != list2)); + } - assert_ne!(reg_list5, reg_list6); - assert!(!(reg_list5 == reg_list6)); + test_impl::(); + test_impl::(); +} - // 5. Lists with different lengths should not be equal - let mut reg_list7 = RegDescList::new(MemType::Dram).unwrap(); - let reg_list8 = RegDescList::new(MemType::Dram).unwrap(); +#[test] +fn test_desc_list_equality_different_descriptors() { + fn test_impl() { + let mut list1 = T::new(MemType::Dram).unwrap(); + let mut list2 = T::new(MemType::Dram).unwrap(); - reg_list7.add_desc(0x1000, 0x100, 0).unwrap(); + list1.add_desc(0x1000, 0x100, 0); + list2.add_desc(0x2000, 0x200, 1); - assert_ne!(reg_list7, reg_list8); - assert!(!(reg_list7 == reg_list8)); + assert_ne!(list1, list2); + assert!(!(list1 == list2)); + } - // 6. Lists with same descriptors but different order should not be equal - let mut reg_list9 = RegDescList::new(MemType::Dram).unwrap(); - let mut reg_list10 = RegDescList::new(MemType::Dram).unwrap(); + test_impl::(); + test_impl::(); +} - reg_list9.add_desc(0x1000, 0x100, 0).unwrap(); - reg_list9.add_desc(0x2000, 0x200, 1).unwrap(); +#[test] +fn test_desc_list_equality_different_lengths() { + fn test_impl() { + let mut list1 = T::new(MemType::Dram).unwrap(); + let list2 = T::new(MemType::Dram).unwrap(); - reg_list10.add_desc(0x2000, 0x200, 1).unwrap(); - reg_list10.add_desc(0x1000, 0x100, 0).unwrap(); + list1.add_desc(0x1000, 0x100, 0); - assert_ne!(reg_list9, reg_list10); - assert!(!(reg_list9 == reg_list10)); + assert_ne!(list1, list2); + assert!(!(list1 == list2)); + } - // 7. Lists with same descriptors but different metadata should not be equal - let mut reg_list11 = RegDescList::new(MemType::Dram).unwrap(); - let mut reg_list12 = RegDescList::new(MemType::Dram).unwrap(); + test_impl::(); + test_impl::(); +} - reg_list11.add_desc_with_meta(0x1000, 0x100, 0, b"metadata1").unwrap(); - reg_list12.add_desc_with_meta(0x1000, 0x100, 0, b"metadata2").unwrap(); +#[test] +fn test_desc_list_equality_order_matters() { + fn test_impl() { + let mut list1 = T::new(MemType::Dram).unwrap(); + let mut list2 = T::new(MemType::Dram).unwrap(); - assert_ne!(reg_list11, reg_list12); - assert!(!(reg_list11 == reg_list12)); + list1.add_desc(0x1000, 0x100, 0); + list1.add_desc(0x2000, 0x200, 1); + + list2.add_desc(0x2000, 0x200, 1); + list2.add_desc(0x1000, 0x100, 0); + + assert_ne!(list1, list2); + assert!(!(list1 == list2)); } - // Test XferDescList equality scenarios (same tests as RegDescList) - { - // 1. Empty lists of same type should be equal - let xfer_list1 = XferDescList::new(MemType::Dram).unwrap(); - let xfer_list2 = XferDescList::new(MemType::Dram).unwrap(); - assert_eq!(xfer_list1, xfer_list2); - assert!(!(xfer_list1 != xfer_list2)); + test_impl::(); + test_impl::(); +} + +// RegDescList-specific test: metadata affects equality +#[test] +fn test_reg_desc_list_equality_metadata() { + let mut list1 = RegDescList::new(MemType::Dram).unwrap(); + let mut list2 = RegDescList::new(MemType::Dram).unwrap(); + + list1.add_desc_with_meta(0x1000, 0x100, 0, b"metadata1"); + list2.add_desc_with_meta(0x1000, 0x100, 0, b"metadata2"); + + assert_ne!(list1, list2); + assert!(!(list1 == list2)); +} + +// Tests for Index trait (immutable indexing) +#[test] +fn test_desc_list_immutable_index_access() { + macro_rules! test_impl { + ($list_type:ty, $mem_type:expr) => {{ + let mut list = <$list_type>::new($mem_type).unwrap(); + list.add_desc(0x1000, 0x100, 0); + list.add_desc(0x2000, 0x200, 1); + + // Test indexing - direct field access + assert_eq!(list[0].addr, 0x1000); + assert_eq!(list[0].len, 0x100); + assert_eq!(list[0].dev_id, 0); + + assert_eq!(list[1].addr, 0x2000); + assert_eq!(list[1].len, 0x200); + assert_eq!(list[1].dev_id, 1); + }}; + } + + test_impl!(RegDescList, MemType::Dram); + test_impl!(XferDescList, MemType::Vram); +} + +// Test for IndexMut trait (mutable indexing) +#[test] +fn test_desc_list_mutable_index_modification() { + macro_rules! test_impl { + ($list_type:ty, $mem_type:expr, $new_addr:expr, $new_len:expr, $new_dev_id:expr) => {{ + let mut list = <$list_type>::new($mem_type).unwrap(); + list.add_desc(0x1000, 0x100, 0); + list.add_desc(0x2000, 0x200, 1); - // 2. Lists with different memory types should not be equal - let xfer_list_vram = XferDescList::new(MemType::Vram).unwrap(); - assert_ne!(xfer_list1, xfer_list_vram); - assert!(!(xfer_list1 == xfer_list_vram)); + // Mutate via index - direct field access + list[0].addr = $new_addr; + list[0].len = $new_len; + list[1].dev_id = $new_dev_id; - // 3. Lists with same descriptors should be equal - let mut xfer_list3 = XferDescList::new(MemType::Dram).unwrap(); - let mut xfer_list4 = XferDescList::new(MemType::Dram).unwrap(); + // Verify changes + assert_eq!(list[0].addr, $new_addr); + assert_eq!(list[0].len, $new_len); + assert_eq!(list[1].dev_id, $new_dev_id); - xfer_list3.add_desc(0x1000, 0x100, 0).unwrap(); - xfer_list3.add_desc(0x2000, 0x200, 1).unwrap(); + // Verify the list still has 2 elements + assert_eq!(list.len().unwrap(), 2); + }}; + } - xfer_list4.add_desc(0x1000, 0x100, 0).unwrap(); - xfer_list4.add_desc(0x2000, 0x200, 1).unwrap(); + test_impl!(RegDescList, MemType::Dram, 0x3000, 0x300, 42); + test_impl!(XferDescList, MemType::Vram, 0x4000, 0x400, 99); +} - assert_eq!(xfer_list3, xfer_list4); - assert!(!(xfer_list3 != xfer_list4)); +// Test out-of-bounds indexing panics (expected behavior) +#[test] +#[should_panic] +fn test_desc_list_immutable_index_panics_on_out_of_bounds() { + let list = RegDescList::new(MemType::Dram).unwrap(); + let _ = &list[0]; // Should panic - empty list +} - // 4. Lists with different descriptors should not be equal - let mut xfer_list5 = XferDescList::new(MemType::Dram).unwrap(); - let mut xfer_list6 = XferDescList::new(MemType::Dram).unwrap(); +#[test] +#[should_panic] +fn test_desc_list_mutable_index_panics_on_out_of_bounds() { + let mut list = XferDescList::new(MemType::Dram).unwrap(); + list.add_desc(0x1000, 0x100, 0); + list[5].addr = 0x9999; // Should panic - index 5 doesn't exist +} - xfer_list5.add_desc(0x1000, 0x100, 0).unwrap(); - xfer_list6.add_desc(0x2000, 0x200, 1).unwrap(); +// Tests for safe get() method +#[test] +fn test_desc_list_safe_get_method() { + // Test RegDescList + let mut reg_list = RegDescList::new(MemType::Dram).unwrap(); + reg_list.add_desc(0x1000, 0x100, 0); + reg_list.add_desc(0x2000, 0x200, 1); - assert_ne!(xfer_list5, xfer_list6); - assert!(!(xfer_list5 == xfer_list6)); + // Valid access + let desc = reg_list.get(0).unwrap(); + assert_eq!(desc.addr, 0x1000); + assert_eq!(desc.len, 0x100); - // 5. Lists with different lengths should not be equal - let mut xfer_list7 = XferDescList::new(MemType::Dram).unwrap(); - let xfer_list8 = XferDescList::new(MemType::Dram).unwrap(); + // Out of bounds should return error + assert!(reg_list.get(10).is_err()); - xfer_list7.add_desc(0x1000, 0x100, 0).unwrap(); + // Test XferDescList + let mut xfer_list = XferDescList::new(MemType::Vram).unwrap(); + xfer_list.add_desc(0x3000, 0x300, 2); - assert_ne!(xfer_list7, xfer_list8); - assert!(!(xfer_list7 == xfer_list8)); + let desc = xfer_list.get(0).unwrap(); + assert_eq!(desc.addr, 0x3000); + assert_eq!(desc.len, 0x300); - // 6. Lists with same descriptors but different order should not be equal - let mut xfer_list9 = XferDescList::new(MemType::Dram).unwrap(); - let mut xfer_list10 = XferDescList::new(MemType::Dram).unwrap(); + assert!(xfer_list.get(5).is_err()); +} - xfer_list9.add_desc(0x1000, 0x100, 0).unwrap(); - xfer_list9.add_desc(0x2000, 0x200, 1).unwrap(); +// Tests for safe get_mut() method +#[test] +fn test_desc_list_safe_get_mut_method() { + // Test RegDescList + let mut reg_list = RegDescList::new(MemType::Dram).unwrap(); + reg_list.add_desc(0x1000, 0x100, 0); + + // Valid mutable access + { + let desc = reg_list.get_mut(0).unwrap(); + desc.addr = 0x5000; + desc.len = 0x500; + } - xfer_list10.add_desc(0x2000, 0x200, 1).unwrap(); - xfer_list10.add_desc(0x1000, 0x100, 0).unwrap(); + assert_eq!(reg_list[0].addr, 0x5000); + assert_eq!(reg_list[0].len, 0x500); - assert_ne!(xfer_list9, xfer_list10); - assert!(!(xfer_list9 == xfer_list10)); + // Out of bounds should return error + assert!(reg_list.get_mut(10).is_err()); + + // Test XferDescList + let mut xfer_list = XferDescList::new(MemType::Vram).unwrap(); + xfer_list.add_desc(0x2000, 0x200, 1); + + { + let desc = xfer_list.get_mut(0).unwrap(); + desc.dev_id = 99; } + + assert_eq!(xfer_list[0].dev_id, 99); + assert!(xfer_list.get_mut(5).is_err()); }