Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add de-sugaring for impl Trait in function parameters #4919

Merged
merged 9 commits into from
Apr 29, 2024
8 changes: 8 additions & 0 deletions compiler/noirc_arena/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#![warn(unreachable_pub)]
#![warn(clippy::semicolon_if_nothing_returned)]

use std::fmt;

#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub struct Index(usize);

Expand All @@ -25,6 +27,12 @@ impl Index {
}
}

impl fmt::Display for Index {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

#[derive(Clone, Debug)]
pub struct Arena<T> {
pub vec: Vec<T>,
Expand Down
60 changes: 56 additions & 4 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::ast::{
ArrayLiteral, BinaryOpKind, BlockExpression, Distinctness, Expression, ExpressionKind,
ForRange, FunctionDefinition, FunctionKind, FunctionReturnType, Ident, ItemVisibility, LValue,
LetStatement, Literal, NoirFunction, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern,
Statement, StatementKind, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint,
Statement, StatementKind, TraitBound, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint,
UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use crate::graph::CrateId;
Expand Down Expand Up @@ -202,23 +202,70 @@ impl<'a> Resolver<'a> {
self.errors.push(err);
}

/// This turns function parameters of the form:
/// fn foo(x: impl Bar)
///
/// into
/// fn foo<T0_impl_Bar>(x: T0_impl_Bar) where T0_impl_Bar: Bar
fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) {
let mut impl_trait_generics = HashSet::new();
let mut counter: usize = 0;
for parameter in func.def.parameters.iter_mut() {
if let UnresolvedTypeData::TraitAsType(path, args) = &parameter.typ.typ {
let mut new_generic_ident: Ident =
format!("T{}_impl_{}", func_id, path.as_string()).into();
let mut new_generic_path = Path::from_ident(new_generic_ident.clone());
while impl_trait_generics.contains(&new_generic_ident)
|| self.lookup_generic_or_global_type(&new_generic_path).is_some()
{
new_generic_ident =
format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into();
new_generic_path = Path::from_ident(new_generic_ident.clone());
counter += 1;
}
impl_trait_generics.insert(new_generic_ident.clone());

let is_synthesized = true;
let new_generic_type_data =
UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized);
let new_generic_type =
UnresolvedType { typ: new_generic_type_data.clone(), span: None };
let new_trait_bound = TraitBound {
trait_path: path.clone(),
trait_id: None,
trait_generics: args.to_vec(),
};
let new_trait_constraint = UnresolvedTraitConstraint {
typ: new_generic_type,
trait_bound: new_trait_bound,
};

parameter.typ.typ = new_generic_type_data;
func.def.generics.push(new_generic_ident);
func.def.where_clause.push(new_trait_constraint);
}
}
self.add_generics(&impl_trait_generics.into_iter().collect());
}

/// Resolving a function involves interning the metadata
/// interning any statements inside of the function
/// and interning the function itself
/// We resolve and lower the function at the same time
/// Since lowering would require scope data, unless we add an extra resolution field to the AST
pub fn resolve_function(
mut self,
func: NoirFunction,
mut func: NoirFunction,
func_id: FuncId,
) -> (HirFunction, FuncMeta, Vec<ResolverError>) {
self.scopes.start_function();
self.current_item = Some(DependencyId::Function(func_id));

// Check whether the function has globals in the local module and add them to the scope
self.resolve_local_globals();

self.add_generics(&func.def.generics);

self.desugar_impl_trait_args(&mut func, func_id);
self.trait_bounds = func.def.where_clause.clone();

let is_low_level_or_oracle = func
Expand Down Expand Up @@ -1150,10 +1197,15 @@ impl<'a> Resolver<'a> {
| Type::TypeVariable(_, _)
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::TraitAsType(..)
| Type::Code
| Type::Forall(_, _) => (),

Type::TraitAsType(_, _, args) => {
for arg in args {
Self::find_numeric_generics_in_type(arg, found);
}
}

Type::Array(length, element_type) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
found.insert(name.to_string(), type_variable.clone());
Expand Down
33 changes: 20 additions & 13 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,11 @@
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::Forall(_, _)
| Type::Code
| Type::TraitAsType(..) => false,
| Type::Code => false,

Type::TraitAsType(_, _, args) => {
args.iter().any(|generic| generic.contains_numeric_typevar(target_id))
}
Type::Array(length, elem) => {
elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length)
}
Expand Down Expand Up @@ -1571,7 +1573,7 @@
Type::Tuple(fields)
}
Type::Forall(typevars, typ) => {
// Trying to substitute_helper a variable de, substitute_bound_typevarsfined within a nested Forall

Check warning on line 1576 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (typevarsfined)
// is usually impossible and indicative of an error in the type checker somewhere.
for var in typevars {
assert!(!type_bindings.contains_key(&var.id()));
Expand All @@ -1591,11 +1593,17 @@
element.substitute_helper(type_bindings, substitute_bound_typevars),
)),

Type::TraitAsType(s, name, args) => {
let args = vecmap(args, |arg| {
arg.substitute_helper(type_bindings, substitute_bound_typevars)
});
Type::TraitAsType(*s, name.clone(), args)
}

Type::FieldElement
| Type::Integer(_, _)
| Type::Bool
| Type::Constant(_)
| Type::TraitAsType(..)
| Type::Error
| Type::Code
| Type::Unit => self.clone(),
Expand All @@ -1613,7 +1621,9 @@
let field_occurs = fields.occurs(target_id);
len_occurs || field_occurs
}
Type::Struct(_, generic_args) | Type::Alias(_, generic_args) => {
Type::Struct(_, generic_args)
| Type::Alias(_, generic_args)
| Type::TraitAsType(_, _, generic_args) => {
generic_args.iter().any(|arg| arg.occurs(target_id))
}
Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)),
Expand All @@ -1637,7 +1647,6 @@
| Type::Integer(_, _)
| Type::Bool
| Type::Constant(_)
| Type::TraitAsType(..)
| Type::Error
| Type::Code
| Type::Unit => false,
Expand Down Expand Up @@ -1689,16 +1698,14 @@

MutableReference(element) => MutableReference(Box::new(element.follow_bindings())),

TraitAsType(s, name, args) => {
let args = vecmap(args, |arg| arg.follow_bindings());
TraitAsType(*s, name.clone(), args)
}

// Expect that this function should only be called on instantiated types
Forall(..) => unreachable!(),
TraitAsType(..)
| FieldElement
| Integer(_, _)
| Bool
| Constant(_)
| Unit
| Code
| Error => self.clone(),
FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Code | Error => self.clone(),
}
}

Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::ops::Deref;

use fm::FileId;
Expand Down Expand Up @@ -314,6 +315,12 @@ impl FuncId {
}
}

impl fmt::Display for FuncId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone, PartialOrd, Ord)]
pub struct StructId(ModuleId);

Expand Down
5 changes: 5 additions & 0 deletions compiler/noirc_frontend/src/parser/parser/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ mod test {
"fn func_name<T>(f: Field, y : T) where T: SomeTrait + {}",
// The following should produce compile error on later stage. From the parser's perspective it's fine
"fn func_name<A>(f: Field, y : Field, z : Field) where T: SomeTrait {}",
// TODO: this fails with known EOF != EOF error
// https://github.com/noir-lang/noir/issues/4763
// fn func_name(x: impl Eq) {} with error Expected an end of input but found end of input
// "fn func_name(x: impl Eq) {}",
"fn func_name<T>(x: impl Eq, y : T) where T: SomeTrait + Eq {}",
],
);

Expand Down
64 changes: 64 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,70 @@ mod test {
}
}

#[test]
fn check_trait_as_type_as_fn_parameter() {
let src = "
trait Eq {
fn eq(self, other: Self) -> bool;
}

struct Foo {
a: u64,
}

impl Eq for Foo {
fn eq(self, other: Foo) -> bool { self.a == other.a }
}

fn test_eq(x: impl Eq) -> bool {
x.eq(x)
}

fn main(a: Foo) -> pub bool {
test_eq(a)
}";

let errors = get_program_errors(src);
errors.iter().for_each(|err| println!("{:?}", err));
assert!(errors.is_empty());
}

#[test]
fn check_trait_as_type_as_two_fn_parameters() {
let src = "
trait Eq {
fn eq(self, other: Self) -> bool;
}

trait Test {
fn test(self) -> bool;
}

struct Foo {
a: u64,
}

impl Eq for Foo {
fn eq(self, other: Foo) -> bool { self.a == other.a }
}

impl Test for u64 {
fn test(self) -> bool { self == self }
}

fn test_eq(x: impl Eq, y: impl Test) -> bool {
x.eq(x) == y.test()
}

fn main(a: Foo, b: u64) -> pub bool {
test_eq(a, b)
}";

let errors = get_program_errors(src);
errors.iter().for_each(|err| println!("{:?}", err));
assert!(errors.is_empty());
}

fn get_program_captures(src: &str) -> Vec<Vec<String>> {
let (program, context, _errors) = get_program(src);
let interner = context.def_interner;
Expand Down
3 changes: 3 additions & 0 deletions tooling/nargo_fmt/tests/expected/impl_trait_fn_parameter.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn func_name(x: impl Eq) {}

fn func_name<T>(x: impl Eq, y: T) where T: SomeTrait + Eq {}
3 changes: 3 additions & 0 deletions tooling/nargo_fmt/tests/input/impl_trait_fn_parameter.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn func_name(x: impl Eq) {}

fn func_name<T>(x: impl Eq, y: T) where T: SomeTrait + Eq {}
Loading