Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: map entry point indexes after all ssa passes #6740

Merged
merged 5 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions compiler/noirc_evaluator/src/acir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,14 +821,12 @@ impl<'a> Context<'a> {
})
.sum();

let Some(acir_function_id) =
ssa.entry_point_to_generated_index.get(id)
else {
let Some(acir_function_id) = ssa.get_entry_point_index(id) else {
unreachable!("Expected an associated final index for call to acir function {id} with args {arguments:?}");
};

let output_vars = self.acir_context.call_acir_function(
AcirFunctionId(*acir_function_id),
AcirFunctionId(acir_function_id),
inputs,
output_count,
self.current_side_effects_enabled_var,
Expand Down Expand Up @@ -2979,7 +2977,7 @@ mod test {

build_basic_foo_with_return(&mut builder, foo_id, false, inline_type);

let ssa = builder.finish();
let ssa = builder.finish().generate_entry_point_index();

let (acir_functions, _, _, _) = ssa
.into_acir(&Brillig::default(), ExpressionWidth::default())
Expand Down Expand Up @@ -3087,6 +3085,7 @@ mod test {
let ssa = builder.finish();

let (acir_functions, _, _, _) = ssa
.generate_entry_point_index()
.into_acir(&Brillig::default(), ExpressionWidth::default())
Comment on lines +3088 to 3089
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it have made sense to move this into the into_acir method?

.expect("Should compile manually written SSA into ACIR");
// The expected result should look very similar to the above test expect that the input witnesses of the `Call`
Expand Down Expand Up @@ -3184,7 +3183,7 @@ mod test {

build_basic_foo_with_return(&mut builder, foo_id, false, inline_type);

let ssa = builder.finish();
let ssa = builder.finish().generate_entry_point_index();

let (acir_functions, _, _, _) = ssa
.into_acir(&Brillig::default(), ExpressionWidth::default())
Expand Down Expand Up @@ -3311,6 +3310,7 @@ mod test {
let brillig = ssa.to_brillig(false);

let (acir_functions, brillig_functions, _, _) = ssa
.generate_entry_point_index()
.into_acir(&brillig, ExpressionWidth::default())
.expect("Should compile manually written SSA into ACIR");

Expand Down Expand Up @@ -3375,6 +3375,7 @@ mod test {
// The Brillig bytecode we insert for the stdlib is hardcoded so we do not need to provide any
// Brillig artifacts to the ACIR gen pass.
let (acir_functions, brillig_functions, _, _) = ssa
.generate_entry_point_index()
.into_acir(&Brillig::default(), ExpressionWidth::default())
.expect("Should compile manually written SSA into ACIR");

Expand Down Expand Up @@ -3449,6 +3450,7 @@ mod test {
println!("{}", ssa);

let (acir_functions, brillig_functions, _, _) = ssa
.generate_entry_point_index()
.into_acir(&brillig, ExpressionWidth::default())
.expect("Should compile manually written SSA into ACIR");

Expand Down Expand Up @@ -3537,6 +3539,7 @@ mod test {
println!("{}", ssa);

let (acir_functions, brillig_functions, _, _) = ssa
.generate_entry_point_index()
.into_acir(&brillig, ExpressionWidth::default())
.expect("Should compile manually written SSA into ACIR");

Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ impl SsaBuilder {
}

fn finish(self) -> Ssa {
self.ssa
self.ssa.generate_entry_point_index()
}

/// Runs the given SSA pass and prints the SSA afterward if `print_ssa_passes` is true.
Expand Down
6 changes: 4 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ impl Ssa {
}

// The ones that remain are never called: let's remove them.
for func_id in brillig_functions.keys() {
for (func_id, func) in &brillig_functions {
// We never want to remove the main function (it could be `unconstrained` or it
// could have been turned into brillig if `--force-brillig` was given).
// We also don't want to remove entry points.
if self.main_id == *func_id || self.entry_point_to_generated_index.contains_key(func_id)
let runtime = func.runtime();
if self.main_id == *func_id
|| (runtime.is_entry_point() && matches!(runtime, RuntimeType::Acir(_)))
{
continue;
}
Expand Down
45 changes: 29 additions & 16 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(crate) struct Ssa {
/// This mapping is necessary to use the correct function pointer for an ACIR call,
/// as the final program artifact will be a list of only entry point functions.
#[serde(skip)]
pub(crate) entry_point_to_generated_index: BTreeMap<FunctionId, u32>,
entry_point_to_generated_index: BTreeMap<FunctionId, u32>,
// We can skip serializing this field as the error selector types end up as part of the
// ABI not the actual SSA IR.
#[serde(skip)]
Expand All @@ -47,25 +47,11 @@ impl Ssa {
(f.id(), f)
});

let entry_point_to_generated_index = btree_map(
functions
.iter()
.filter(|(_, func)| {
let runtime = func.runtime();
match func.runtime() {
RuntimeType::Acir(_) => runtime.is_entry_point() || func.id() == main_id,
RuntimeType::Brillig(_) => false,
}
})
.enumerate(),
|(i, (id, _))| (*id, i as u32),
);

Self {
functions,
main_id,
next_id: AtomicCounter::starting_after(max_id),
entry_point_to_generated_index,
entry_point_to_generated_index: BTreeMap::new(),
error_selector_to_type: error_types,
}
}
Expand Down Expand Up @@ -98,6 +84,33 @@ impl Ssa {
self.functions.insert(new_id, function);
new_id
}
pub(crate) fn generate_entry_point_index(mut self) -> Self {
self.entry_point_to_generated_index = btree_map(
self.functions
.iter()
.filter(|(_, func)| {
let runtime = func.runtime();
match func.runtime() {
RuntimeType::Acir(_) => {
runtime.is_entry_point() || func.id() == self.main_id
}
RuntimeType::Brillig(_) => false,
}
})
.enumerate(),
|(i, (id, _))| (*id, i as u32),
);
self
}

pub(crate) fn get_entry_point_index(&self, func_id: &FunctionId) -> Option<u32> {
// Ensure the map has been initialized
assert!(
!self.entry_point_to_generated_index.is_empty(),
"Trying to read uninitialized entry point index"
);
self.entry_point_to_generated_index.get(func_id).copied()
}
}

impl Display for Ssa {
Expand Down
Loading