Skip to content

Commit

Permalink
feat!: type-check trait default methods
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite committed Dec 23, 2024
1 parent 011fbc1 commit 52c1332
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 73 deletions.
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,8 @@ impl FunctionDefinition {
is_unconstrained: bool,
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
body: &BlockExpression,
where_clause: &[UnresolvedTraitConstraint],
body: BlockExpression,
where_clause: Vec<UnresolvedTraitConstraint>,
return_type: &FunctionReturnType,
) -> FunctionDefinition {
let p = parameters
Expand All @@ -843,9 +843,9 @@ impl FunctionDefinition {
visibility: ItemVisibility::Private,
generics: generics.clone(),
parameters: p,
body: body.clone(),
body,
span: name.span(),
where_clause: where_clause.to_vec(),
where_clause,
return_type: return_type.clone(),
return_visibility: Visibility::Private,
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ impl<'context> Elaborator<'context> {
self.elaborate_functions(functions);
}

for (trait_id, unresolved_trait) in items.traits {
self.current_trait = Some(trait_id);
self.elaborate_functions(unresolved_trait.fns_with_default_impl);
}
self.current_trait = None;

for impls in items.impls.into_values() {
self.elaborate_impls(impls);
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ impl<'context> Elaborator<'context> {

let impl_kind = match method {
HirMethodReference::FuncId(_) => ImplKind::NotATraitMethod,
HirMethodReference::TraitMethodId(method_id, generics) => {
HirMethodReference::TraitMethodId(method_id, generics, _) => {
let mut constraint =
self.interner.get_trait(method_id.trait_id).as_constraint(span);
constraint.trait_bound.trait_generics = generics;
Expand Down
39 changes: 25 additions & 14 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ impl<'context> Elaborator<'context> {
self.recover_generics(|this| {
this.current_trait = Some(*trait_id);

let the_trait = this.interner.get_trait(*trait_id);
let self_typevar = the_trait.self_type_typevar.clone();
let self_type = Type::TypeVariable(self_typevar.clone());
this.self_type = Some(self_type.clone());

let resolved_generics = this.interner.get_trait(*trait_id).generics.clone();
this.add_existing_generics(
&unresolved_trait.trait_def.generics,
Expand All @@ -48,12 +53,15 @@ impl<'context> Elaborator<'context> {
.add_trait_dependency(DependencyId::Trait(bound.trait_id), *trait_id);
}

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});

let methods = this.resolve_trait_methods(*trait_id, unresolved_trait);

this.interner.update_trait(*trait_id, |trait_def| {
trait_def.set_methods(methods);
trait_def.set_trait_bounds(resolved_trait_bounds);
trait_def.set_where_clause(where_clause);
});
});

Expand Down Expand Up @@ -94,16 +102,17 @@ impl<'context> Elaborator<'context> {
parameters,
return_type,
where_clause,
body: _,
body,
is_unconstrained,
visibility: _,
is_comptime: _,
} = &item.item
{
self.recover_generics(|this| {
let the_trait = this.interner.get_trait(trait_id);
let the_trait_where_clause = the_trait.where_clause.clone();
let the_trait_constraint = the_trait.as_constraint(the_trait.name.span());
let self_typevar = the_trait.self_type_typevar.clone();
let self_type = Type::TypeVariable(self_typevar.clone());
let name_span = the_trait.name.span();

this.add_existing_generic(
Expand All @@ -115,9 +124,12 @@ impl<'context> Elaborator<'context> {
span: name_span,
},
);
this.self_type = Some(self_type.clone());

let func_id = unresolved_trait.method_ids[&name.0.contents];
let mut where_clause = where_clause.to_vec();

// Attach any trait constraints on the trait to the function
where_clause.extend(unresolved_trait.trait_def.where_clause.clone());

this.resolve_trait_function(
trait_id,
Expand All @@ -127,6 +139,7 @@ impl<'context> Elaborator<'context> {
parameters,
return_type,
where_clause,
body,
func_id,
);

Expand Down Expand Up @@ -188,20 +201,22 @@ impl<'context> Elaborator<'context> {
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
return_type: &FunctionReturnType,
where_clause: &[UnresolvedTraitConstraint],
where_clause: Vec<UnresolvedTraitConstraint>,
body: &Option<BlockExpression>,
func_id: FuncId,
) {
let old_generic_count = self.generics.len();

self.scopes.start_function();
let body = match body {
Some(body) => body.clone(),
None => BlockExpression { statements: Vec::new() },
};

let kind = FunctionKind::Normal;
let mut def = FunctionDefinition::normal(
name,
is_unconstrained,
generics,
parameters,
&BlockExpression { statements: Vec::new() },
body,
where_clause,
return_type,
);
Expand All @@ -210,10 +225,6 @@ impl<'context> Elaborator<'context> {

let mut function = NoirFunction { kind, def };
self.define_function_meta(&mut function, func_id, Some(trait_id));
self.elaborate_function(func_id);
let _ = self.scopes.end_function();
// Don't check the scope tree for unused variables, they can't be used in a declaration anyway.
self.generics.truncate(old_generic_count);
}
}

Expand Down
29 changes: 27 additions & 2 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,17 @@ impl<'context> Elaborator<'context> {
}

// this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type)
// or inside a trait default method.
//
// Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not
// E.g. `t.method()` with `where T: Foo<Bar>` in scope will return `(Foo::method, T, vec![Bar])`
fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option<TraitPathResolution> {
let trait_impl = self.current_trait_impl?;
let trait_id = self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id;
let trait_id = if let Some(current_trait) = self.current_trait {
current_trait
} else {
let trait_impl = self.current_trait_impl?;
self.interner.try_get_trait_implementation(trait_impl)?.borrow().trait_id
};

if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].ident.0.contents;
Expand Down Expand Up @@ -1395,6 +1400,25 @@ impl<'context> Elaborator<'context> {
};
let func_meta = self.interner.function_meta(&func_id);

// If inside a trait method, check if it's a method on `self`
if let Some(trait_id) = func_meta.trait_id {
if Some(object_type) == self.self_type.as_ref() {
let the_trait = self.interner.get_trait(trait_id);
let constraint = the_trait.as_constraint(the_trait.name.span());
if let Some(HirMethodReference::TraitMethodId(method_id, generics, _)) = self
.lookup_method_in_trait(
the_trait,
method_name,
&constraint.trait_bound,
the_trait.id,
)
{
// If it is, it's an assumed trait
return Some(HirMethodReference::TraitMethodId(method_id, generics, true));
}
}
}

for constraint in &func_meta.trait_constraints {
if *object_type == constraint.typ {
if let Some(the_trait) =
Expand Down Expand Up @@ -1432,6 +1456,7 @@ impl<'context> Elaborator<'context> {
return Some(HirMethodReference::TraitMethodId(
trait_method,
trait_bound.trait_generics.clone(),
false,
));
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ impl<'a> ModCollector<'a> {
*is_unconstrained,
generics,
parameters,
body,
where_clause,
body.clone(),
where_clause.clone(),
return_type,
));
unresolved_functions.push_fn(
Expand Down
9 changes: 5 additions & 4 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ pub enum HirMethodReference {
/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphization.
TraitMethodId(TraitMethodId, TraitGenerics),
TraitMethodId(TraitMethodId, TraitGenerics, bool /* assumed */),
}

impl HirMethodReference {
pub fn func_id(&self, interner: &NodeInterner) -> Option<FuncId> {
match self {
HirMethodReference::FuncId(func_id) => Some(*func_id),
HirMethodReference::TraitMethodId(method_id, _) => {
HirMethodReference::TraitMethodId(method_id, _, _) => {
let id = interner.trait_method_id(*method_id);
match &interner.try_definition(id)?.kind {
DefinitionKind::Function(func_id) => Some(*func_id),
Expand Down Expand Up @@ -246,7 +246,7 @@ impl HirMethodCallExpression {
HirMethodReference::FuncId(func_id) => {
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
}
HirMethodReference::TraitMethodId(method_id, trait_generics) => {
HirMethodReference::TraitMethodId(method_id, trait_generics, assumed) => {
let id = interner.trait_method_id(method_id);
let constraint = TraitConstraint {
typ: object_type,
Expand All @@ -256,7 +256,8 @@ impl HirMethodCallExpression {
span: location.span,
},
};
(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed: false }))

(id, ImplKind::TraitMethod(TraitMethod { method_id, constraint, assumed }))
}
};
let func_var = HirIdent { location, id, impl_kind };
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir_def/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ pub enum FunctionBody {

impl FuncMeta {
/// A stub function does not have a body. This includes Builtin, LowLevel,
/// and Oracle functions in addition to method declarations within a trait.
/// and Oracle functions.
///
/// We don't check the return type of these functions since it will always have
/// an empty body, and we don't check for unused parameters.
pub fn is_stub(&self) -> bool {
self.kind.can_ignore_return_type() || self.trait_id.is_some()
self.kind.can_ignore_return_type()
}

pub fn function_signature(&self) -> FunctionSignature {
Expand Down
20 changes: 0 additions & 20 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::borrow::Cow;
use std::fmt;
use std::hash::Hash;
use std::marker::Copy;
use std::ops::Deref;

use fm::FileId;
use iter_extended::vecmap;
Expand Down Expand Up @@ -1478,25 +1477,6 @@ impl NodeInterner {
Ok(impl_kind)
}

/// Given a `ObjectType: TraitId` pair, find all implementations without taking constraints into account or
/// applying any type bindings. Useful to look for a specific trait in a type that is used in a macro.
pub fn lookup_all_trait_implementations(
&self,
object_type: &Type,
trait_id: TraitId,
) -> Vec<&TraitImplKind> {
let trait_impl = self.trait_implementation_map.get(&trait_id);

let trait_impl = trait_impl.map(|trait_impl| {
let impls = trait_impl.iter().filter_map(|(typ, impl_kind)| match &typ {
Type::Forall(_, typ) => (typ.deref() == object_type).then_some(impl_kind),
_ => None,
});
impls.collect()
});
trait_impl.unwrap_or_default()
}

/// Similar to `lookup_trait_implementation` but does not apply any type bindings on success.
/// On error returns either:
/// - 1+ failing trait constraints, including the original.
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2942,7 +2942,7 @@ fn uses_self_type_inside_trait() {
fn uses_self_type_in_trait_where_clause() {
let src = r#"
pub trait Trait {
fn trait_func() -> bool;
fn trait_func(self) -> bool;
}
pub trait Foo where Self: Trait {
Expand All @@ -2963,6 +2963,7 @@ fn uses_self_type_in_trait_where_clause() {
"#;

let errors = get_program_errors(src);
dbg!(&errors);
assert_eq!(errors.len(), 2);

let CompilationError::ResolverError(ResolverError::TraitNotImplemented { .. }) = &errors[0].0
Expand Down
Loading

0 comments on commit 52c1332

Please sign in to comment.