Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ssa refactor): Implement first-class references #1849

Merged
merged 13 commits into from
Jul 5, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.5.1"

[dependencies]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
fn main() {
let mut x = 2;
add1(&mut x);
assert(x == 3);

let mut s = S { y: x };
s.add2();
assert(s.y == 5);

// Test that normal mutable variables are still copied
let mut a = 0;
mutate_copy(a);
assert(a == 0);

// Test something 3 allocations deep
let mut nested_allocations = Nested { y: &mut &mut 0 };
add1(*nested_allocations.y);
assert(**nested_allocations.y == 1);

// Test nested struct allocations with a mutable reference to an array.
let mut c = C {
foo: 0,
bar: &mut C2 {
array: &mut [1, 2],
},
};
*c.bar.array = [3, 4];
assert(*c.bar.array == [3, 4]);
}

fn add1(x: &mut Field) {
*x += 1;
}

struct S { y: Field }

struct Nested { y: &mut &mut Field }

struct C {
foo: Field,
bar: &mut C2,
}

struct C2 {
array: &mut [Field; 2]
}

impl S {
fn add2(&mut self) {
self.y += 2;
}
}

fn mutate_copy(mut a: Field) {
a = 7;
}
28 changes: 14 additions & 14 deletions crates/nargo_cli/tests/test_data_ssa_refactor/tuples/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
use dep::std;

fn main(x: Field, y: Field) {
let pair = (x, y);
assert(pair.0 == 1);
assert(pair.1 == 0);
// let pair = (x, y);
// assert(pair.0 == 1);
// assert(pair.1 == 0);

let (a, b) = if true { (0, 1) } else { (2, 3) };
assert(a == 0);
assert(b == 1);
// let (a, b) = if true { (0, 1) } else { (2, 3) };
// assert(a == 0);
// assert(b == 1);

let (u,v) = if x as u32 < 1 {
(x, x + 1)
} else {
(x + 1, x)
};
assert(u == x+1);
assert(v == x);
// let (u,v) = if x as u32 < 1 {
// (x, x + 1)
// } else {
// (x + 1, x)
// };
// assert(u == x+1);
// assert(v == x);
jfecher marked this conversation as resolved.
Show resolved Hide resolved

// Test mutating tuples
let mut mutable = ((0, 0), 1, 2, 3);
mutable.0 = pair;
mutable.0 = (x, y);
mutable.2 = 7;
assert(mutable.0.0 == 1);
assert(mutable.0.1 == 0);
Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ impl SsaContext {
}
}
Type::Array(..) => panic!("Cannot convert an array type {t} into an ObjectType since it is unknown which array it refers to"),
Type::MutableReference(..) => panic!("Mutable reference types are unimplemented in the old ssa backend"),
Type::Unit => ObjectType::NotAnObject,
Type::Function(..) => ObjectType::Function,
Type::Tuple(_) => todo!("Conversion to ObjectType is unimplemented for tuples"),
Expand Down
10 changes: 10 additions & 0 deletions crates/noirc_evaluator/src/ssa/ssa_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ impl IrGenerator {
self.context.new_instruction(op, rhs_type)
}
UnaryOp::Not => self.context.new_instruction(Operation::Not(rhs), rhs_type),
UnaryOp::MutableReference | UnaryOp::Dereference => {
unimplemented!("Mutable references are unimplemented in the old ssa backend")
}
}
}

Expand Down Expand Up @@ -248,6 +251,9 @@ impl IrGenerator {
let val = self.find_variable(ident_def).unwrap();
val.get_field_member(*field_index)
}
LValue::Dereference { .. } => {
unreachable!("Mutable references are unsupported in the old ssa backend")
}
}
}

Expand All @@ -256,6 +262,7 @@ impl IrGenerator {
LValue::Ident(ident) => &ident.definition,
LValue::Index { array, .. } => Self::lvalue_ident_def(array.as_ref()),
LValue::MemberAccess { object, .. } => Self::lvalue_ident_def(object.as_ref()),
LValue::Dereference { reference, .. } => Self::lvalue_ident_def(reference.as_ref()),
}
}

Expand Down Expand Up @@ -462,6 +469,9 @@ impl IrGenerator {
let value = val.get_field_member(*field_index).clone();
self.assign_pattern(&value, rhs)?;
}
LValue::Dereference { .. } => {
unreachable!("Mutable references are unsupported in the old ssa backend")
}
}
Ok(Value::dummy())
}
Expand Down
3 changes: 2 additions & 1 deletion crates/noirc_evaluator/src/ssa/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ impl Value {
| Type::String(..)
| Type::Integer(..)
| Type::Bool
| Type::Field => Value::Node(*iter.next().unwrap()),
| Type::Field
| Type::MutableReference(_) => Value::Node(*iter.next().unwrap()),
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(crate) enum RuntimeType {
// Unconstrained function, to be compiled to brillig and executed by the Brillig VM
Brillig,
}

/// A function holds a list of instructions.
/// These instructions are further grouped into Basic blocks
///
Expand Down
40 changes: 28 additions & 12 deletions crates/noirc_evaluator/src/ssa_refactor/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,39 @@ impl PerBlockContext {
dfg: &mut DataFlowGraph,
) -> HashSet<AllocId> {
let mut protected_allocations = HashSet::new();
let mut loads_to_substitute = HashMap::new();
let block = &dfg[self.block_id];

// Maps Load instruction id -> value to replace the result of the load with
let mut loads_to_substitute = HashMap::new();

// Maps Load result id -> value to replace the result of the load with
let mut load_values_to_substitute = HashMap::new();

for instruction_id in block.instructions() {
match &dfg[*instruction_id] {
Instruction::Store { address, value } => {
self.last_stores.insert(*address, *value);
Instruction::Store { mut address, value } => {
if let Some(value) = load_values_to_substitute.get(&address) {
address = *value;
}

self.last_stores.insert(address, *value);
self.store_ids.push(*instruction_id);
}
Instruction::Load { address } => {
if let Some(last_value) = self.last_stores.get(address) {
Instruction::Load { mut address } => {
if let Some(value) = load_values_to_substitute.get(&address) {
address = *value;
}

if let Some(last_value) = self.last_stores.get(&address) {
let result_value = *dfg
.instruction_results(*instruction_id)
.first()
.expect("ICE: Load instructions should have single result");

loads_to_substitute.insert(*instruction_id, *last_value);
load_values_to_substitute.insert(result_value, *last_value);
} else {
protected_allocations.insert(*address);
protected_allocations.insert(address);
}
}
Instruction::Call { arguments, .. } => {
Expand All @@ -103,12 +122,9 @@ impl PerBlockContext {
}

// Substitute load result values
for (instruction_id, new_value) in &loads_to_substitute {
let result_value = *dfg
.instruction_results(*instruction_id)
.first()
.expect("ICE: Load instructions should have single result");
dfg.set_value_from_id(result_value, *new_value);
for (result_value, new_value) in load_values_to_substitute {
let result_value = dfg.resolve(result_value);
dfg.set_value_from_id(result_value, new_value);
}

// Delete load instructions
Expand Down
50 changes: 43 additions & 7 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,17 @@ impl<'a> FunctionContext<'a> {

// This helper is needed because we need to take f by mutable reference,
// otherwise we cannot move it multiple times each loop of vecmap.
fn map_type_helper<T>(typ: &ast::Type, f: &mut impl FnMut(Type) -> T) -> Tree<T> {
fn map_type_helper<T>(typ: &ast::Type, f: &mut dyn FnMut(Type) -> T) -> Tree<T> {
match typ {
ast::Type::Tuple(fields) => {
Tree::Branch(vecmap(fields, |field| Self::map_type_helper(field, f)))
}
ast::Type::Unit => Tree::empty(),
// A mutable reference wraps each element into a reference.
// This can be multiple values if the element type is a tuple.
ast::Type::MutableReference(element) => {
Self::map_type_helper(element, &mut |_| f(Type::Reference))
}
other => Tree::Leaf(f(Self::convert_non_tuple_type(other))),
}
}
Expand Down Expand Up @@ -201,6 +206,11 @@ impl<'a> FunctionContext<'a> {
ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"),
ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"),
ast::Type::Function(_, _) => Type::Function,
ast::Type::MutableReference(element) => {
// Recursive call to panic if element is a tuple
Self::convert_non_tuple_type(element);
Type::Reference
}

// How should we represent Vecs?
// Are they a struct of array + length + capacity?
Expand Down Expand Up @@ -473,11 +483,23 @@ impl<'a> FunctionContext<'a> {
let object_lvalue = Box::new(object_lvalue);
LValue::MemberAccess { old_object, object_lvalue, index: *field_index }
}
ast::LValue::Dereference { reference, .. } => {
let (reference, _) = self.extract_current_value_recursive(reference);
LValue::Dereference { reference }
}
}
}

fn dereference(&mut self, values: &Values, element_type: &ast::Type) -> Values {
let element_types = Self::convert_type(element_type);
values.map_both(element_types, |value, element_type| {
let reference = value.eval(self);
self.builder.insert_load(reference, element_type).into()
})
jfecher marked this conversation as resolved.
Show resolved Hide resolved
}

/// Compile the given identifier as a reference - ie. avoid calling .eval()
fn ident_lvalue(&self, ident: &ast::Ident) -> Values {
fn ident_lvalue(&mut self, ident: &ast::Ident) -> Values {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
match &ident.definition {
ast::Definition::Local(id) => self.lookup(*id),
other => panic!("Unexpected definition found for mutable value: {other}"),
Expand Down Expand Up @@ -516,16 +538,19 @@ impl<'a> FunctionContext<'a> {
let element = Self::get_field_ref(&old_object, *index).clone();
(element, LValue::MemberAccess { old_object, object_lvalue, index: *index })
}
ast::LValue::Dereference { reference, element_type } => {
let (reference, _) = self.extract_current_value_recursive(reference);
let dereferenced = self.dereference(&reference, element_type);
(dereferenced, LValue::Dereference { reference })
}
}
}

/// Assigns a new value to the given LValue.
/// The LValue can be created via a previous call to extract_current_value.
/// This method recurs on the given LValue to create a new value to assign an allocation
/// instruction within an LValue::Ident - see the comment on `extract_current_value` for more
/// details.
/// When first-class references are supported the nearest reference may be in any LValue
/// variant rather than just LValue::Ident.
/// instruction within an LValue::Ident or LValue::Dereference - see the comment on
/// `extract_current_value` for more details.
pub(super) fn assign_new_value(&mut self, lvalue: LValue, new_value: Values) {
match lvalue {
LValue::Ident(references) => self.assign(references, new_value),
Expand All @@ -535,9 +560,13 @@ impl<'a> FunctionContext<'a> {
self.assign_new_value(*array_lvalue, new_array.into());
}
LValue::MemberAccess { old_object, index, object_lvalue } => {
let old_object = old_object.map(|value| value.eval(self).into());
jfecher marked this conversation as resolved.
Show resolved Hide resolved
let new_object = Self::replace_field(old_object, index, new_value);
self.assign_new_value(*object_lvalue, new_object);
}
LValue::Dereference { reference } => {
self.assign(reference, new_value);
}
}
}

Expand All @@ -553,7 +582,12 @@ impl<'a> FunctionContext<'a> {
}
}
(Tree::Leaf(lhs), Tree::Leaf(rhs)) => {
let (lhs, rhs) = (lhs.eval_reference(), rhs.eval(self));
let (lhs2, rhs2) = (lhs.clone().eval_reference(), rhs.clone().eval(self));

println!("Created store {lhs2} <- {rhs2} from {lhs:?} <- {rhs:?}");

let lhs = lhs2;
let rhs = rhs2;
jfecher marked this conversation as resolved.
Show resolved Hide resolved
self.builder.insert_store(lhs, rhs);
}
(lhs, rhs) => {
Expand Down Expand Up @@ -705,8 +739,10 @@ impl SharedContext {
}

/// Used to remember the results of each step of extracting a value from an ast::LValue
#[derive(Debug)]
pub(super) enum LValue {
Ident(Values),
Index { old_array: ValueId, index: ValueId, array_lvalue: Box<LValue> },
MemberAccess { old_object: Values, index: usize, object_lvalue: Box<LValue> },
Dereference { reference: Values },
}
Loading