From 88f174b785916f02232e38c0fca61f94effa96bf Mon Sep 17 00:00:00 2001 From: Yordan Madzhunkov Date: Thu, 12 Oct 2023 13:13:10 +0300 Subject: [PATCH] feat(traits): Add impl Trait as function return type #2397 --- compiler/noirc_frontend/src/ast/mod.rs | 11 ++++ .../src/hir/resolution/resolver.rs | 27 +++++++++ .../noirc_frontend/src/hir/type_check/expr.rs | 7 ++- .../noirc_frontend/src/hir/type_check/mod.rs | 59 ++++++++++++------- compiler/noirc_frontend/src/hir_def/traits.rs | 2 +- compiler/noirc_frontend/src/hir_def/types.rs | 29 ++++++--- .../src/monomorphization/mod.rs | 29 +++++---- compiler/noirc_frontend/src/node_interner.rs | 19 ++++++ compiler/noirc_frontend/src/parser/parser.rs | 8 +++ compiler/noirc_printable_type/src/lib.rs | 5 ++ .../trait_as_return_type/Nargo.toml | 7 +++ .../trait_as_return_type/Prover.toml | 1 + .../trait_as_return_type/src/main.nr | 55 +++++++++++++++++ tooling/noirc_abi/src/lib.rs | 1 + 14 files changed, 217 insertions(+), 43 deletions(-) create mode 100644 tooling/nargo_cli/tests/execution_success/trait_as_return_type/Nargo.toml create mode 100644 tooling/nargo_cli/tests/execution_success/trait_as_return_type/Prover.toml create mode 100644 tooling/nargo_cli/tests/execution_success/trait_as_return_type/src/main.nr diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index 662c3b28bef..c7dc48fb7d7 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -44,6 +44,9 @@ pub enum UnresolvedTypeData { /// A Named UnresolvedType can be a struct type or a type variable Named(Path, Vec), + /// A Trait as return type or parameter of function, including it's generics + TraitAsType(Path, Vec), + /// &mut T MutableReference(Box), @@ -112,6 +115,14 @@ impl std::fmt::Display for UnresolvedTypeData { write!(f, "{}<{}>", s, args.join(", ")) } } + TraitAsType(s, args) => { + let args = vecmap(args, |arg| ToString::to_string(&arg.typ)); + if args.is_empty() { + write!(f, "impl {s}") + } else { + write!(f, "impl {}<{}>", s, args.join(", ")) + } + } Tuple(elements) => { let elements = vecmap(elements, ToString::to_string); write!(f, "({})", elements.join(", ")) diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index f9bf01c0957..ca30b31e78d 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -376,6 +376,8 @@ impl<'a> Resolver<'a> { Unspecified => Type::Error, Error => Type::Error, Named(path, args) => self.resolve_named_type(path, args, new_variables), + TraitAsType(path, args) => self.resolve_trait_as_type(path, args, new_variables), + Tuple(fields) => { Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, new_variables))) } @@ -479,6 +481,19 @@ impl<'a> Resolver<'a> { } } + fn resolve_trait_as_type( + &mut self, + path: Path, + _args: Vec, + _new_variables: &mut Generics, + ) -> Type { + if let Some(t) = self.lookup_trait_or_error(path) { + Type::TraitAsType(t) + } else { + Type::Error + } + } + fn verify_generics_count( &mut self, expected_count: usize, @@ -874,6 +889,7 @@ impl<'a> Resolver<'a> { | Type::Constant(_) | Type::NamedGeneric(_, _) | Type::NotConstant + | Type::TraitAsType(_) | Type::Forall(_, _) => (), Type::Array(length, element_type) => { @@ -1430,6 +1446,17 @@ impl<'a> Resolver<'a> { } } + /// Lookup a given trait by name/path. + fn lookup_trait_or_error(&mut self, path: Path) -> Option { + match self.lookup(path) { + Ok(trait_id) => Some(self.get_trait(trait_id)), + Err(error) => { + self.push_err(error); + None + } + } + } + /// Looks up a given type by name. /// This will also instantiate any struct types found. fn lookup_type_or_error(&mut self, path: Path) -> Option { diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 27820922100..bda2e64bde4 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -35,6 +35,7 @@ impl<'interner> TypeChecker<'interner> { } } } + /// Infers a type for a given expression, and return this type. /// As a side-effect, this function will also remember this type in the NodeInterner /// for the given expr_id key. @@ -50,7 +51,7 @@ impl<'interner> TypeChecker<'interner> { // E.g. `fn foo(t: T, field: Field) -> T` has type `forall T. fn(T, Field) -> T`. // We must instantiate identifiers at every call site to replace this T with a new type // variable to handle generic functions. - let t = self.interner.id_type(ident.id); + let t = self.interner.id_type_substitute_trait_as_type(ident.id); let (typ, bindings) = t.instantiate(self.interner); self.interner.store_instantiation_bindings(*expr_id, bindings); typ @@ -131,7 +132,6 @@ impl<'interner> TypeChecker<'interner> { HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr), HirExpression::Call(call_expr) => { self.check_if_deprecated(&call_expr.func); - let function = self.check_expression(&call_expr.func); let args = vecmap(&call_expr.arguments, |arg| { let typ = self.check_expression(arg); @@ -839,6 +839,9 @@ impl<'interner> TypeChecker<'interner> { } } } + Type::TraitAsType(_trait) => { + unreachable!("unexpected lookup on trait as return type") + } Type::NamedGeneric(_, _) => { let func_meta = self.interner.function_meta( &self.current_function.expect("unexpected method outside a function"), diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index a3c5a5eb98d..a88dcdcbf63 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -15,7 +15,7 @@ pub use errors::TypeCheckError; use crate::{ hir_def::{expr::HirExpression, stmt::HirStatement}, - node_interner::{ExprId, FuncId, NodeInterner, StmtId}, + node_interner::{ExprId, FuncId, NodeInterner, StmtId, TraitImplKey}, Type, }; @@ -63,30 +63,45 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec {} + None => { + let error = TypeCheckError::TypeMismatchWithSource { + expected: declared_return_type.clone(), + actual: function_last_type.clone(), + span: func_span, + source: Source::Return(meta.return_type, expr_span), + }; + errors.push(error); + } + } + } else { + function_last_type.unify_with_coercions( + &declared_return_type, + *function_body_id, + interner, + &mut errors, + || { + let mut error = TypeCheckError::TypeMismatchWithSource { + expected: declared_return_type.clone(), + actual: function_last_type.clone(), + span: func_span, + source: Source::Return(meta.return_type, expr_span), + }; + + if empty_function { + error = error.add_context( "implicitly returns `()` as its body has no tail or `return` expression", ); - } - - error - }, - ); + } + error + }, + ); + } } errors diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 11e9dde6846..3f01df4426a 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -34,7 +34,7 @@ pub struct TraitType { /// Represents a trait in the type system. Each instance of this struct /// will be shared across all Type::Trait variants that represent /// the same trait. -#[derive(Clone, Debug)] +#[derive(Debug, Eq, Clone)] pub struct Trait { /// A unique id representing this trait type. Used to check if two /// struct traits are equal. diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index ef321ee2f71..13e2bc9e607 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -14,7 +14,10 @@ use noirc_printable_type::PrintableType; use crate::{node_interner::StructId, Ident, Signedness}; -use super::expr::{HirCallExpression, HirExpression, HirIdent}; +use super::{ + expr::{HirCallExpression, HirExpression, HirIdent}, + traits::Trait, +}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub enum Type { @@ -62,6 +65,8 @@ pub enum Type { /// different argument types each time. TypeVariable(TypeVariable, TypeVariableKind), + TraitAsType(Trait), + /// NamedGenerics are the 'T' or 'U' in a user-defined generic function /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. NamedGeneric(TypeVariable, Rc), @@ -483,7 +488,8 @@ impl Type { | Type::Constant(_) | Type::NamedGeneric(_, _) | Type::NotConstant - | Type::Forall(_, _) => false, + | Type::Forall(_, _) + | Type::TraitAsType(_) => false, Type::Array(length, elem) => { elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length) @@ -560,6 +566,9 @@ impl std::fmt::Display for Type { write!(f, "{}<{}>", s.borrow(), args.join(", ")) } } + Type::TraitAsType(tr) => { + write!(f, "impl {}", tr.name) + } Type::Tuple(elements) => { let elements = vecmap(elements, ToString::to_string); write!(f, "({})", elements.join(", ")) @@ -1057,6 +1066,7 @@ impl Type { let fields = vecmap(fields, |field| field.substitute(type_bindings)); Type::Tuple(fields) } + Type::TraitAsType(_) => todo!(), Type::Forall(typevars, typ) => { // Trying to substitute a variable defined within a nested Forall // is usually impossible and indicative of an error in the type checker somewhere. @@ -1096,6 +1106,7 @@ impl Type { let field_occurs = fields.occurs(target_id); len_occurs || field_occurs } + Type::TraitAsType(_) => todo!(), Type::Struct(_, generic_args) => generic_args.iter().any(|arg| arg.occurs(target_id)), Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => { @@ -1147,7 +1158,6 @@ impl Type { Struct(def.clone(), args) } Tuple(args) => Tuple(vecmap(args, |arg| arg.follow_bindings())), - TypeVariable(var, _) | NamedGeneric(var, _) => { if let TypeBinding::Bound(typ) = &*var.borrow() { return typ.follow_bindings(); @@ -1166,10 +1176,14 @@ impl Type { // Expect that this function should only be called on instantiated types Forall(..) => unreachable!(), - - FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Error | NotConstant => { - self.clone() - } + TraitAsType(_) + | FieldElement + | Integer(_, _) + | Bool + | Constant(_) + | Unit + | Error + | NotConstant => self.clone(), } } } @@ -1270,6 +1284,7 @@ impl From<&Type> for PrintableType { let fields = vecmap(fields, |(name, typ)| (name, typ.into())); PrintableType::Struct { fields, name: struct_type.name.to_string() } } + Type::TraitAsType(name) => PrintableType::Trait { name: name.name.to_string() }, Type::Tuple(_) => todo!("printing tuple types is not yet implemented"), Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index b2c12746a8c..39250be619a 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -222,9 +222,15 @@ impl<'interner> Monomorphizer<'interner> { let modifiers = self.interner.function_modifiers(&f); let name = self.interner.function_name(&f).to_owned(); - let return_type = self.convert_type(meta.return_type()); + let body_expr_id = *self.interner.function(&f).as_expr(); + let body_return_type = self.interner.id_type(body_expr_id); + let return_type = self.convert_type(match meta.return_type() { + Type::TraitAsType(_) => &body_return_type, + _ => meta.return_type(), + }); + let parameters = self.parameters(meta.parameters); - let body = self.expr(*self.interner.function(&f).as_expr()); + let body = self.expr(body_expr_id); let unconstrained = modifiers.is_unconstrained || matches!(modifiers.contract_function_type, Some(ContractFunctionType::Open)); @@ -381,8 +387,8 @@ impl<'interner> Monomorphizer<'interner> { } } - HirExpression::MethodCall(_) => { - unreachable!("Encountered HirExpression::MethodCall during monomorphization") + HirExpression::MethodCall(hir_method_call) => { + unreachable!("Encountered HirExpression::MethodCall during monomorphization {hir_method_call:?}") } HirExpression::Error => unreachable!("Encountered Error node during monomorphization"), } @@ -635,7 +641,6 @@ impl<'interner> Monomorphizer<'interner> { let location = Some(ident.location); let name = definition.name.clone(); let typ = self.interner.id_type(expr_id); - let definition = self.lookup_function(*func_id, expr_id, &typ); let typ = self.convert_type(&typ); let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; @@ -686,7 +691,6 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::FmtString(size, fields) } HirType::Unit => ast::Type::Unit, - HirType::Array(length, element) => { let element = Box::new(self.convert_type(element.as_ref())); @@ -696,7 +700,9 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Slice(element) } } - + HirType::TraitAsType(_) => { + unreachable!("All TraitAsType should be replaced before calling convert_type"); + } HirType::NamedGeneric(binding, _) => { if let TypeBinding::Bound(binding) = &*binding.borrow() { return self.convert_type(binding); @@ -780,8 +786,7 @@ impl<'interner> Monomorphizer<'interner> { } } - fn is_function_closure(&self, raw_func_id: node_interner::ExprId) -> bool { - let t = self.convert_type(&self.interner.id_type(raw_func_id)); + fn is_function_closure(&self, t: ast::Type) -> bool { if self.is_function_closure_type(&t) { true } else if let ast::Type::Tuple(elements) = t { @@ -850,6 +855,7 @@ impl<'interner> Monomorphizer<'interner> { let func: Box; let return_type = self.interner.id_type(id); let return_type = self.convert_type(&return_type); + let location = call.location; if let ast::Expression::Ident(ident) = original_func.as_ref() { @@ -863,8 +869,9 @@ impl<'interner> Monomorphizer<'interner> { } let mut block_expressions = vec![]; - - let is_closure = self.is_function_closure(call.func); + let func_type = self.interner.id_type(call.func); + let func_type = self.convert_type(&func_type); + let is_closure = self.is_function_closure(func_type); if is_closure { let local_id = self.next_local_id(); diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index c4497331fae..8512e38fec0 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -820,6 +820,23 @@ impl NodeInterner { self.id_to_type.get(&index.into()).cloned().unwrap_or(Type::Error) } + pub fn id_type_substitute_trait_as_type(&self, def_id: DefinitionId) -> Type { + let typ = self.id_type(def_id); + if let Type::Function(args, ret, env) = &typ { + let def = self.definition(def_id); + if let Type::TraitAsType(_trait) = ret.as_ref() { + if let DefinitionKind::Function(func_id) = def.kind { + let f = self.function(&func_id); + let func_body = f.as_expr(); + let ret_type = self.id_type(func_body); + let new_type = Type::Function(args.clone(), Box::new(ret_type), env.clone()); + return new_type; + } + } + } + typ + } + /// Returns the span of an item stored in the Interner pub fn id_location(&self, index: impl Into) -> Location { self.id_to_location.get(&index.into()).copied().unwrap() @@ -943,6 +960,7 @@ impl NodeInterner { | Type::Forall(..) | Type::NotConstant | Type::Constant(..) + | Type::TraitAsType(..) | Type::Error => false, } } @@ -1051,6 +1069,7 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::Error | Type::NotConstant | Type::Struct(_, _) + | Type::TraitAsType(_) | Type::FmtString(_, _) => None, } } diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 6b43bd003f3..9b2e8343680 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -271,6 +271,7 @@ fn lambda_return_type() -> impl NoirParser { } fn function_return_type() -> impl NoirParser<((Distinctness, Visibility), FunctionReturnType)> { + //tuka just(Token::Arrow) .ignore_then(optional_distinctness()) .then(optional_visibility()) @@ -993,6 +994,7 @@ fn parse_type_inner( string_type(), format_string_type(recursive_type_parser.clone()), named_type(recursive_type_parser.clone()), + named_trait(recursive_type_parser.clone()), array_type(recursive_type_parser.clone()), recursive_type_parser.clone().delimited_by(just(Token::LeftParen), just(Token::RightParen)), tuple_type(recursive_type_parser.clone()), @@ -1082,6 +1084,12 @@ fn named_type(type_parser: impl NoirParser) -> impl NoirParser) -> impl NoirParser { + keyword(Keyword::Impl).then(path()).then(generic_type_args(type_parser)).map_with_span( + |((_token, path), args), span| UnresolvedTypeData::TraitAsType(path, args).with_span(span), + ) +} + fn generic_type_args( type_parser: impl NoirParser, ) -> impl NoirParser> { diff --git a/compiler/noirc_printable_type/src/lib.rs b/compiler/noirc_printable_type/src/lib.rs index 348f5ef3274..03b126ff008 100644 --- a/compiler/noirc_printable_type/src/lib.rs +++ b/compiler/noirc_printable_type/src/lib.rs @@ -26,6 +26,9 @@ pub enum PrintableType { name: String, fields: Vec<(String, PrintableType)>, }, + Trait { + name: String, + }, String { length: u64, }, @@ -44,6 +47,7 @@ impl PrintableType { fields.iter().fold(0, |acc, (_, field_type)| acc + field_type.field_count()) } Self::String { length } => *length as u32, + Self::Trait { .. } => 0, } } } @@ -315,6 +319,7 @@ fn decode_value( PrintableValue::Struct(struct_map) } + PrintableType::Trait { .. } => todo!(), } } diff --git a/tooling/nargo_cli/tests/execution_success/trait_as_return_type/Nargo.toml b/tooling/nargo_cli/tests/execution_success/trait_as_return_type/Nargo.toml new file mode 100644 index 00000000000..1a7ef7ee8a2 --- /dev/null +++ b/tooling/nargo_cli/tests/execution_success/trait_as_return_type/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "trait_as_return_type" +type = "bin" +authors = [""] +compiler_version = "0.10.5" + +[dependencies] \ No newline at end of file diff --git a/tooling/nargo_cli/tests/execution_success/trait_as_return_type/Prover.toml b/tooling/nargo_cli/tests/execution_success/trait_as_return_type/Prover.toml new file mode 100644 index 00000000000..a0cd58138b6 --- /dev/null +++ b/tooling/nargo_cli/tests/execution_success/trait_as_return_type/Prover.toml @@ -0,0 +1 @@ +x = "5" \ No newline at end of file diff --git a/tooling/nargo_cli/tests/execution_success/trait_as_return_type/src/main.nr b/tooling/nargo_cli/tests/execution_success/trait_as_return_type/src/main.nr new file mode 100644 index 00000000000..e84fe15aba7 --- /dev/null +++ b/tooling/nargo_cli/tests/execution_success/trait_as_return_type/src/main.nr @@ -0,0 +1,55 @@ +trait SomeTrait { + fn magic_number(self) -> Field; +} + +struct A {} +struct B {} +struct C { + x: Field +} + + +impl SomeTrait for A { + fn magic_number(self) -> Field { + 2 + } +} + +impl SomeTrait for B { + fn magic_number(self) -> Field { + 4 + } +} + +impl SomeTrait for C { + fn magic_number(self) -> Field { + self.x + } +} + + + +fn factory_a() -> impl SomeTrait { + A {} +} + +fn factory_b() -> impl SomeTrait { + B {} +} + +fn factory_c(x: Field) -> impl SomeTrait { + C {x:x} +} + +// x = 15 +fn main(x: u32) { + let a = factory_a(); + let b = B {}; + let b2 = factory_b(); + assert(a.magic_number() == 2); + assert(b.magic_number() == 4); + assert(b2.magic_number() == 4); + let c = factory_c(10); + assert(c.magic_number() == 10); + assert(factory_c(13).magic_number() == 13); +} \ No newline at end of file diff --git a/tooling/noirc_abi/src/lib.rs b/tooling/noirc_abi/src/lib.rs index d5c2314b3a6..9753dc21f14 100644 --- a/tooling/noirc_abi/src/lib.rs +++ b/tooling/noirc_abi/src/lib.rs @@ -158,6 +158,7 @@ impl AbiType { Type::Error => unreachable!(), Type::Unit => unreachable!(), Type::Constant(_) => unreachable!(), + Type::TraitAsType(_) => unreachable!(), Type::Struct(def, ref args) => { let struct_type = def.borrow(); let fields = struct_type.get_fields(args);