Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(traits): Implement trait bounds def collector + resolver passes #2716

Merged
merged 7 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::fmt::Display;

use crate::token::{Attributes, Token};
use crate::{
Distinctness, Ident, Path, Pattern, Recoverable, Statement, TraitConstraint, UnresolvedType,
UnresolvedTypeData, Visibility,
Distinctness, Ident, Path, Pattern, Recoverable, Statement, UnresolvedTraitConstraint,
UnresolvedType, UnresolvedTypeData, Visibility,
};
use acvm::FieldElement;
use iter_extended::vecmap;
Expand Down Expand Up @@ -368,7 +368,7 @@ pub struct FunctionDefinition {
pub parameters: Vec<(Pattern, UnresolvedType, Visibility)>,
pub body: BlockExpression,
pub span: Span,
pub where_clause: Vec<TraitConstraint>,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub return_type: FunctionReturnType,
pub return_visibility: Visibility,
pub return_distinctness: Distinctness,
Expand Down Expand Up @@ -634,7 +634,7 @@ impl FunctionDefinition {
generics: &UnresolvedGenerics,
parameters: &[(Ident, UnresolvedType)],
body: &BlockExpression,
where_clause: &[TraitConstraint],
where_clause: &[UnresolvedTraitConstraint],
return_type: &FunctionReturnType,
) -> FunctionDefinition {
let p = parameters
Expand Down
19 changes: 10 additions & 9 deletions compiler/noirc_frontend/src/ast/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use iter_extended::vecmap;
use noirc_errors::Span;

use crate::{
BlockExpression, Expression, FunctionReturnType, Ident, NoirFunction, UnresolvedGenerics,
UnresolvedType,
node_interner::TraitId, BlockExpression, Expression, FunctionReturnType, Ident, NoirFunction,
UnresolvedGenerics, UnresolvedType,
};

/// AST node for trait definitions:
Expand All @@ -14,7 +14,7 @@ use crate::{
pub struct NoirTrait {
pub name: Ident,
pub generics: Vec<Ident>,
pub where_clause: Vec<TraitConstraint>,
pub where_clause: Vec<UnresolvedTraitConstraint>,
pub span: Span,
pub items: Vec<TraitItem>,
}
Expand All @@ -28,7 +28,7 @@ pub enum TraitItem {
generics: Vec<Ident>,
parameters: Vec<(Ident, UnresolvedType)>,
return_type: FunctionReturnType,
where_clause: Vec<TraitConstraint>,
where_clause: Vec<UnresolvedTraitConstraint>,
body: Option<BlockExpression>,
},
Constant {
Expand All @@ -54,7 +54,7 @@ pub struct TypeImpl {
/// Ast node for an implementation of a trait for a particular type
/// `impl trait_name<trait_generics> for object_type where where_clauses { ... items ... }`
#[derive(Clone, Debug)]
pub struct TraitImpl {
pub struct NoirTraitImpl {
pub impl_generics: UnresolvedGenerics,

pub trait_name: Ident,
Expand All @@ -63,7 +63,7 @@ pub struct TraitImpl {
pub object_type: UnresolvedType,
pub object_type_span: Span,

pub where_clause: Vec<TraitConstraint>,
pub where_clause: Vec<UnresolvedTraitConstraint>,

pub items: Vec<TraitImplItem>,
}
Expand All @@ -75,7 +75,7 @@ pub struct TraitImpl {
/// `Foo: TraitX`
/// `Foo: TraitY<U, V>`
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TraitConstraint {
pub struct UnresolvedTraitConstraint {
pub typ: UnresolvedType,
pub trait_bound: TraitBound,
}
Expand All @@ -84,6 +84,7 @@ pub struct TraitConstraint {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TraitBound {
pub trait_name: Ident,
pub trait_id: Option<TraitId>, // initially None, gets assigned during DC
pub trait_generics: Vec<UnresolvedType>,
}

Expand Down Expand Up @@ -167,7 +168,7 @@ impl Display for TraitItem {
}
}

impl Display for TraitConstraint {
impl Display for UnresolvedTraitConstraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.typ, self.trait_bound)
}
Expand All @@ -184,7 +185,7 @@ impl Display for TraitBound {
}
}

impl Display for TraitImpl {
impl Display for NoirTraitImpl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let generics = vecmap(&self.trait_generics, |generic| generic.to_string());
let generics = generics.join(", ");
Expand Down
23 changes: 16 additions & 7 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use crate::hir::resolution::{
};
use crate::hir::type_check::{type_check_func, TypeCheckError, TypeChecker};
use crate::hir::Context;
use crate::hir_def::traits::{TraitConstant, TraitFunction, TraitType};
use crate::node_interner::{FuncId, NodeInterner, StmtId, StructId, TraitId, TypeAliasId};
use crate::hir_def::traits::{TraitConstant, TraitFunction, TraitImpl, TraitType};
use crate::node_interner::{
FuncId, NodeInterner, StmtId, StructId, TraitId, TraitImplKey, TypeAliasId,
};
use crate::{
ExpressionKind, Generics, Ident, LetStatement, Literal, NoirFunction, NoirStruct, NoirTrait,
NoirTypeAlias, ParsedModule, Shared, StructType, TraitItem, Type, TypeBinding,
Expand Down Expand Up @@ -696,24 +698,31 @@ fn resolve_trait_impls(
errors,
);

let resolved_trait_impl = Shared::new(TraitImpl {
ident: trait_impl.trait_impl_ident.clone(),
typ: self_type.clone(),
trait_id,
methods: vecmap(&impl_methods, |(_, func_id)| *func_id),
});

let mut new_resolver =
Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id);
new_resolver.set_self_type(Some(self_type.clone()));

check_methods_signatures(&mut new_resolver, &impl_methods, trait_id, errors);

let trait_definition_ident = &trait_impl.trait_impl_ident;
let key = (self_type.clone(), trait_id);
if let Some(prev_trait_impl_ident) = interner.get_previous_trait_implementation(&key) {
let key = TraitImplKey { typ: self_type.clone(), trait_id };

if let Some(prev_trait_impl_ident) = interner.get_trait_implementation(&key) {
let err = DefCollectorErrorKind::Duplicate {
typ: DuplicateType::TraitImplementation,
first_def: prev_trait_impl_ident.clone(),
first_def: prev_trait_impl_ident.borrow().ident.clone(),
second_def: trait_definition_ident.clone(),
};
errors.push(err.into_file_diagnostic(trait_impl.methods.file_id));
} else {
let _func_ids =
interner.add_trait_implementaion(&key, trait_definition_ident, &trait_impl.methods);
interner.add_trait_implementation(&key, resolved_trait_impl.clone());
jfecher marked this conversation as resolved.
Show resolved Hide resolved
}

methods.append(&mut impl_methods);
Expand Down
124 changes: 69 additions & 55 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ use noirc_errors::{FileDiagnostic, Location};

use crate::{
graph::CrateId,
hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait},
hir::{
def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait},
def_map::ScopeResolveError,
},
node_interner::TraitId,
parser::SubModule,
FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTypeAlias,
ParsedModule, TraitImpl, TraitImplItem, TraitItem, TypeImpl,
FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl,
NoirTypeAlias, ParsedModule, TraitImplItem, TraitItem, TypeImpl,
};

use super::{
Expand All @@ -19,7 +22,7 @@ use super::{
},
errors::{DefCollectorErrorKind, DuplicateType},
};
use crate::hir::def_map::{parse_file, LocalModuleId, ModuleData, ModuleDefId, ModuleId};
use crate::hir::def_map::{parse_file, LocalModuleId, ModuleData, ModuleId};
use crate::hir::resolution::import::ImportDirective;
use crate::hir::Context;

Expand Down Expand Up @@ -131,73 +134,71 @@ impl<'a> ModCollector<'a> {
fn collect_trait_impls(
&mut self,
context: &mut Context,
impls: Vec<TraitImpl>,
impls: Vec<NoirTraitImpl>,
errors: &mut Vec<FileDiagnostic>,
) {
for trait_impl in impls {
let trait_name = trait_impl.trait_name.clone();
let trait_name = &trait_impl.trait_name;
let module = &self.def_collector.def_map.modules[self.module_id.0];
match module.find_name(&trait_name).types {
Some((module_def_id, _visibility)) => {
if let Some(collected_trait) = self.get_unresolved_trait(module_def_id) {
let unresolved_functions = self.collect_trait_implementations(
context,
&trait_impl,
&collected_trait.trait_def,
errors,
);

for (_, func_id, noir_function) in &unresolved_functions.functions {
let name = noir_function.name().to_owned();

context.def_interner.push_function_definition(name, *func_id);
}

let unresolved_trait_impl = UnresolvedTraitImpl {
file_id: self.file_id,
module_id: self.module_id,
the_trait: collected_trait,
methods: unresolved_functions,
trait_impl_ident: trait_impl.trait_name.clone(),
};

let trait_id = match module_def_id {
ModuleDefId::TraitId(trait_id) => trait_id,
_ => unreachable!(),
};

let key = (trait_impl.object_type, self.module_id, trait_id);
self.def_collector
.collected_traits_impls
.insert(key, unresolved_trait_impl);
} else {
let error = DefCollectorErrorKind::NotATrait {
not_a_trait_name: trait_name.clone(),
};
errors.push(error.into_file_diagnostic(self.file_id));
}
}
None => {
let error = DefCollectorErrorKind::TraitNotFound { trait_ident: trait_name };
errors.push(error.into_file_diagnostic(self.file_id));
if let Some(trait_id) = self.find_trait_or_emit_error(module, trait_name, errors) {
let collected_trait =
self.def_collector.collected_traits.get(&trait_id).cloned().unwrap();

let unresolved_functions = self.collect_trait_implementations(
context,
&trait_impl,
&collected_trait.trait_def,
errors,
);

for (_, func_id, noir_function) in &unresolved_functions.functions {
let name = noir_function.name().to_owned();

context.def_interner.push_function_definition(name, *func_id);
}

let unresolved_trait_impl = UnresolvedTraitImpl {
file_id: self.file_id,
module_id: self.module_id,
the_trait: collected_trait,
methods: unresolved_functions,
trait_impl_ident: trait_impl.trait_name.clone(),
};

let key = (trait_impl.object_type, self.module_id, trait_id);
self.def_collector.collected_traits_impls.insert(key, unresolved_trait_impl);
}
}
}

fn get_unresolved_trait(&self, module_def_id: ModuleDefId) -> Option<UnresolvedTrait> {
match module_def_id {
ModuleDefId::TraitId(trait_id) => {
self.def_collector.collected_traits.get(&trait_id).cloned()
fn find_trait_or_emit_error(
&self,
module: &ModuleData,
trait_name: &Ident,
errors: &mut Vec<FileDiagnostic>,
) -> Option<TraitId> {
match module.find_trait_with_name(trait_name) {
Ok(trait_id) => Some(trait_id),
Err(ScopeResolveError::WrongKind) => {
let error =
DefCollectorErrorKind::NotATrait { not_a_trait_name: trait_name.clone() };
errors.push(error.into_file_diagnostic(self.file_id));
None
}
Err(ScopeResolveError::NotFound) => {
let error =
DefCollectorErrorKind::TraitNotFound { trait_ident: trait_name.clone() };
errors.push(error.into_file_diagnostic(self.file_id));
None
}
_ => None,
}
}

fn collect_trait_implementations(
&mut self,
context: &mut Context,
trait_impl: &TraitImpl,
trait_impl: &NoirTraitImpl,
trait_def: &NoirTrait,
errors: &mut Vec<FileDiagnostic>,
) -> UnresolvedFunctions {
Expand Down Expand Up @@ -296,14 +297,27 @@ impl<'a> ModCollector<'a> {
let mut unresolved_functions =
UnresolvedFunctions { file_id: self.file_id, functions: Vec::new() };

for function in functions {
for mut function in functions {
let name = function.name_ident().clone();

// First create dummy function in the DefInterner
// So that we can get a FuncId
let func_id = context.def_interner.push_empty_fn();
context.def_interner.push_function_definition(name.0.contents.clone(), func_id);

// Then go over the where clause and assign trait_ids to the constraints
for constraint in &mut function.def.where_clause {
let module = &self.def_collector.def_map.modules[self.module_id.0];

if let Some(trait_id) = self.find_trait_or_emit_error(
module,
&constraint.trait_bound.trait_name,
errors,
) {
constraint.trait_bound.trait_id = Some(trait_id);
}
}

// Now link this func_id to a crate level map with the noir function and the module id
// Encountering a NoirFunction, we retrieve it's module_data to get the namespace
// Once we have lowered it to a HirFunction, we retrieve it's Id from the DefInterner
Expand Down
21 changes: 20 additions & 1 deletion compiler/noirc_frontend/src/hir/def_map/item_scope.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::{namespace::PerNs, ModuleDefId, ModuleId};
use crate::{node_interner::FuncId, Ident};
use crate::{
node_interner::{FuncId, TraitId},
Ident,
};
use std::collections::{hash_map::Entry, HashMap};

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
Expand All @@ -15,6 +18,14 @@ pub struct ItemScope {
defs: Vec<ModuleDefId>,
}

pub enum ScopeResolveError {
/// the ident we attempted to resolve isn't declared
NotFound,

/// The ident we attempted to resolve is declared, but is the wrong kind - e.g. we want a trait, but it's a function
WrongKind,
}

impl ItemScope {
pub fn add_definition(
&mut self,
Expand Down Expand Up @@ -69,6 +80,14 @@ impl ItemScope {
_ => None,
}
}
pub fn find_trait_with_name(&self, trait_name: &Ident) -> Result<TraitId, ScopeResolveError> {
let (module_def, _) = self.types.get(trait_name).ok_or(ScopeResolveError::NotFound)?;

match module_def {
ModuleDefId::TraitId(id) => Ok(*id),
_ => Err(ScopeResolveError::WrongKind),
}
}

pub fn find_name(&self, name: &Ident) -> PerNs {
PerNs { types: self.types.get(name).cloned(), values: self.values.get(name).cloned() }
Expand Down
6 changes: 5 additions & 1 deletion compiler/noirc_frontend/src/hir/def_map/module_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
Ident,
};

use super::{ItemScope, LocalModuleId, ModuleDefId, ModuleId, PerNs};
use super::{ItemScope, LocalModuleId, ModuleDefId, ModuleId, PerNs, ScopeResolveError};

/// Contains the actual contents of a module: its parent (if one exists),
/// children, and scope with all definitions defined within the scope.
Expand Down Expand Up @@ -85,6 +85,10 @@ impl ModuleData {
self.scope.find_func_with_name(name)
}

pub fn find_trait_with_name(&self, name: &Ident) -> Result<TraitId, ScopeResolveError> {
self.scope.find_trait_with_name(name)
}

pub fn import(&mut self, name: Ident, id: ModuleDefId) -> Result<(), (Ident, Ident)> {
self.scope.add_item_to_namespace(name, id)
}
Expand Down
Loading