Skip to content

Commit

Permalink
Normalize nvcc subcommand order for CTK <12.0, ensuring the DAG is pa…
Browse files Browse the repository at this point in the history
…rsed by inputs/outputs even if the preprocessor, cicc, and ptxas commands are out of order.
  • Loading branch information
trxcllnt committed Nov 14, 2024
1 parent edab2b0 commit 9c0e681
Showing 1 changed file with 118 additions and 50 deletions.
168 changes: 118 additions & 50 deletions src/compiler/nvcc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -655,41 +655,66 @@ where
// but can optionally be run in parallel to other groups if the user requested via
// `nvcc --threads`.

let mut no_more_groups = false;
let mut command_groups: Vec<Vec<NvccGeneratedSubcommand>> = vec![];

let preprocessor_flag = match host_compiler {
NvccHostCompiler::Msvc => "-P",
_ => "-E",
}
.to_owned();

for (_, dir, exe, args) in all_commands {
if log_enabled!(log::Level::Trace) {
trace!(
"transformed nvcc command: {:?}",
[
&[format!("cd {} &&", dir.to_string_lossy()).to_string()],
&[exe.to_str().unwrap_or_default().to_string()][..],
&args[..]
]
.concat()
.join(" ")
);
}
let gen_module_id_file_flag = "--gen_module_id_file".to_owned();
let mut cuda_front_end_group = Vec::<NvccGeneratedSubcommand>::new();
let mut final_assembly_group = Vec::<NvccGeneratedSubcommand>::new();
let mut device_compile_groups = HashMap::<String, Vec<NvccGeneratedSubcommand>>::new();

let (env_vars, cacheable) = match exe.file_stem().and_then(|s| s.to_str()) {
for (_, dir, exe, args) in all_commands {
let mut args = args.clone();

if let (env_vars, cacheable, Some(group)) = match exe.file_stem().and_then(|s| s.to_str()) {
// fatbinary and nvlink are not cacheable
Some("fatbinary") | Some("nvlink") => (
env_vars.clone(),
Cacheable::No,
Some(&mut final_assembly_group),
),
// cicc and ptxas are cacheable
Some("cicc") | Some("ptxas") => (env_vars.clone(), Cacheable::Yes),
// cudafe++, nvlink, and fatbinary are not cacheable
Some("cudafe++") | Some("nvlink") => (env_vars.clone(), Cacheable::No),
Some("fatbinary") => {
// The fatbinary command represents the start of the last group
if !no_more_groups {
command_groups.push(vec![]);
Some("cicc") => {
// Remove the `--gen_module_id_file` flag
if let Some(idx) = args.iter().position(|x| x == &gen_module_id_file_flag) {
args.splice(idx..idx + 1, []);
}
no_more_groups = true;
(env_vars.clone(), Cacheable::No)
let group = device_compile_groups.get_mut(&args[args.len() - 3]);
(env_vars.clone(), Cacheable::Yes, group)
}
Some("ptxas") => {
// Remove the `--gen_module_id_file` flag
if let Some(idx) = args.iter().position(|x| x == &gen_module_id_file_flag) {
args.splice(idx..idx + 1, []);
}
let group = device_compile_groups.values_mut().find(|cmds| {
if let Some(cicc) = cmds.last() {
if let Some(cicc_out) = cicc.args.last() {
return cicc_out == &args[args.len() - 3];
}
}
false
});
(env_vars.clone(), Cacheable::Yes, group)
}
// cudafe++ is not cacheable
Some("cudafe++") => {
// Fix for CTK < 12.0:
// Add `--gen_module_id_file` if the cudafe++ args include `--module_id_file_name`
if !args.contains(&gen_module_id_file_flag) {
if let Some(idx) = args.iter().position(|x| x == "--module_id_file_name") {
// Insert `--gen_module_id_file` just before `--module_id_file_name` to match nvcc behavior
args.splice(idx..idx, [gen_module_id_file_flag.clone()]);
}
}
(
env_vars.clone(),
Cacheable::No,
Some(&mut cuda_front_end_group),
)
}
_ => {
// All generated host compiler commands include one of these defines.
Expand All @@ -705,13 +730,47 @@ where
continue;
}
if args.contains(&preprocessor_flag) {
// Each preprocessor step represents the start of a new command
// group, unless it comes after a call to fatbinary.
if !no_more_groups {
command_groups.push(vec![]);
// Each preprocessor step represents the start of a new command group
if let Some(out_file) = if cfg!(target_os = "windows") {
args.iter()
.find(|x| x.starts_with("-Fi"))
.and_then(|x| x.strip_prefix("-Fi"))
} else {
args.iter()
.position(|x| x == "-o")
.and_then(|i| args.get(i + 1).map(|o| o.as_str()))
}
.map(PathBuf::from)
.and_then(|out_path| {
out_path
.file_name()
.and_then(|out_name| out_name.to_str())
.map(|out_name| out_name.to_owned())
})
.and_then(|out_name| {
// If the output file ends with...
// * .cpp1.ii - cicc/ptxas input
// * .cpp4.ii - cudafe++ input
if out_name.ends_with(".cpp1.ii") {
Some(out_name.to_owned())
} else {
None
}
}) {
let new_device_compile_group = vec![];
device_compile_groups.insert(out_file.clone(), new_device_compile_group);
(
env_vars.clone(),
Cacheable::No,
device_compile_groups.get_mut(&out_file),
)
} else {
(
env_vars.clone(),
Cacheable::No,
Some(&mut cuda_front_end_group),
)
}
// Do not run preprocessor calls through sccache
(env_vars.clone(), Cacheable::No)
} else {
// Returns Cacheable::Yes to indicate we _do_ want to run this host
// compiler call through sccache (because it may be distributed),
Expand All @@ -732,31 +791,40 @@ where
.cloned()
.collect::<Vec<_>>(),
Cacheable::Yes,
Some(&mut final_assembly_group),
)
}
}
};
} {
if log_enabled!(log::Level::Trace) {
trace!(
"transformed nvcc command: {:?}",
[
&[format!("cd {} &&", dir.to_string_lossy()).to_string()],
&[exe.to_str().unwrap_or_default().to_string()][..],
&args[..]
]
.concat()
.join(" ")
);
}

// Initialize the first group in case the first command isn't a call to the host preprocessor,
// i.e. `nvcc -o test.o -c test.c`
if command_groups.is_empty() {
command_groups.push(vec![]);
group.push(NvccGeneratedSubcommand {
exe: exe.clone(),
args: args.clone(),
cwd: dir.into(),
env_vars,
cacheable,
});
}

match command_groups.last_mut() {
None => {}
Some(group) => {
group.push(NvccGeneratedSubcommand {
exe: exe.clone(),
args: args.clone(),
cwd: dir.into(),
env_vars,
cacheable,
});
}
};
}

let mut command_groups = vec![];

command_groups.push(cuda_front_end_group);
command_groups.extend(device_compile_groups.into_values());
command_groups.push(final_assembly_group);

Ok(command_groups)
}

Expand Down

0 comments on commit 9c0e681

Please sign in to comment.