Skip to content

Commit

Permalink
Add Metal GPU trace capture in tract cli
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Jul 19, 2024
1 parent c1b4380 commit 59c260e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
27 changes: 26 additions & 1 deletion cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ fn main() -> TractResult<()> {
.arg(Arg::new("f32-to-f16").long("f32-to-f16").alias("half-floats").long_help("Convert the decluttered network from f32 to f16"))
.arg(arg!(--"f16-to-f32" "Convert the decluttered network from f16 to f32"))
.arg(arg!(--"metal" "Convert metal compatible operator in the decluttered network. Only available on MacOS and iOS"))
.arg(Arg::new("metal-gpu-trace").long("metal-gpu-trace").takes_value(true).help("Capture Metal GPU trace at given path. Only available on MacOS and iOS"))
.arg(Arg::new("transform").short('t').long("transform").multiple_occurrences(true).takes_value(true).help("Apply a built-in transformation to the model"))
.arg(Arg::new("set").long("set").multiple_occurrences(true).takes_value(true)
.long_help("Set a symbol to a concrete value after decluttering"))
Expand Down Expand Up @@ -291,7 +292,31 @@ fn main() -> TractResult<()> {
env_logger::Builder::from_env(env).format_timestamp_nanos().init();
info_usage("init", probe.as_ref());

if let Err(e) = handle(matches, probe.as_ref()) {
let res = if matches.is_present("metal-gpu-trace") {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let gpu_trace_path = std::path::Path::new(matches.value_of("metal-gpu-trace").unwrap()).to_path_buf();
ensure!(gpu_trace_path.is_absolute(), "Metal GPU trace file has to be absolute");
ensure!(!gpu_trace_path.exists(), format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path));
log::info!("Capturing Metal GPU trace at : {:?}", gpu_trace_path);
std::env::set_var("METAL_CAPTURE_ENABLED", "1");
std::env::set_var("METAL_DEVICE_WRAPPER_TYPE", "1");
let probe_ref = probe.as_ref();
tract_metal::METAL_CONTEXT.with_borrow(move |context| {
context.capture_trace(gpu_trace_path, move |_ctxt| {
handle(matches, probe_ref)
})
})
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
bail!("`--metal-gpu-trace` present but it is only available on MacOS and iOS")
}
} else {
handle(matches, probe.as_ref())
};

if let Err(e) = res {
error!("{:?}", e);
std::process::exit(1);
}
Expand Down
2 changes: 1 addition & 1 deletion metal/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ impl MetalContext {
pub fn capture_trace<P, F>(&self, path: P, compute: F) -> Result<()>
where
P: AsRef<Path>,
F: Fn(&Self) -> Result<()>,
F: FnOnce(&Self) -> Result<()>,
{
self.wait_until_completed()?;

Expand Down

0 comments on commit 59c260e

Please sign in to comment.