Skip to content

Commit

Permalink
feat(traits): Implement trait bounds def collector + resolver passes (#…
Browse files Browse the repository at this point in the history
…2716)

Co-authored-by: Yordan Madzhunkov <ymadzhunkov@gmail.com>
  • Loading branch information
alexvitkov and ymadzhunkov authored Sep 20, 2023
1 parent fc48930 commit e3d18bb
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 119 deletions.
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::fmt::Display;

use crate::token::{Attributes, Token};
use crate::{
Distinctness, Ident, Path, Pattern, Recoverable, Statement, TraitConstraint, UnresolvedType,
UnresolvedTypeData, Visibility,
Distinctness, Ident, Path, Pattern, Recoverable, Statement, UnresolvedTraitConstraint,
UnresolvedType, UnresolvedTypeData, Visibility,
};
use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -368,7 +368,7 @@ pub struct FunctionDefinition {
pub parameters: Vec<(Pattern, UnresolvedType, Visibility)>,
pub body: BlockExpression,
pub span: Span,
pub where_clause: Vec<TraitConstraint>,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub return_type: FunctionReturnType,
pub return_visibility: Visibility,
pub return_distinctness: Distinctness,
Expand Down Expand Up @@ -634,7 +634,7 @@ impl FunctionDefinition {
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
body: &BlockExpression,
where_clause: &[TraitConstraint],
where_clause: &[UnresolvedTraitConstraint],
return_type: &FunctionReturnType,
) -> FunctionDefinition {
let p = parameters
Expand Down
19 changes: 10 additions & 9 deletions compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use iter_extended::vecmap;
use noirc_errors::Span;

use crate::{
BlockExpression, Expression, FunctionReturnType, Ident, NoirFunction, UnresolvedGenerics,
UnresolvedType,
node_interner::TraitId, BlockExpression, Expression, FunctionReturnType, Ident, NoirFunction,
UnresolvedGenerics, UnresolvedType,
};

/// AST node for trait definitions:
Expand All @@ -14,7 +14,7 @@ use crate::{
pub struct NoirTrait {
pub name: Ident,
pub generics: Vec<Ident>,
pub where_clause: Vec<TraitConstraint>,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub span: Span,
pub items: Vec<TraitItem>,
}
Expand All @@ -28,7 +28,7 @@ pub enum TraitItem {
generics: Vec<Ident>,
parameters: Vec<(Ident, UnresolvedType)>,
return_type: FunctionReturnType,
where_clause: Vec<TraitConstraint>,
where_clause: Vec<UnresolvedTraitConstraint>,
body: Option<BlockExpression>,
},
Constant {
Expand All @@ -54,7 +54,7 @@ pub struct TypeImpl {
/// Ast node for an implementation of a trait for a particular type
/// `impl trait_name<trait_generics> for object_type where where_clauses { ... items ... }`
#[derive(Clone, Debug)]
pub struct TraitImpl {
pub struct NoirTraitImpl {
pub impl_generics: UnresolvedGenerics,

pub trait_name: Ident,
Expand All @@ -63,7 +63,7 @@ pub struct TraitImpl {
pub object_type: UnresolvedType,
pub object_type_span: Span,

pub where_clause: Vec<TraitConstraint>,
pub where_clause: Vec<UnresolvedTraitConstraint>,

pub items: Vec<TraitImplItem>,
}
Expand All @@ -75,7 +75,7 @@ pub struct TraitImpl {
/// `Foo: TraitX`
/// `Foo: TraitY<U, V>`
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TraitConstraint {
pub struct UnresolvedTraitConstraint {
pub typ: UnresolvedType,
pub trait_bound: TraitBound,
}
Expand All @@ -84,6 +84,7 @@ pub struct TraitConstraint {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TraitBound {
pub trait_name: Ident,
pub trait_id: Option<TraitId>, // initially None, gets assigned during DC
pub trait_generics: Vec<UnresolvedType>,
}

Expand Down Expand Up @@ -167,7 +168,7 @@ impl Display for TraitItem {
}
}

impl Display for TraitConstraint {
impl Display for UnresolvedTraitConstraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.typ, self.trait_bound)
}
Expand All @@ -184,7 +185,7 @@ impl Display for TraitBound {
}
}

impl Display for TraitImpl {
impl Display for NoirTraitImpl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let generics = vecmap(&self.trait_generics, |generic| generic.to_string());
let generics = generics.join(", ");
Expand Down
23 changes: 16 additions & 7 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use crate::hir::resolution::{
};
use crate::hir::type_check::{type_check_func, TypeCheckError, TypeChecker};
use crate::hir::Context;
use crate::hir_def::traits::{TraitConstant, TraitFunction, TraitType};
use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TraitId, TypeAliasId};
use crate::hir_def::traits::{TraitConstant, TraitFunction, TraitImpl, TraitType};
use crate::node_interner::{
FuncId, NodeInterner, StmtId, StructId, TraitId, TraitImplKey, TypeAliasId,
};
use crate::{
ExpressionKind, Generics, Ident, LetStatement, Literal, NoirFunction, NoirStruct, NoirTrait,
NoirTypeAlias, ParsedModule, Shared, StructType, TraitItem, Type, TypeBinding,
Expand Down Expand Up @@ -696,24 +698,31 @@ fn resolve_trait_impls(
errors,
);

let resolved_trait_impl = Shared::new(TraitImpl {
ident: trait_impl.trait_impl_ident.clone(),
typ: self_type.clone(),
trait_id,
methods: vecmap(&impl_methods, |(_, func_id)| *func_id),
});

let mut new_resolver =
Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id);
new_resolver.set_self_type(Some(self_type.clone()));

check_methods_signatures(&mut new_resolver, &impl_methods, trait_id, errors);

let trait_definition_ident = &trait_impl.trait_impl_ident;
let key = (self_type.clone(), trait_id);
if let Some(prev_trait_impl_ident) = interner.get_previous_trait_implementation(&key) {
let key = TraitImplKey { typ: self_type.clone(), trait_id };

if let Some(prev_trait_impl_ident) = interner.get_trait_implementation(&key) {
let err = DefCollectorErrorKind::Duplicate {
typ: DuplicateType::TraitImplementation,
first_def: prev_trait_impl_ident.clone(),
first_def: prev_trait_impl_ident.borrow().ident.clone(),
second_def: trait_definition_ident.clone(),
};
errors.push(err.into_file_diagnostic(trait_impl.methods.file_id));
} else {
let _func_ids =
interner.add_trait_implementaion(&key, trait_definition_ident, &trait_impl.methods);
interner.add_trait_implementation(&key, resolved_trait_impl.clone());
}

methods.append(&mut impl_methods);
Expand Down
124 changes: 69 additions & 55 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ use noirc_errors::{FileDiagnostic, Location};

use crate::{
graph::CrateId,
hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait},
hir::{
def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait},
def_map::ScopeResolveError,
},
node_interner::TraitId,
parser::SubModule,
FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTypeAlias,
ParsedModule, TraitImpl, TraitImplItem, TraitItem, TypeImpl,
FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl,
NoirTypeAlias, ParsedModule, TraitImplItem, TraitItem, TypeImpl,
};

use super::{
Expand All @@ -19,7 +22,7 @@ use super::{
},
errors::{DefCollectorErrorKind, DuplicateType},
};
use crate::hir::def_map::{parse_file, LocalModuleId, ModuleData, ModuleDefId, ModuleId};
use crate::hir::def_map::{parse_file, LocalModuleId, ModuleData, ModuleId};
use crate::hir::resolution::import::ImportDirective;
use crate::hir::Context;

Expand Down Expand Up @@ -131,73 +134,71 @@ impl<'a> ModCollector<'a> {
fn collect_trait_impls(
&mut self,
context: &mut Context,
impls: Vec<TraitImpl>,
impls: Vec<NoirTraitImpl>,
errors: &mut Vec<FileDiagnostic>,
) {
for trait_impl in impls {
let trait_name = trait_impl.trait_name.clone();
let trait_name = &trait_impl.trait_name;
let module = &self.def_collector.def_map.modules[self.module_id.0];
match module.find_name(&trait_name).types {
Some((module_def_id, _visibility)) => {
if let Some(collected_trait) = self.get_unresolved_trait(module_def_id) {
let unresolved_functions = self.collect_trait_implementations(
context,
&trait_impl,
&collected_trait.trait_def,
errors,
);

for (_, func_id, noir_function) in &unresolved_functions.functions {
let name = noir_function.name().to_owned();

context.def_interner.push_function_definition(name, *func_id);
}

let unresolved_trait_impl = UnresolvedTraitImpl {
file_id: self.file_id,
module_id: self.module_id,
the_trait: collected_trait,
methods: unresolved_functions,
trait_impl_ident: trait_impl.trait_name.clone(),
};

let trait_id = match module_def_id {
ModuleDefId::TraitId(trait_id) => trait_id,
_ => unreachable!(),
};

let key = (trait_impl.object_type, self.module_id, trait_id);
self.def_collector
.collected_traits_impls
.insert(key, unresolved_trait_impl);
} else {
let error = DefCollectorErrorKind::NotATrait {
not_a_trait_name: trait_name.clone(),
};
errors.push(error.into_file_diagnostic(self.file_id));
}
}
None => {
let error = DefCollectorErrorKind::TraitNotFound { trait_ident: trait_name };
errors.push(error.into_file_diagnostic(self.file_id));
if let Some(trait_id) = self.find_trait_or_emit_error(module, trait_name, errors) {
let collected_trait =
self.def_collector.collected_traits.get(&trait_id).cloned().unwrap();

let unresolved_functions = self.collect_trait_implementations(
context,
&trait_impl,
&collected_trait.trait_def,
errors,
);

for (_, func_id, noir_function) in &unresolved_functions.functions {
let name = noir_function.name().to_owned();

context.def_interner.push_function_definition(name, *func_id);
}

let unresolved_trait_impl = UnresolvedTraitImpl {
file_id: self.file_id,
module_id: self.module_id,
the_trait: collected_trait,
methods: unresolved_functions,
trait_impl_ident: trait_impl.trait_name.clone(),
};

let key = (trait_impl.object_type, self.module_id, trait_id);
self.def_collector.collected_traits_impls.insert(key, unresolved_trait_impl);
}
}
}

fn get_unresolved_trait(&self, module_def_id: ModuleDefId) -> Option<UnresolvedTrait> {
match module_def_id {
ModuleDefId::TraitId(trait_id) => {
self.def_collector.collected_traits.get(&trait_id).cloned()
fn find_trait_or_emit_error(
&self,
module: &ModuleData,
trait_name: &Ident,
errors: &mut Vec<FileDiagnostic>,
) -> Option<TraitId> {
match module.find_trait_with_name(trait_name) {
Ok(trait_id) => Some(trait_id),
Err(ScopeResolveError::WrongKind) => {
let error =
DefCollectorErrorKind::NotATrait { not_a_trait_name: trait_name.clone() };
errors.push(error.into_file_diagnostic(self.file_id));
None
}
Err(ScopeResolveError::NotFound) => {
let error =
DefCollectorErrorKind::TraitNotFound { trait_ident: trait_name.clone() };
errors.push(error.into_file_diagnostic(self.file_id));
None
}
_ => None,
}
}

fn collect_trait_implementations(
&mut self,
context: &mut Context,
trait_impl: &TraitImpl,
trait_impl: &NoirTraitImpl,
trait_def: &NoirTrait,
errors: &mut Vec<FileDiagnostic>,
) -> UnresolvedFunctions {
Expand Down Expand Up @@ -296,14 +297,27 @@ impl<'a> ModCollector<'a> {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };

for function in functions {
for mut function in functions {
let name = function.name_ident().clone();

// First create dummy function in the DefInterner
// So that we can get a FuncId
let func_id = context.def_interner.push_empty_fn();
context.def_interner.push_function_definition(name.0.contents.clone(), func_id);

// Then go over the where clause and assign trait_ids to the constraints
for constraint in &mut function.def.where_clause {
let module = &self.def_collector.def_map.modules[self.module_id.0];

if let Some(trait_id) = self.find_trait_or_emit_error(
module,
&constraint.trait_bound.trait_name,
errors,
) {
constraint.trait_bound.trait_id = Some(trait_id);
}
}

// Now link this func_id to a crate level map with the noir function and the module id
// Encountering a NoirFunction, we retrieve it's module_data to get the namespace
// Once we have lowered it to a HirFunction, we retrieve it's Id from the DefInterner
Expand Down
21 changes: 20 additions & 1 deletion compiler/noirc_frontend/src/hir/def_map/item_scope.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::{namespace::PerNs, ModuleDefId, ModuleId};
use crate::{node_interner::FuncId, Ident};
use crate::{
node_interner::{FuncId, TraitId},
Ident,
};
use std::collections::{hash_map::Entry, HashMap};

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
Expand All @@ -15,6 +18,14 @@ pub struct ItemScope {
defs: Vec<ModuleDefId>,
}

pub enum ScopeResolveError {
/// the ident we attempted to resolve isn't declared
NotFound,

/// The ident we attempted to resolve is declared, but is the wrong kind - e.g. we want a trait, but it's a function
WrongKind,
}

impl ItemScope {
pub fn add_definition(
&mut self,
Expand Down Expand Up @@ -69,6 +80,14 @@ impl ItemScope {
_ => None,
}
}
pub fn find_trait_with_name(&self, trait_name: &Ident) -> Result<TraitId, ScopeResolveError> {
let (module_def, _) = self.types.get(trait_name).ok_or(ScopeResolveError::NotFound)?;

match module_def {
ModuleDefId::TraitId(id) => Ok(*id),
_ => Err(ScopeResolveError::WrongKind),
}
}

pub fn find_name(&self, name: &Ident) -> PerNs {
PerNs { types: self.types.get(name).cloned(), values: self.values.get(name).cloned() }
Expand Down
6 changes: 5 additions & 1 deletion compiler/noirc_frontend/src/hir/def_map/module_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
Ident,
};

use super::{ItemScope, LocalModuleId, ModuleDefId, ModuleId, PerNs};
use super::{ItemScope, LocalModuleId, ModuleDefId, ModuleId, PerNs, ScopeResolveError};

/// Contains the actual contents of a module: its parent (if one exists),
/// children, and scope with all definitions defined within the scope.
Expand Down Expand Up @@ -85,6 +85,10 @@ impl ModuleData {
self.scope.find_func_with_name(name)
}

pub fn find_trait_with_name(&self, name: &Ident) -> Result<TraitId, ScopeResolveError> {
self.scope.find_trait_with_name(name)
}

pub fn import(&mut self, name: Ident, id: ModuleDefId) -> Result<(), (Ident, Ident)> {
self.scope.add_item_to_namespace(name, id)
}
Expand Down
Loading

0 comments on commit e3d18bb

Please sign in to comment.