Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traits): Improve support for traits static method resolution #2958

Merged
merged 7 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
pub struct UnresolvedFunctions {
pub file_id: FileId,
pub functions: Vec<(LocalModuleId, FuncId, NoirFunction)>,
pub trait_id: Option<TraitId>,
}

impl UnresolvedFunctions {
Expand Down Expand Up @@ -302,7 +303,7 @@
errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls));

// Bind trait impls to their trait. Collect trait functions, that have a
// default implementation, which hasn't been overriden.

Check warning on line 306 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (overriden)
errors.extend(collect_trait_impls(
context,
crate_id,
Expand Down Expand Up @@ -485,7 +486,9 @@
errors.push((error.into(), trait_impl.file_id));
}
}

trait_impl.methods.functions = ordered_methods;
trait_impl.methods.trait_id = Some(trait_id);
errors
}

Expand Down Expand Up @@ -796,7 +799,7 @@
.collect();
let default_impl = if !default_impl_list.is_empty() {
if default_impl_list.len() > 1 {
// TODO(nickysn): Add check for method duplicates in the trait and emit proper error messages. This is planned in a future PR.

Check warning on line 802 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (nickysn)
panic!("Too many functions with the same name!");
}
Some(Box::new(default_impl_list[0].2.clone()))
Expand Down Expand Up @@ -993,6 +996,12 @@
errors,
);

if let Some(trait_id) = maybe_trait_id {
for (_, func) in &impl_methods {
interner.set_function_trait(*func, self_type.clone(), trait_id);
}
}

let mut new_resolver =
Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id);
new_resolver.set_self_type(Some(self_type.clone()));
Expand Down Expand Up @@ -1025,7 +1034,7 @@
methods
}

// TODO(vitkov): Move this out of here and into type_check

Check warning on line 1037 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (vitkov)
fn check_methods_signatures(
resolver: &mut Resolver,
impl_methods: &Vec<(FileId, FuncId)>,
Expand All @@ -1043,7 +1052,7 @@
let meta = resolver.interner.function_meta(func_id);
let func_name = resolver.interner.function_name(func_id).to_owned();

let mut typecheck_errors = Vec::new();

Check warning on line 1055 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (typecheck)

// `method` is None in the case where the impl block has a method that's not part of the trait.
// If that's the case, a `MethodNotInTrait` error has already been thrown, and we can ignore
Expand All @@ -1059,7 +1068,7 @@
for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in
method.arguments.iter().zip(&params).zip(&meta.parameters.0).enumerate()
{
expected.unify(actual, &mut typecheck_errors, || {

Check warning on line 1071 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (typecheck)
TypeCheckError::TraitMethodParameterTypeMismatch {
method_name: func_name.to_string(),
expected_typ: expected.to_string(),
Expand Down Expand Up @@ -1088,7 +1097,7 @@
let resolved_return_type =
resolver.resolve_type(meta.return_type.get_type().into_owned());

method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || {

Check warning on line 1100 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (typecheck)
let ret_type_span =
meta.return_type.get_type().span.expect("return type must always have a span");

Expand All @@ -1099,7 +1108,7 @@
}
});

errors.extend(typecheck_errors.iter().cloned().map(|e| (e.into(), *file_id)));

Check warning on line 1111 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (typecheck)
}
}

Expand Down Expand Up @@ -1156,6 +1165,7 @@
// TypeVariables for the same generic, causing it to instantiate incorrectly.
resolver.set_generics(impl_generics.clone());
resolver.set_self_type(self_type.clone());
resolver.set_trait_id(unresolved_functions.trait_id);

let (hir_func, func_meta, errs) = resolver.resolve_function(func, func_id);
interner.push_fn_meta(func_meta, func_id);
Expand Down
18 changes: 12 additions & 6 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@
let module_id = ModuleId { krate, local_id: self.module_id };

for r#impl in impls {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
let mut unresolved_functions = UnresolvedFunctions {
file_id: self.file_id,
functions: Vec::new(),
trait_id: None,
};

for method in r#impl.methods {
let func_id = context.def_interner.push_empty_fn();
Expand Down Expand Up @@ -171,7 +174,7 @@
krate: CrateId,
) -> UnresolvedFunctions {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None };

let module = ModuleId { krate, local_id: self.module_id };

Expand All @@ -193,7 +196,7 @@
krate: CrateId,
) -> Vec<(CompilationError, FileId)> {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new(), trait_id: None };
let mut errors = vec![];

let module = ModuleId { krate, local_id: self.module_id };
Expand Down Expand Up @@ -240,7 +243,7 @@
types: Vec<NoirStruct>,
krate: CrateId,
) -> Vec<(CompilationError, FileId)> {
let mut definiton_errors = vec![];

Check warning on line 246 in compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (definiton)
for struct_definition in types {
let name = struct_definition.name.clone();

Expand Down Expand Up @@ -351,8 +354,11 @@
}

// Add all functions that have a default implementation in the trait
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };
let mut unresolved_functions = UnresolvedFunctions {
file_id: self.file_id,
functions: Vec::new(),
trait_id: None,
};
for trait_item in &trait_definition.items {
// TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629
if let TraitItem::Function {
Expand Down
133 changes: 103 additions & 30 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ use crate::{
};
use crate::{
ArrayLiteral, ContractFunctionType, Distinctness, Generics, LValue, NoirStruct, NoirTypeAlias,
Path, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable, UnaryOp,
UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
Path, PathKind, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable,
UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use fm::FileId;
Expand Down Expand Up @@ -78,6 +78,8 @@ pub struct Resolver<'a> {
scopes: ScopeForest,
path_resolver: &'a dyn PathResolver,
def_maps: &'a BTreeMap<CrateId, CrateDefMap>,
trait_id: Option<TraitId>,
trait_bounds: Vec<UnresolvedTraitConstraint>,
pub interner: &'a mut NodeInterner,
errors: Vec<ResolverError>,
file: FileId,
Expand Down Expand Up @@ -120,6 +122,8 @@ impl<'a> Resolver<'a> {
Self {
path_resolver,
def_maps,
trait_id: None,
trait_bounds: Vec::new(),
scopes: ScopeForest::default(),
interner,
self_type: None,
Expand All @@ -134,6 +138,10 @@ impl<'a> Resolver<'a> {
self.self_type = self_type;
}

pub fn set_trait_id(&mut self, trait_id: Option<TraitId>) {
self.trait_id = trait_id;
}

pub fn get_self_type(&mut self) -> Option<&Type> {
self.self_type.as_ref()
}
Expand All @@ -158,12 +166,14 @@ impl<'a> Resolver<'a> {
self.resolve_local_globals();

self.add_generics(&func.def.generics);
self.trait_bounds = func.def.where_clause.clone();

let (hir_func, func_meta) = self.intern_function(func, func_id);
let func_scope_tree = self.scopes.end_function();

self.check_for_unused_variables_in_scope_tree(func_scope_tree);

self.trait_bounds.clear();
(hir_func, func_meta, self.errors)
}

Expand Down Expand Up @@ -1075,39 +1085,43 @@ impl<'a> Resolver<'a> {
Literal::Unit => HirLiteral::Unit,
}),
ExpressionKind::Variable(path) => {
// If the Path is being used as an Expression, then it is referring to a global from a separate module
// Otherwise, then it is referring to an Identifier
// This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10;
// If the expression is a singular indent, we search the resolver's current scope as normal.
let (hir_ident, var_scope_index) = self.get_ident_from_path(path);

if hir_ident.id != DefinitionId::dummy_id() {
match self.interner.definition(hir_ident.id).kind {
DefinitionKind::Function(id) => {
if self.interner.function_visibility(id) == Visibility::Private {
let span = hir_ident.location.span;
self.check_can_reference_private_function(id, span);
if let Some(expr) = self.resolve_trait_generic_path(&path) {
expr
} else {
// If the Path is being used as an Expression, then it is referring to a global from a separate module
// Otherwise, then it is referring to an Identifier
// This lookup allows support of such statements: let x = foo::bar::SOME_GLOBAL + 10;
// If the expression is a singular indent, we search the resolver's current scope as normal.
let (hir_ident, var_scope_index) = self.get_ident_from_path(path);

if hir_ident.id != DefinitionId::dummy_id() {
match self.interner.definition(hir_ident.id).kind {
DefinitionKind::Function(id) => {
if self.interner.function_visibility(id) == Visibility::Private {
let span = hir_ident.location.span;
self.check_can_reference_private_function(id, span);
}
}
}
DefinitionKind::Global(_) => {}
DefinitionKind::GenericType(_) => {
// Initialize numeric generics to a polymorphic integer type in case
// they're used in expressions. We must do this here since the type
// checker does not check definition kinds and otherwise expects
// parameters to already be typed.
if self.interner.id_type(hir_ident.id) == Type::Error {
let typ = Type::polymorphic_integer(self.interner);
self.interner.push_definition_type(hir_ident.id, typ);
DefinitionKind::Global(_) => {}
DefinitionKind::GenericType(_) => {
// Initialize numeric generics to a polymorphic integer type in case
// they're used in expressions. We must do this here since the type
// checker does not check definition kinds and otherwise expects
// parameters to already be typed.
if self.interner.id_type(hir_ident.id) == Type::Error {
let typ = Type::polymorphic_integer(self.interner);
self.interner.push_definition_type(hir_ident.id, typ);
}
}
DefinitionKind::Local(_) => {
// only local variables can be captured by closures.
self.resolve_local_variable(hir_ident, var_scope_index);
}
}
DefinitionKind::Local(_) => {
// only local variables can be captured by closures.
self.resolve_local_variable(hir_ident, var_scope_index);
}
}
}

HirExpression::Ident(hir_ident)
HirExpression::Ident(hir_ident)
}
}
ExpressionKind::Prefix(prefix) => {
let operator = prefix.operator;
Expand Down Expand Up @@ -1445,6 +1459,65 @@ impl<'a> Resolver<'a> {
self.lookup(path).ok().map(|id| self.interner.get_type_alias(id))
}

// this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type)
fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option<HirExpression> {
if let Some(trait_id) = self.trait_id {
if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].0.contents;
let method = &path.segments[1];

if name == SELF_TYPE_NAME {
let the_trait = self.interner.get_trait(trait_id);

if let Some(method) = the_trait.find_method(method.clone()) {
let self_type = Type::TypeVariable(
the_trait.self_type_typevar,
crate::TypeVariableKind::Normal,
);
return Some(HirExpression::TraitMethodReference(self_type, method));
}
}
}
}
None
}

// this resolves a static trait method T::trait_method by iterating over the where clause
fn resolve_trait_method_by_named_generic(&mut self, path: &Path) -> Option<HirExpression> {
if path.segments.len() != 2 {
return None;
}

for UnresolvedTraitConstraint { typ, trait_bound } in self.trait_bounds.clone() {
if let UnresolvedTypeData::Named(constraint_path, _) = &typ.typ {
// if `path` is `T::method_name`, we're looking for constraint of the form `T: SomeTrait`
if constraint_path.segments.len() == 1
&& path.segments[0] != constraint_path.last_segment()
{
continue;
}

if let Ok(ModuleDefId::TraitId(trait_id)) =
self.path_resolver.resolve(self.def_maps, trait_bound.trait_path.clone())
{
let the_trait = self.interner.get_trait(trait_id);
if let Some(method) =
the_trait.find_method(path.segments.last().unwrap().clone())
{
let self_type = self.resolve_type(typ.clone());
return Some(HirExpression::TraitMethodReference(self_type, method));
}
}
}
}
None
}

fn resolve_trait_generic_path(&mut self, path: &Path) -> Option<HirExpression> {
self.resolve_trait_static_method_by_self(path)
.or_else(|| self.resolve_trait_method_by_named_generic(path))
}

fn resolve_path(&mut self, path: Path) -> Result<ModuleDefId, ResolverError> {
self.path_resolver.resolve(self.def_maps, path).map_err(ResolverError::PathResolutionError)
}
Expand Down
23 changes: 19 additions & 4 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl<'interner> TypeChecker<'interner> {
}

let (function_id, function_call) = method_call.into_function_call(
method_ref,
method_ref.clone(),
location,
self.interner,
);
Expand Down Expand Up @@ -291,7 +291,19 @@ impl<'interner> TypeChecker<'interner> {

Type::Function(params, Box::new(lambda.return_type), Box::new(env_type))
}
HirExpression::TraitMethodReference(_) => unreachable!("unexpected TraitMethodReference - they should be added after initial type checking"),
HirExpression::TraitMethodReference(_, method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let method = &the_trait.methods[method.method_index];

let typ = Type::Function(
method.arguments.clone(),
Box::new(method.return_type.clone()),
Box::new(Type::Unit),
);
let (typ, bindings) = typ.instantiate(self.interner);
jfecher marked this conversation as resolved.
Show resolved Hide resolved
self.interner.store_instantiation_bindings(*expr_id, bindings);
typ
}
};

self.interner.push_expr_type(expr_id, typ.clone());
Expand Down Expand Up @@ -498,7 +510,7 @@ impl<'interner> TypeChecker<'interner> {

(func_meta.typ, param_len)
}
HirMethodReference::TraitMethodId(method) => {
HirMethodReference::TraitMethodId(_, method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let method = &the_trait.methods[method.method_index];

Expand Down Expand Up @@ -863,7 +875,10 @@ impl<'interner> TypeChecker<'interner> {
if method.name.0.contents == method_name {
let trait_method =
TraitMethodId { trait_id: constraint.trait_id, method_index };
return Some(HirMethodReference::TraitMethodId(trait_method));
return Some(HirMethodReference::TraitMethodId(
object_type.clone(),
trait_method,
));
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
TraitMethodReference(TraitMethodId),
TraitMethodReference(Type, TraitMethodId),
Error,
}

Expand Down Expand Up @@ -108,7 +108,7 @@
pub rhs: ExprId,
}

/// This is always a struct field access `mystruct.field`

Check warning on line 111 in compiler/noirc_frontend/src/hir_def/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (mystruct)
/// and never a method call. The later is represented by HirMethodCallExpression.
#[derive(Debug, Clone)]
pub struct HirMemberAccess {
Expand Down Expand Up @@ -151,7 +151,7 @@
pub location: Location,
}

#[derive(Debug, Copy, Clone)]
#[derive(Debug, Clone)]
pub enum HirMethodReference {
/// A method can be defined in a regular `impl` block, in which case
/// it's syntax sugar for a normal function call, and can be
Expand All @@ -160,8 +160,8 @@

/// Or a method can come from a Trait impl block, in which case
/// the actual function called will depend on the instantiated type,
/// which can be only known during monomorphizaiton.

Check warning on line 163 in compiler/noirc_frontend/src/hir_def/expr.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (monomorphizaiton)
TraitMethodId(TraitMethodId),
TraitMethodId(Type, TraitMethodId),
}

impl HirMethodCallExpression {
Expand All @@ -179,8 +179,8 @@
let id = interner.function_definition_id(func_id);
HirExpression::Ident(HirIdent { location, id })
}
HirMethodReference::TraitMethodId(method_id) => {
HirExpression::TraitMethodReference(method_id)
HirMethodReference::TraitMethodId(typ, method_id) => {
HirExpression::TraitMethodReference(typ, method_id)
}
};
let func = interner.push_expr(expr);
Expand Down
11 changes: 10 additions & 1 deletion compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
graph::CrateId,
node_interner::{FuncId, TraitId},
node_interner::{FuncId, TraitId, TraitMethodId},
Generics, Ident, NoirFunction, Type, TypeVariable, TypeVariableId,
};
use noirc_errors::Span;
Expand Down Expand Up @@ -111,6 +111,15 @@ impl Trait {
pub fn set_methods(&mut self, methods: Vec<TraitFunction>) {
self.methods = methods;
}

pub fn find_method(&self, name: Ident) -> Option<TraitMethodId> {
for (idx, method) in self.methods.iter().enumerate() {
if method.name == name {
return Some(TraitMethodId { trait_id: self.id, method_index: idx });
}
}
None
}
}

impl std::fmt::Display for Trait {
Expand Down
Loading