Skip to content

Commit

Permalink
feat!: Allow impls on primitive types (#847)
Browse files Browse the repository at this point in the history
* Allow impls on primitives

* Update extra tests

* Fix pedersen calls

* Fix remaining test errors

* Turn hash methods back into functions

* Fix stdlib

* Revert nargo tests

* Format and update stdlib syntax

* Remove 'if true'
  • Loading branch information
jfecher authored Feb 16, 2023
1 parent 595e3c3 commit 479da0e
Show file tree
Hide file tree
Showing 20 changed files with 262 additions and 237 deletions.
4 changes: 1 addition & 3 deletions crates/nargo/tests/target_tests_data/pass/import/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
mod import;
use dep::std;

use crate::import::hello;

fn main(x : Field, y : Field) {
let _k = std::hash::pedersen([x]);
let _k = dep::std::hash::pedersen([x]);
let _l = hello(x);

constrain x != import::hello(y);
Expand Down
2 changes: 1 addition & 1 deletion crates/nargo/tests/test_data/7_function/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn test2(z: Field, t: u32 ) {

fn pow(base: Field, exponent: Field) -> Field {
let mut r = 1 as Field;
let b = std::field::to_le_bits(exponent, 32 as u32);
let b = exponent.to_le_bits(32 as u32);
for i in 1..33 {
r = r*r;
r = (b[32-i] as Field) * (r * base) + (1 - b[32-i] as Field) * r;
Expand Down
5 changes: 2 additions & 3 deletions crates/nargo/tests/test_data/9_conditional/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ fn test4() -> [u32; 4] {
}

fn main(a: u32, mut c: [u32; 4], x: [u8; 5], result: pub [u8; 32]){

// Regression test for issue #547
// Warning: it must be kept at the start of main
let arr: [u8; 2] = [1, 2];
Expand All @@ -49,7 +48,7 @@ fn main(a: u32, mut c: [u32; 4], x: [u8; 5], result: pub [u8; 32]){
//Issue reported in #421
if a == c[0] {
constrain c[0] == 0;
} else {
} else {
if a == c[1] {
constrain c[1] == 0;
} else {
Expand All @@ -64,7 +63,7 @@ fn main(a: u32, mut c: [u32; 4], x: [u8; 5], result: pub [u8; 32]){
let as_bits_hardcode_1 = [1, 0];
let mut c1 = 0;
for i in 0..2 {
let mut as_bits = std::field::to_le_bits(arr[i] as Field, 2);
let mut as_bits = (arr[i] as Field).to_le_bits(2);
c1 = c1 + as_bits[0] as Field;

if i == 0 {
Expand Down
8 changes: 4 additions & 4 deletions crates/nargo/tests/test_data/array_len/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use dep::std;

fn len_plus_1<T>(array: [T]) -> Field {
std::array::len(array) + 1
array.len() + 1
}

fn add_lens<T>(a: [T], b: [Field]) -> Field {
std::array::len(a) + std::array::len(b)
a.len() + b.len()
}

fn nested_call(b: [Field]) -> Field {
Expand All @@ -19,13 +19,13 @@ fn main(len3: [u8; 3], len4: [Field; 4]) {
constrain nested_call(len4) == 5;

// std::array::len returns a comptime value
constrain len4[std::array::len(len3)] == 4;
constrain len4[len3.len()] == 4;

// test for std::array::sort
let mut unsorted = len3;
unsorted[0] = len3[1];
unsorted[1] = len3[0];
constrain unsorted[0] > unsorted[1];
let sorted = std::array::sort(unsorted);
let sorted = unsorted.sort();
constrain sorted[0] < sorted[1];
}
2 changes: 1 addition & 1 deletion crates/nargo/tests/test_data/hash_to_field/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ use dep::std;

fn main(input : Field) -> pub Field {
std::hash::hash_to_field([input])
}
}
10 changes: 5 additions & 5 deletions crates/nargo/tests/test_data/higher-order-functions/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ fn main() -> pub Field {
/// Test the array functions in std::array
fn test_array_functions() {
let myarray: [i32; 3] = [1, 2, 3];
constrain std::array::any(myarray, |n| n > 2);
constrain myarray.any(|n| n > 2);

let evens: [i32; 3] = [2, 4, 6];
constrain std::array::all(evens, |n| n > 1);
constrain evens.all(|n| n > 1);

constrain std::array::fold(evens, 0, |a, b| a + b) == 12;
constrain std::array::reduce(evens, |a, b| a + b) == 12;
constrain evens.fold(0, |a, b| a + b) == 12;
constrain evens.reduce(|a, b| a + b) == 12;

let descending = std::array::sort_via(myarray, |a, b| a > b);
let descending = myarray.sort_via(|a, b| a > b);
constrain descending == [3, 2, 1];
}

Expand Down
2 changes: 1 addition & 1 deletion crates/nargo/tests/test_data/strings/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ struct Test {
a: Field,
b: Field,
c: [Field; 2],
}
}
4 changes: 2 additions & 2 deletions crates/nargo/tests/test_data/struct_inputs/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main(x : Field, y : pub myStruct, z: pub foo::bar::barStruct, a: pub foo::foo

check_inner_struct(a, z);

for i in 0..std::array::len(struct_from_bar.array) {
for i in 0 .. struct_from_bar.array.len() {
constrain struct_from_bar.array[i] == z.array[i];
}
constrain z.val == struct_from_bar.val;
Expand All @@ -30,7 +30,7 @@ fn main(x : Field, y : pub myStruct, z: pub foo::bar::barStruct, a: pub foo::foo

fn check_inner_struct(a: foo::fooStruct, z: foo::bar::barStruct) {
constrain a.bar_struct.val == z.val;
for i in 0..std::array::len(a.bar_struct.array) {
for i in 0.. a.bar_struct.array.len() {
constrain a.bar_struct.array[i] == z.array[i];
}
}
4 changes: 2 additions & 2 deletions crates/nargo/tests/test_data/to_le_bytes/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use dep::std;

fn main(x : Field) -> pub [u8; 4] {
// The result of this byte array will be little-endian
let byte_array = std::field::to_le_bytes(x, 31);
let byte_array = x.to_le_bytes(31);
let mut first_four_bytes = [0; 4];
for i in 0..4 {
first_four_bytes[i] = byte_array[i];
Expand All @@ -11,4 +11,4 @@ fn main(x : Field) -> pub [u8; 4] {
// We were incorrectly mapping our output array from bit decomposition functions during acir generation
first_four_bytes[3] = byte_array[31];
first_four_bytes
}
}
7 changes: 5 additions & 2 deletions crates/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::dc_mod::collect_defs;
use super::errors::DefCollectorErrorKind;
use crate::graph::CrateId;
use crate::graph::{CrateId, LOCAL_CRATE};
use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleId};
use crate::hir::resolution::errors::ResolverError;
use crate::hir::resolution::resolver::Resolver;
Expand Down Expand Up @@ -226,7 +226,10 @@ fn collect_impls(
errors.push(err.into_file_diagnostic(unresolved.file_id));
}
}
} else if typ != Type::Error {
// Prohibit defining impls for primitive types if we're in the local crate.
// We should really prevent it for all crates that aren't the noir stdlib but
// there is no way of checking if the current crate is the stdlib currently.
} else if typ != Type::Error && crate_id == LOCAL_CRATE {
let span = *span;
let error = DefCollectorErrorKind::NonStructTypeInImpl { span };
errors.push(error.into_file_diagnostic(unresolved.file_id))
Expand Down
17 changes: 10 additions & 7 deletions crates/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,16 @@ fn lookup_method(

// In the future we could support methods for non-struct types if we have a context
// (in the interner?) essentially resembling HashMap<Type, Methods>
other => {
errors.push(TypeCheckError::Unstructured {
span: interner.expr_span(expr_id),
msg: format!("Type '{other}' must be a struct type to call methods on it"),
});
None
}
other => match interner.lookup_primitive_method(other, method_name) {
Some(method_id) => Some(method_id),
None => {
errors.push(TypeCheckError::Unstructured {
span: interner.expr_span(expr_id),
msg: format!("No method named '{method_name}' found for type '{other}'",),
});
None
}
},
}
}

Expand Down
69 changes: 61 additions & 8 deletions crates/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,11 @@ pub struct NodeInterner {

delayed_type_checks: Vec<TypeCheckFn>,

// A map from a struct type and method name to a function id for the method
// along with any generic on the struct it may require. E.g. if the impl is
// only for `impl Foo<String>` rather than all Foo, the generics will be `vec![String]`.
struct_methods: HashMap<(StructId, String), (Vec<Type>, FuncId)>,
/// A map from a struct type and method name to a function id for the method.
struct_methods: HashMap<(StructId, String), FuncId>,

/// Methods on primitive types defined in the stdlib.
primitive_methods: HashMap<(TypeMethodKey, String), FuncId>,
}

type TypeCheckFn = Box<dyn FnOnce() -> Result<(), TypeCheckError>>;
Expand Down Expand Up @@ -241,6 +242,7 @@ impl Default for NodeInterner {
language: Language::R1CS,
delayed_type_checks: vec![],
struct_methods: HashMap::new(),
primitive_methods: HashMap::new(),
};

// An empty block expression is used often, we add this into the `node` on startup
Expand Down Expand Up @@ -585,16 +587,67 @@ impl NodeInterner {
method_id: FuncId,
) -> Option<FuncId> {
match self_type {
Type::Struct(struct_type, generics) => {
Type::Struct(struct_type, _generics) => {
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, (generics.clone(), method_id)).map(|(_, id)| id)
self.struct_methods.insert(key, method_id)
}
Type::Error => None,

other => {
let key = get_type_method_key(self_type).unwrap_or_else(|| {
unreachable!("Cannot add a method to the unsupported type '{}'", other)
});
self.primitive_methods.insert((key, method_name), method_id)
}
other => unreachable!("Tried adding method to non-struct type '{}'", other),
}
}

/// Search by name for a method on the given struct
pub fn lookup_method(&self, id: StructId, method_name: &str) -> Option<FuncId> {
self.struct_methods.get(&(id, method_name.to_owned())).map(|(_, id)| *id)
self.struct_methods.get(&(id, method_name.to_owned())).copied()
}

/// Looks up a given method name on the given primitive type.
pub fn lookup_primitive_method(&self, typ: &Type, method_name: &str) -> Option<FuncId> {
get_type_method_key(typ)
.and_then(|key| self.primitive_methods.get(&(key, method_name.to_owned())).copied())
}
}

/// These are the primitive type variants that we support adding methods to
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
enum TypeMethodKey {
/// Fields and integers share methods for ease of use. These methods may still
/// accept only fields or integers, it is just that their names may not clash.
FieldOrInt,
Array,
Bool,
String,
Unit,
Tuple,
Function,
}

fn get_type_method_key(typ: &Type) -> Option<TypeMethodKey> {
use TypeMethodKey::*;
let typ = typ.follow_bindings();
match &typ {
Type::FieldElement(_) => Some(FieldOrInt),
Type::Array(_, _) => Some(Array),
Type::Integer(_, _, _) => Some(FieldOrInt),
Type::PolymorphicInteger(_, _) => Some(FieldOrInt),
Type::Bool(_) => Some(Bool),
Type::String(_) => Some(String),
Type::Unit => Some(Unit),
Type::Tuple(_) => Some(Tuple),
Type::Function(_, _) => Some(Function),

// We do not support adding methods to these types
Type::TypeVariable(_)
| Type::NamedGeneric(_, _)
| Type::Forall(_, _)
| Type::Constant(_)
| Type::Error
| Type::Struct(_, _) => None,
}
}
18 changes: 7 additions & 11 deletions crates/noirc_frontend/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::sync::atomic::{AtomicU32, Ordering};
use crate::token::{Keyword, Token};
use crate::{ast::ImportStatement, Expression, NoirStruct};
use crate::{
BlockExpression, CallExpression, ExpressionKind, ForExpression, Ident, IndexExpression,
LetStatement, NoirFunction, NoirImpl, Path, PathKind, Pattern, Recoverable, Statement,
BlockExpression, ExpressionKind, ForExpression, Ident, IndexExpression, LetStatement,
MethodCallExpression, NoirFunction, NoirImpl, Path, PathKind, Pattern, Recoverable, Statement,
UnresolvedType,
};

Expand Down Expand Up @@ -371,19 +371,15 @@ impl ForRange {
expression: array,
});

let ident = |name: &str| Ident::new(name.to_string(), array_span);

// std::array::len(array)
// array.len()
let segments = vec![array_ident];
let array_ident =
ExpressionKind::Variable(Path { segments, kind: PathKind::Plain });

let segments = vec![ident("std"), ident("array"), ident("len")];
let func_ident = ExpressionKind::Variable(Path { segments, kind: PathKind::Dep });

let end_range = ExpressionKind::Call(Box::new(CallExpression {
func: Box::new(Expression::new(func_ident, array_span)),
arguments: vec![Expression::new(array_ident.clone(), array_span)],
let end_range = ExpressionKind::MethodCall(Box::new(MethodCallExpression {
object: Expression::new(array_ident.clone(), array_span),
method_name: Ident::new("len".to_string(), array_span),
arguments: vec![],
}));
let end_range = Expression::new(end_range, array_span);

Expand Down
Loading

0 comments on commit 479da0e

Please sign in to comment.