Skip to content

Commit

Permalink
Update example CLI for benchmarking. (#27)
Browse files Browse the repository at this point in the history
* Update FFTW bindings
* Move thread control to cli
* Add write_image flag for large solves.
  • Loading branch information
SallySoul authored Jan 1, 2025
1 parent 6c1151f commit 4008681
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 34 deletions.
14 changes: 11 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
bytemuck = "1.19.0"
colorous = "1.0.14"
fftw = { git = "https://github.com/SallySoul/fftw3-rs.git", tag = "fftw3-v0.8.1" }
fftw = { git = "https://github.com/SallySoul/fftw3-rs.git", tag = "fftw3-0.0.1" }

image = "0.25.2"
nalgebra = "0.33.2"
Expand Down
17 changes: 12 additions & 5 deletions examples/heat_1d_ap_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ fn main() {
let bc = ConstantCheck::new(1.0, grid_bound);

// Make image
let mut img = nhls::image::Image1D::new(grid_bound, args.lines as u32);
img.add_line(0, input_domain.buffer());
let mut img = None;
if args.write_image {
let mut i = nhls::image::Image1D::new(grid_bound, args.lines as u32);
i.add_line(0, input_domain.buffer());
img = Some(i);
}
for t in 1..args.lines as u32 {
box_apply(
&bc,
Expand All @@ -28,8 +32,11 @@ fn main() {
args.chunk_size,
);
std::mem::swap(&mut input_domain, &mut output_domain);
img.add_line(t, input_domain.buffer());
if let Some(i) = img.as_mut() {
i.add_line(t, input_domain.buffer());
}
}
if let Some(i) = img {
i.write(&output_image_path);
}

img.write(&output_image_path);
}
17 changes: 12 additions & 5 deletions examples/heat_1d_ap_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ fn main() {
args.chunk_size,
);

// Make image
let mut img = nhls::image::Image1D::new(grid_bound, args.lines as u32);
img.add_line(0, input_domain.buffer());
let mut img = None;
if args.write_image {
let mut i = nhls::image::Image1D::new(grid_bound, args.lines as u32);
i.add_line(0, input_domain.buffer());
img = Some(i);
}
for t in 1..args.lines as u32 {
solver.loop_solve(
&mut input_domain,
&mut output_domain,
args.steps_per_line,
);
std::mem::swap(&mut input_domain, &mut output_domain);
img.add_line(t, input_domain.buffer());
if let Some(i) = img.as_mut() {
i.add_line(t, input_domain.buffer());
}
}

img.write(&output_image_path);
if let Some(i) = img {
i.write(&output_image_path);
}
}
18 changes: 12 additions & 6 deletions examples/heat_1d_p_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ fn main() {
};
input_domain.par_set_values(ic_gen, args.chunk_size);

// Make image
let mut img = nhls::image::Image1D::new(grid_bound, args.lines as u32);
img.add_line(0, input_domain.buffer());
let mut img = None;
if args.write_image {
let mut i = nhls::image::Image1D::new(grid_bound, args.lines as u32);
i.add_line(0, input_domain.buffer());
img = Some(i);
}
for t in 1..args.lines as u32 {
direct_periodic_apply(
&stencil,
Expand All @@ -36,8 +39,11 @@ fn main() {
args.chunk_size,
);
std::mem::swap(&mut input_domain, &mut output_domain);
img.add_line(t, input_domain.buffer());
if let Some(i) = img.as_mut() {
i.add_line(t, input_domain.buffer());
}
}
if let Some(i) = img {
i.write(&output_image_path);
}

img.write(&output_image_path);
}
21 changes: 14 additions & 7 deletions examples/heat_1d_p_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ fn main() {
// Create domains
let grid_bound = args.grid_bounds();
let mut input_domain = OwnedDomain::new(grid_bound);

let mut output_domain = OwnedDomain::new(grid_bound);

// Fill in with IC values (use normal dist for spike in the middle)
Expand All @@ -23,14 +22,19 @@ fn main() {
};
input_domain.par_set_values(ic_gen, args.chunk_size);

// Make image
let mut img = nhls::image::Image1D::new(grid_bound, args.lines as u32);
let mut periodic_library =
nhls::solver::periodic_plan::PeriodicPlanLibrary::new(
&grid_bound,
&stencil,
);
img.add_line(0, input_domain.buffer());

let mut img = None;
if args.write_image {
let mut i = nhls::image::Image1D::new(grid_bound, args.lines as u32);
i.add_line(0, input_domain.buffer());
img = Some(i);
}

for t in 1..args.lines as u32 {
periodic_library.apply(
&mut input_domain,
Expand All @@ -39,8 +43,11 @@ fn main() {
args.chunk_size,
);
std::mem::swap(&mut input_domain, &mut output_domain);
img.add_line(t, input_domain.buffer());
if let Some(i) = img.as_mut() {
i.add_line(t, input_domain.buffer());
}
}
if let Some(i) = img {
i.write(&output_image_path);
}

img.write(&output_image_path);
}
8 changes: 6 additions & 2 deletions examples/heat_2d_p_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ fn main() {
exp.exp()
};
input_domain.par_set_values(ic_gen, args.chunk_size);
image2d(&input_domain, &args.frame_name(0));
if args.write_images {
image2d(&input_domain, &args.frame_name(0));
}

// Create boundary condition
// TODO WHAT is this doing in periodic, shouldn't we use periodic direct solve?
Expand All @@ -47,6 +49,8 @@ fn main() {
);

std::mem::swap(&mut input_domain, &mut output_domain);
image2d(&input_domain, &args.frame_name(t));
if args.write_images {
image2d(&input_domain, &args.frame_name(t));
}
}
}
12 changes: 7 additions & 5 deletions examples/heat_2d_p_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ use nhls::util::*;
fn main() {
let args = Args::cli_parse("heat_2d_p_direct");

// Grid size
let grid_bound = args.grid_bounds();

let stencil = nhls::standard_stencils::heat_2d(1.0, 1.0, 1.0, 0.2, 0.2);

// Create domains
let grid_bound = args.grid_bounds();
let mut input_domain = OwnedDomain::new(grid_bound);
let mut output_domain = OwnedDomain::new(grid_bound);

Expand All @@ -29,7 +27,9 @@ fn main() {
exp.exp()
};
input_domain.par_set_values(ic_gen, args.chunk_size);
image2d(&input_domain, &args.frame_name(0));
if args.write_images {
image2d(&input_domain, &args.frame_name(0));
}

// Apply periodic solver
let mut periodic_library =
Expand All @@ -45,6 +45,8 @@ fn main() {
args.chunk_size,
);
std::mem::swap(&mut input_domain, &mut output_domain);
image2d(&input_domain, &args.frame_name(t));
if args.write_images {
image2d(&input_domain, &args.frame_name(t));
}
}
}
15 changes: 15 additions & 0 deletions src/image_1d_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ pub struct Args {
/// Domain size
#[arg(short, long, default_value = "1000")]
pub domain_size: usize,

/// Write out image, WARNING: we do not check image size, so be reasonable.
#[arg(short, long)]
pub write_image: bool,

/// The number of threads to use.
#[arg(short, long, default_value = "8")]
pub threads: usize,
}

impl Args {
Expand All @@ -43,6 +51,13 @@ impl Args {
let mut output_image_path = args.output_dir.clone();
output_image_path.push(format!("{}.png", name));

rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.unwrap();
fftw::threading::init_threads_f64().unwrap();
fftw::threading::plan_with_nthreads_f64(args.threads);

(args, output_image_path)
}

Expand Down
15 changes: 15 additions & 0 deletions src/image_2d_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ pub struct Args {
/// Domain size, assume square
#[arg(short, long, default_value = "1000")]
pub domain_size: usize,

/// Write out image, WARNING: we do not check image size, so be reasonable.
#[arg(short, long)]
pub write_images: bool,

/// The number of threads to use.
#[arg(short, long, default_value = "8")]
pub threads: usize,
}

impl Args {
Expand All @@ -40,6 +48,13 @@ impl Args {
let _ = std::fs::remove_dir_all(output_dir);
std::fs::create_dir(output_dir).unwrap();

rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.unwrap();
fftw::threading::init_threads_f64().unwrap();
fftw::threading::plan_with_nthreads_f64(args.threads);

args
}

Expand Down

0 comments on commit 4008681

Please sign in to comment.