diff --git a/compiler/noirc_arena/src/lib.rs b/compiler/noirc_arena/src/lib.rs index 2d117304e16..9a25299d6c8 100644 --- a/compiler/noirc_arena/src/lib.rs +++ b/compiler/noirc_arena/src/lib.rs @@ -3,6 +3,8 @@ #![warn(unreachable_pub)] #![warn(clippy::semicolon_if_nothing_returned)] +use std::fmt; + #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)] pub struct Index(usize); @@ -25,6 +27,12 @@ impl Index { } } +impl fmt::Display for Index { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + #[derive(Clone, Debug)] pub struct Arena { pub vec: Vec, diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index b13ffe05c2d..bef0ebdaacc 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -29,7 +29,7 @@ use crate::ast::{ ArrayLiteral, BinaryOpKind, BlockExpression, Distinctness, Expression, ExpressionKind, ForRange, FunctionDefinition, FunctionKind, FunctionReturnType, Ident, ItemVisibility, LValue, LetStatement, Literal, NoirFunction, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern, - Statement, StatementKind, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, + Statement, StatementKind, TraitBound, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT, }; use crate::graph::CrateId; @@ -202,6 +202,52 @@ impl<'a> Resolver<'a> { self.errors.push(err); } + /// This turns function parameters of the form: + /// fn foo(x: impl Bar) + /// + /// into + /// fn foo(x: T0_impl_Bar) where T0_impl_Bar: Bar + fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) { + let mut impl_trait_generics = HashSet::new(); + let mut counter: usize = 0; + for parameter in func.def.parameters.iter_mut() { + if let UnresolvedTypeData::TraitAsType(path, args) = ¶meter.typ.typ { + let mut new_generic_ident: Ident = + format!("T{}_impl_{}", func_id, path.as_string()).into(); + let mut new_generic_path = Path::from_ident(new_generic_ident.clone()); + while impl_trait_generics.contains(&new_generic_ident) + || self.lookup_generic_or_global_type(&new_generic_path).is_some() + { + new_generic_ident = + format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into(); + new_generic_path = Path::from_ident(new_generic_ident.clone()); + counter += 1; + } + impl_trait_generics.insert(new_generic_ident.clone()); + + let is_synthesized = true; + let new_generic_type_data = + UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized); + let new_generic_type = + UnresolvedType { typ: new_generic_type_data.clone(), span: None }; + let new_trait_bound = TraitBound { + trait_path: path.clone(), + trait_id: None, + trait_generics: args.to_vec(), + }; + let new_trait_constraint = UnresolvedTraitConstraint { + typ: new_generic_type, + trait_bound: new_trait_bound, + }; + + parameter.typ.typ = new_generic_type_data; + func.def.generics.push(new_generic_ident); + func.def.where_clause.push(new_trait_constraint); + } + } + self.add_generics(&impl_trait_generics.into_iter().collect()); + } + /// Resolving a function involves interning the metadata /// interning any statements inside of the function /// and interning the function itself @@ -209,7 +255,7 @@ impl<'a> Resolver<'a> { /// Since lowering would require scope data, unless we add an extra resolution field to the AST pub fn resolve_function( mut self, - func: NoirFunction, + mut func: NoirFunction, func_id: FuncId, ) -> (HirFunction, FuncMeta, Vec) { self.scopes.start_function(); @@ -217,8 +263,9 @@ impl<'a> Resolver<'a> { // Check whether the function has globals in the local module and add them to the scope self.resolve_local_globals(); - self.add_generics(&func.def.generics); + + self.desugar_impl_trait_args(&mut func, func_id); self.trait_bounds = func.def.where_clause.clone(); let is_low_level_or_oracle = func @@ -1150,10 +1197,15 @@ impl<'a> Resolver<'a> { | Type::TypeVariable(_, _) | Type::Constant(_) | Type::NamedGeneric(_, _) - | Type::TraitAsType(..) | Type::Code | Type::Forall(_, _) => (), + Type::TraitAsType(_, _, args) => { + for arg in args { + Self::find_numeric_generics_in_type(arg, found); + } + } + Type::Array(length, element_type) => { if let Type::NamedGeneric(type_variable, name) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 8c6e3d48fca..637f3c99e89 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -643,9 +643,11 @@ impl Type { | Type::Constant(_) | Type::NamedGeneric(_, _) | Type::Forall(_, _) - | Type::Code - | Type::TraitAsType(..) => false, + | Type::Code => false, + Type::TraitAsType(_, _, args) => { + args.iter().any(|generic| generic.contains_numeric_typevar(target_id)) + } Type::Array(length, elem) => { elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length) } @@ -1591,11 +1593,17 @@ impl Type { element.substitute_helper(type_bindings, substitute_bound_typevars), )), + Type::TraitAsType(s, name, args) => { + let args = vecmap(args, |arg| { + arg.substitute_helper(type_bindings, substitute_bound_typevars) + }); + Type::TraitAsType(*s, name.clone(), args) + } + Type::FieldElement | Type::Integer(_, _) | Type::Bool | Type::Constant(_) - | Type::TraitAsType(..) | Type::Error | Type::Code | Type::Unit => self.clone(), @@ -1613,7 +1621,9 @@ impl Type { let field_occurs = fields.occurs(target_id); len_occurs || field_occurs } - Type::Struct(_, generic_args) | Type::Alias(_, generic_args) => { + Type::Struct(_, generic_args) + | Type::Alias(_, generic_args) + | Type::TraitAsType(_, _, generic_args) => { generic_args.iter().any(|arg| arg.occurs(target_id)) } Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), @@ -1637,7 +1647,6 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Constant(_) - | Type::TraitAsType(..) | Type::Error | Type::Code | Type::Unit => false, @@ -1689,16 +1698,14 @@ impl Type { MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), + TraitAsType(s, name, args) => { + let args = vecmap(args, |arg| arg.follow_bindings()); + TraitAsType(*s, name.clone(), args) + } + // Expect that this function should only be called on instantiated types Forall(..) => unreachable!(), - TraitAsType(..) - | FieldElement - | Integer(_, _) - | Bool - | Constant(_) - | Unit - | Code - | Error => self.clone(), + FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Code | Error => self.clone(), } } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 7ec32613cb7..88adc7a9414 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::collections::HashMap; +use std::fmt; use std::ops::Deref; use fm::FileId; @@ -314,6 +315,12 @@ impl FuncId { } } +impl fmt::Display for FuncId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + #[derive(Debug, Eq, PartialEq, Hash, Copy, Clone, PartialOrd, Ord)] pub struct StructId(ModuleId); diff --git a/compiler/noirc_frontend/src/parser/parser/function.rs b/compiler/noirc_frontend/src/parser/parser/function.rs index f39b2ad6292..ec4728fba4f 100644 --- a/compiler/noirc_frontend/src/parser/parser/function.rs +++ b/compiler/noirc_frontend/src/parser/parser/function.rs @@ -193,6 +193,11 @@ mod test { "fn func_name(f: Field, y : T) where T: SomeTrait + {}", // The following should produce compile error on later stage. From the parser's perspective it's fine "fn func_name(f: Field, y : Field, z : Field) where T: SomeTrait {}", + // TODO: this fails with known EOF != EOF error + // https://github.com/noir-lang/noir/issues/4763 + // fn func_name(x: impl Eq) {} with error Expected an end of input but found end of input + // "fn func_name(x: impl Eq) {}", + "fn func_name(x: impl Eq, y : T) where T: SomeTrait + Eq {}", ], ); diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index e017ea9e97b..cf2d7dbe153 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -746,6 +746,70 @@ mod test { } } + #[test] + fn check_trait_as_type_as_fn_parameter() { + let src = " + trait Eq { + fn eq(self, other: Self) -> bool; + } + + struct Foo { + a: u64, + } + + impl Eq for Foo { + fn eq(self, other: Foo) -> bool { self.a == other.a } + } + + fn test_eq(x: impl Eq) -> bool { + x.eq(x) + } + + fn main(a: Foo) -> pub bool { + test_eq(a) + }"; + + let errors = get_program_errors(src); + errors.iter().for_each(|err| println!("{:?}", err)); + assert!(errors.is_empty()); + } + + #[test] + fn check_trait_as_type_as_two_fn_parameters() { + let src = " + trait Eq { + fn eq(self, other: Self) -> bool; + } + + trait Test { + fn test(self) -> bool; + } + + struct Foo { + a: u64, + } + + impl Eq for Foo { + fn eq(self, other: Foo) -> bool { self.a == other.a } + } + + impl Test for u64 { + fn test(self) -> bool { self == self } + } + + fn test_eq(x: impl Eq, y: impl Test) -> bool { + x.eq(x) == y.test() + } + + fn main(a: Foo, b: u64) -> pub bool { + test_eq(a, b) + }"; + + let errors = get_program_errors(src); + errors.iter().for_each(|err| println!("{:?}", err)); + assert!(errors.is_empty()); + } + fn get_program_captures(src: &str) -> Vec> { let (program, context, _errors) = get_program(src); let interner = context.def_interner; diff --git a/tooling/nargo_fmt/tests/expected/impl_trait_fn_parameter.nr b/tooling/nargo_fmt/tests/expected/impl_trait_fn_parameter.nr new file mode 100644 index 00000000000..5ace5c60dcf --- /dev/null +++ b/tooling/nargo_fmt/tests/expected/impl_trait_fn_parameter.nr @@ -0,0 +1,3 @@ +fn func_name(x: impl Eq) {} + +fn func_name(x: impl Eq, y: T) where T: SomeTrait + Eq {} diff --git a/tooling/nargo_fmt/tests/input/impl_trait_fn_parameter.nr b/tooling/nargo_fmt/tests/input/impl_trait_fn_parameter.nr new file mode 100644 index 00000000000..5ace5c60dcf --- /dev/null +++ b/tooling/nargo_fmt/tests/input/impl_trait_fn_parameter.nr @@ -0,0 +1,3 @@ +fn func_name(x: impl Eq) {} + +fn func_name(x: impl Eq, y: T) where T: SomeTrait + Eq {}