Skip to content

Commit

Permalink
Add DomainStack and Counter mock for pre-computation (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
SallySoul authored Jan 8, 2025
1 parent efeb8d5 commit 44cb6be
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/domain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
mod bc;
mod gather_args;
mod stack;
mod view;

pub use bc::*;
pub use gather_args::*;
pub use stack::*;
pub use view::*;
87 changes: 87 additions & 0 deletions src/domain/stack/counter_stack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use crate::domain::*;
use crate::util::*;

/// Mimics the domain stack during runtime,
/// used to compute the size of the runtime stack.
pub struct CounterStack<const GRID_DIMENSION: usize> {
next_id: usize,
size: usize,
max_size: usize,
}

impl<const GRID_DIMENSION: usize> CounterStack<GRID_DIMENSION> {
pub fn blank() -> Self {
CounterStack {
next_id: 0,
size: 0,
max_size: 0,
}
}

pub fn pop_domain(&mut self, aabb: &AABB<GRID_DIMENSION>) -> DomainId {
let buffer_size = aabb.buffer_size();
let result = self.next_id;
self.next_id += 1;
self.size += buffer_size;
if self.size > self.max_size {
self.max_size = self.size;
}
result
}

pub fn push_domain(&mut self, id: DomainId, aabb: &AABB<GRID_DIMENSION>) {
let buffer_size = aabb.buffer_size();
debug_assert!(self.next_id != 0);
debug_assert!(self.size >= buffer_size);
debug_assert_eq!(self.next_id - 1, id);
self.size -= buffer_size;
self.next_id -= 1;
}

pub fn size(&self) -> usize {
self.size
}

pub fn max_size(&self) -> usize {
self.max_size
}

pub fn finish(self) -> usize {
self.max_size
}
}

#[cfg(test)]
mod unit_tests {
use super::*;

#[test]
fn counter() {
let mut counter = CounterStack::blank();
let aabb_1 = AABB::new(matrix![0, 1]);
let aabb_1_s = aabb_1.buffer_size();
debug_assert_eq!(counter.size(), 0);
debug_assert_eq!(counter.max_size(), 0);
let id_1 = counter.pop_domain(&aabb_1);
debug_assert_eq!(id_1, 0);
debug_assert_eq!(counter.size(), aabb_1_s);
debug_assert_eq!(counter.max_size(), aabb_1_s);

let aabb_2 = AABB::new(matrix![0, 3]);
let aabb_2_s = aabb_2.buffer_size();
let id_2 = counter.pop_domain(&aabb_2);
debug_assert_eq!(id_2, 1);
debug_assert_eq!(counter.max_size(), aabb_1_s + aabb_2_s);

counter.push_domain(id_2, &aabb_2);
debug_assert_eq!(counter.size(), aabb_1_s);
debug_assert_eq!(counter.max_size(), aabb_1_s + aabb_2_s);

let aabb_3 = AABB::new(matrix![0, 5]);
let aabb_3_s = aabb_3.buffer_size();
let id_3 = counter.pop_domain(&aabb_3);
debug_assert_eq!(id_3, 1);
debug_assert_eq!(counter.size(), aabb_1_s + aabb_3_s);
debug_assert_eq!(counter.max_size(), aabb_1_s + aabb_3_s);
}
}
135 changes: 135 additions & 0 deletions src/domain/stack/domain_stack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use crate::domain::*;
use crate::util::*;

/// Dynamically create and drop SliceDomains from single allocaiton
pub struct DomainStack<'a, const GRID_DIMENSION: usize> {
buffer: AlignedVec<f64>,
remainder: &'a mut [f64],
next_id: usize,
}

impl<'a, const GRID_DIMENSION: usize> DomainStack<'a, GRID_DIMENSION> {
pub fn buffer(&'a self) -> &'a [f64] {
self.buffer.as_slice()
}

pub fn remainder(&'a self) -> &'a [f64] {
self.remainder
}

pub fn with_capacity(size: usize) -> Self {
// create remainder from buffer parts
let buffer = AlignedVec::new(size);
let data_ptr = buffer.as_slice().as_ptr();
let next_id = 0;
let remainder = unsafe {
let data_ptr_mut = data_ptr as *mut f64;
std::slice::from_raw_parts_mut(data_ptr_mut, size)
};

DomainStack {
remainder,
buffer,
next_id,
}
}

pub fn pop_domain(
&mut self,
aabb: AABB<GRID_DIMENSION>,
) -> (DomainId, SliceDomain<'a, GRID_DIMENSION>) {
let buffer_size = aabb.buffer_size();
let remainder_len = self.remainder.len();
debug_assert!(remainder_len >= buffer_size);
let remainder_ptr = self.remainder.as_ptr();

let new_remainder_len = remainder_len - buffer_size;

let slice = unsafe {
let slice_ptr = remainder_ptr.add(new_remainder_len);
let slice_ptr_mut = slice_ptr as *mut f64;
std::slice::from_raw_parts_mut(slice_ptr_mut, buffer_size)
};

self.remainder = unsafe {
let remainder_ptr_mut = remainder_ptr as *mut f64;
std::slice::from_raw_parts_mut(remainder_ptr_mut, new_remainder_len)
};

let result = (self.next_id, SliceDomain::new(aabb, slice));
self.next_id += 1;
result
}

pub fn push_domain(
&mut self,
id: DomainId,
domain: SliceDomain<'a, GRID_DIMENSION>,
) {
debug_assert!(self.next_id != 0);
debug_assert_eq!(self.next_id - 1, id);
self.next_id -= 1;
let buffer_size = domain.aabb().buffer_size();
let remainder_size = self.remainder.len();
self.remainder = unsafe {
let remainder_ptr = self.remainder.as_ptr();
let remainder_ptr_mut = remainder_ptr as *mut f64;
std::slice::from_raw_parts_mut(
remainder_ptr_mut,
buffer_size + remainder_size,
)
};
}
}

#[cfg(test)]
mod unit_test {
use super::*;

#[test]
fn counter() {
let size = 100;
let mut counter = DomainStack::with_capacity(size);
let aabb_1 = AABB::new(matrix![0, 1]);
let aabb_1_s = aabb_1.buffer_size();
let (id_1, domain_1) = counter.pop_domain(aabb_1);
debug_assert_eq!(id_1, 0);

debug_assert_eq!(counter.buffer().len(), size);
debug_assert_eq!(domain_1.buffer().len(), aabb_1_s);
debug_assert_eq!(aabb_1_s + counter.remainder().len(), size);
unsafe {
let expected_domain_1_ptr =
counter.buffer().as_ptr().add(size - aabb_1_s);
debug_assert_eq!(domain_1.buffer().as_ptr(), expected_domain_1_ptr);
}

let aabb_2 = AABB::new(matrix![0, 3]);
let aabb_2_s = aabb_2.buffer_size();
let (id_2, domain_2) = counter.pop_domain(aabb_2);
debug_assert_eq!(id_2, 1);
debug_assert_eq!(counter.buffer().len(), size);
debug_assert_eq!(domain_2.buffer().len(), aabb_2_s);
debug_assert_eq!(aabb_2_s + aabb_1_s + counter.remainder().len(), size);
unsafe {
let expected_domain_2_ptr =
counter.buffer().as_ptr().add(size - (aabb_1_s + aabb_2_s));
debug_assert_eq!(domain_2.buffer().as_ptr(), expected_domain_2_ptr);
}

counter.push_domain(id_2, domain_2);

let aabb_3 = AABB::new(matrix![0, 5]);
let aabb_3_s = aabb_3.buffer_size();
let (id_3, domain_3) = counter.pop_domain(aabb_3);
debug_assert_eq!(id_3, 1);
debug_assert_eq!(counter.buffer().len(), size);
debug_assert_eq!(domain_3.buffer().len(), aabb_3_s);
debug_assert_eq!(aabb_3_s + aabb_1_s + counter.remainder().len(), size);
unsafe {
let expected_domain_3_ptr =
counter.buffer().as_ptr().add(size - (aabb_1_s + aabb_3_s));
debug_assert_eq!(domain_3.buffer().as_ptr(), expected_domain_3_ptr);
}
}
}
7 changes: 7 additions & 0 deletions src/domain/stack/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod counter_stack;
mod domain_stack;

pub use counter_stack::*;
pub use domain_stack::*;

pub type DomainId = usize;
1 change: 1 addition & 0 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub trait NumTrait = Num + Copy + Send + Sync;
mod aabb;
pub mod indexing;
pub use aabb::*;
pub use fftw::array::AlignedVec;
pub use nalgebra::{matrix, vector};

pub type Coord<const GRID_DIMENSION: usize> =
Expand Down
1 change: 0 additions & 1 deletion tests/base_solver_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use nhls::solver::*;
use nhls::stencil::*;
use nhls::util::*;

use fftw::array::*;
use float_cmp::assert_approx_eq;
use nalgebra::matrix;

Expand Down

0 comments on commit 44cb6be

Please sign in to comment.