Skip to content

Commit

Permalink
Merge branch '6470-brillig-unroll-small-loops' of github.com:noir-lan…
Browse files Browse the repository at this point in the history
…g/noir into 6470-brillig-unroll-small-loops
  • Loading branch information
aakoshh committed Nov 14, 2024
2 parents 7cd6f3a + c78c07c commit a42c643
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 39 deletions.
41 changes: 12 additions & 29 deletions compiler/noirc_evaluator/src/ssa/opt/array_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,21 @@ impl Ssa {

impl Function {
pub(crate) fn array_set_optimization(&mut self) {
if matches!(self.runtime(), RuntimeType::Brillig(_)) {
// Brillig is supposed to use refcounting to decide whether to mutate an array;

Check warning on line 32 in compiler/noirc_evaluator/src/ssa/opt/array_set.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (refcounting)
// array mutation was only meant for ACIR. We could use it with Brillig as well,
// but then some of the optimizations that we can do in ACIR around shared
// references have to be skipped, which makes it more cumbersome.
return;
}

let reachable_blocks = self.reachable_blocks();

if !self.runtime().is_entry_point() {
assert_eq!(reachable_blocks.len(), 1, "Expected there to be 1 block remaining in Acir function for array_set optimization");
}

let mut context = Context::new(
&self.dfg,
self.parameters(),
matches!(self.runtime(), RuntimeType::Brillig(_)),
);
let mut context = Context::new(&self.dfg);

for block in reachable_blocks.iter() {
context.analyze_last_uses(*block);
Expand All @@ -53,8 +57,6 @@ impl Function {

struct Context<'f> {
dfg: &'f DataFlowGraph,
function_parameters: &'f [ValueId],
is_brillig_runtime: bool,
array_to_last_use: HashMap<ValueId, InstructionId>,
instructions_that_can_be_made_mutable: HashSet<InstructionId>,
// Mapping of an array that comes from a load and whether the address
Expand All @@ -64,15 +66,9 @@ struct Context<'f> {
}

impl<'f> Context<'f> {
fn new(
dfg: &'f DataFlowGraph,
function_parameters: &'f [ValueId],
is_brillig_runtime: bool,
) -> Self {
fn new(dfg: &'f DataFlowGraph) -> Self {
Context {
dfg,
function_parameters,
is_brillig_runtime,
array_to_last_use: HashMap::default(),
instructions_that_can_be_made_mutable: HashSet::default(),
arrays_from_load: HashMap::default(),
Expand All @@ -94,21 +90,12 @@ impl<'f> Context<'f> {
self.instructions_that_can_be_made_mutable.remove(&existing);
}
}
Instruction::ArraySet { array, value, .. } => {
Instruction::ArraySet { array, .. } => {
let array = self.dfg.resolve(*array);

if let Some(existing) = self.array_to_last_use.insert(array, *instruction_id) {
self.instructions_that_can_be_made_mutable.remove(&existing);
}
if self.is_brillig_runtime {
let value = self.dfg.resolve(*value);

if let Some(existing) = self.inner_nested_arrays.get(&value) {
self.instructions_that_can_be_made_mutable.remove(existing);
}
let result = self.dfg.instruction_results(*instruction_id)[0];
self.inner_nested_arrays.insert(result, *instruction_id);
}

// If the array we are setting does not come from a load we can safely mark it mutable.
// If the array comes from a load we may potentially being mutating an array at a reference
Expand All @@ -128,17 +115,13 @@ impl<'f> Context<'f> {
}
});

// We cannot safely mutate slices that are inputs to the function, as they might be shared with the caller.
// NB checking the block parameters is not enough, as we might have jumped into a parameterless blocks inside the function.
let is_function_param = self.function_parameters.contains(&array);

let can_mutate = if let Some(is_from_param) = self.arrays_from_load.get(&array)
{
// If the array was loaded from a reference parameter, we cannot
// safely mark that array mutable as it may be shared by another value.
!is_from_param && is_return_block
} else {
!is_array_in_terminator && !is_function_param
!is_array_in_terminator
};

if can_mutate {
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub(crate) fn assert_normalized_ssa_equals(mut ssa: super::Ssa, expected: &str)
let expected = trim_comments_from_lines(&expected);

if ssa != expected {
println!("Got:\n~~~\n{}\n~~~\nExpected:\n~~~\n{}\n~~~", ssa, expected);
similar_asserts::assert_eq!(ssa, expected);
println!("Expected:\n~~~\n{expected}\n~~~\nGot:\n~~~\n{ssa}\n~~~");
similar_asserts::assert_eq!(expected, ssa);
}
}
1 change: 0 additions & 1 deletion compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ impl Ssa {
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
#[tracing::instrument(level = "trace", skip(ssa))]
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
ssa.normalize_ids();
for (_, function) in ssa.functions.iter_mut() {
// Try to unroll loops first:
let mut unroll_errors = function.try_unroll_loops();
Expand Down
9 changes: 8 additions & 1 deletion compiler/noirc_frontend/src/elaborator/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,14 @@ impl<'context> Elaborator<'context> {
}
ItemKind::Impl(r#impl) => {
let module = self.module_id();
dc_mod::collect_impl(self.interner, generated_items, r#impl, self.file, module);
dc_mod::collect_impl(
self.interner,
generated_items,
r#impl,
self.file,
module,
&mut self.errors,
);
}

ItemKind::ModuleDecl(_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use crate::{
};

use self::builtin_helpers::{eq_item, get_array, get_ctstring, get_str, get_u8, hash_item, lex};
use super::Interpreter;
use super::{foreign, Interpreter};

pub(crate) mod builtin_helpers;

Expand All @@ -57,6 +57,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
let interner = &mut self.elaborator.interner;
let call_stack = &self.elaborator.interpreter_call_stack;
match name {
"apply_range_constraint" => foreign::apply_range_constraint(arguments, location),
"array_as_str_unchecked" => array_as_str_unchecked(interner, arguments, location),
"array_len" => array_len(interner, arguments, location),
"assert_constant" => Ok(Value::Bool(true)),
Expand Down
26 changes: 25 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/interpreter/foreign.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use acvm::blackbox_solver::BlackBoxFunctionSolver;
use acvm::{
acir::BlackBoxFunc, blackbox_solver::BlackBoxFunctionSolver, AcirField, BlackBoxResolutionError,
};
use bn254_blackbox_solver::Bn254BlackBoxSolver;
use im::Vector;
use iter_extended::try_vecmap;
Expand Down Expand Up @@ -29,6 +31,28 @@ pub(super) fn call_foreign(
}
}

pub(super) fn apply_range_constraint(
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let (value, num_bits) = check_two_arguments(arguments, location)?;

let input = get_field(value)?;
let num_bits = get_u32(num_bits)?;

if input.num_bits() < num_bits {
Ok(Value::Unit)
} else {
Err(InterpreterError::BlackBoxError(
BlackBoxResolutionError::Failed(
BlackBoxFunc::RANGE,
"value exceeds range check bounds".to_owned(),
),
location,
))
}
}

// poseidon2_permutation<let N: u32>(_input: [Field; N], _state_length: u32) -> [Field; N]
fn poseidon2_permutation(
interner: &mut NodeInterner,
Expand Down
38 changes: 34 additions & 4 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ pub fn collect_defs(

errors.extend(collector.collect_functions(context, ast.functions, crate_id));

collector.collect_trait_impls(context, ast.trait_impls, crate_id);
errors.extend(collector.collect_trait_impls(context, ast.trait_impls, crate_id));

collector.collect_impls(context, ast.impls, crate_id);
errors.extend(collector.collect_impls(context, ast.impls, crate_id));

collector.collect_attributes(
ast.inner_attributes,
Expand Down Expand Up @@ -163,7 +163,13 @@ impl<'a> ModCollector<'a> {
errors
}

fn collect_impls(&mut self, context: &mut Context, impls: Vec<TypeImpl>, krate: CrateId) {
fn collect_impls(
&mut self,
context: &mut Context,
impls: Vec<TypeImpl>,
krate: CrateId,
) -> Vec<(CompilationError, FileId)> {
let mut errors = Vec::new();
let module_id = ModuleId { krate, local_id: self.module_id };

for r#impl in impls {
Expand All @@ -173,16 +179,21 @@ impl<'a> ModCollector<'a> {
r#impl,
self.file_id,
module_id,
&mut errors,
);
}

errors
}

fn collect_trait_impls(
&mut self,
context: &mut Context,
impls: Vec<NoirTraitImpl>,
krate: CrateId,
) {
) -> Vec<(CompilationError, FileId)> {
let mut errors = Vec::new();

for mut trait_impl in impls {
let trait_name = trait_impl.trait_name.clone();

Expand All @@ -198,6 +209,13 @@ impl<'a> ModCollector<'a> {
let module = ModuleId { krate, local_id: self.module_id };

for (_, func_id, noir_function) in &mut unresolved_functions.functions {
if noir_function.def.attributes.is_test_function() {
let error = DefCollectorErrorKind::TestOnAssociatedFunction {
span: noir_function.name_ident().span(),
};
errors.push((error.into(), self.file_id));
}

let location = Location::new(noir_function.def.span, self.file_id);
context.def_interner.push_function(*func_id, &noir_function.def, module, location);
}
Expand All @@ -224,6 +242,8 @@ impl<'a> ModCollector<'a> {

self.def_collector.items.trait_impls.push(unresolved_trait_impl);
}

errors
}

fn collect_functions(
Expand Down Expand Up @@ -1051,13 +1071,23 @@ pub fn collect_impl(
r#impl: TypeImpl,
file_id: FileId,
module_id: ModuleId,
errors: &mut Vec<(CompilationError, FileId)>,
) {
let mut unresolved_functions =
UnresolvedFunctions { file_id, functions: Vec::new(), trait_id: None, self_type: None };

for (method, _) in r#impl.methods {
let doc_comments = method.doc_comments;
let mut method = method.item;

if method.def.attributes.is_test_function() {
let error = DefCollectorErrorKind::TestOnAssociatedFunction {
span: method.name_ident().span(),
};
errors.push((error.into(), file_id));
continue;
}

let func_id = interner.push_empty_fn();
method.def.where_clause.extend(r#impl.where_clause.clone());
let location = Location::new(method.span(), file_id);
Expand Down
8 changes: 8 additions & 0 deletions compiler/noirc_frontend/src/hir/def_collector/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub enum DefCollectorErrorKind {
},
#[error("{0}")]
UnsupportedNumericGenericType(#[from] UnsupportedNumericGenericType),
#[error("The `#[test]` attribute may only be used on a non-associated function")]
TestOnAssociatedFunction { span: Span },
}

impl DefCollectorErrorKind {
Expand Down Expand Up @@ -291,6 +293,12 @@ impl<'a> From<&'a DefCollectorErrorKind> for Diagnostic {
diag
}
DefCollectorErrorKind::UnsupportedNumericGenericType(err) => err.into(),
DefCollectorErrorKind::TestOnAssociatedFunction { span } => Diagnostic::simple_error(
"The `#[test]` attribute is disallowed on `impl` methods".into(),
String::new(),
*span,
),

}
}
}
48 changes: 48 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3692,3 +3692,51 @@ fn allows_struct_with_generic_infix_type_as_main_input_3() {
"#;
assert_no_errors(src);
}

#[test]
fn disallows_test_attribute_on_impl_method() {
let src = r#"
pub struct Foo {}
impl Foo {
#[test]
fn foo() {}
}
fn main() {}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
errors[0].0,
CompilationError::DefinitionError(DefCollectorErrorKind::TestOnAssociatedFunction {
span: _
})
));
}

#[test]
fn disallows_test_attribute_on_trait_impl_method() {
let src = r#"
pub trait Trait {
fn foo() {}
}
pub struct Foo {}
impl Trait for Foo {
#[test]
fn foo() {}
}
fn main() {}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
errors[0].0,
CompilationError::DefinitionError(DefCollectorErrorKind::TestOnAssociatedFunction {
span: _
})
));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "comptime_apply_failing_range_constraint"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
comptime {
256.assert_max_bit_size::<8>()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "comptime_apply_range_constraint"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
comptime {
2.assert_max_bit_size::<8>()
}
}

0 comments on commit a42c643

Please sign in to comment.