Skip to content

Commit

Permalink
feat: Add generic count check for trait methods (#3382)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Oct 31, 2023
1 parent 7d7d632 commit a9f9717
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 30 deletions.
76 changes: 48 additions & 28 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,13 @@ fn resolve_trait_impls(
new_resolver.set_self_type(Some(self_type.clone()));

if let Some(trait_id) = maybe_trait_id {
check_methods_signatures(&mut new_resolver, &impl_methods, trait_id, errors);
check_methods_signatures(
&mut new_resolver,
&impl_methods,
trait_id,
trait_impl.generics.len(),
errors,
);

let resolved_trait_impl = Shared::new(TraitImpl {
ident: trait_impl.trait_path.last_segment().clone(),
Expand Down Expand Up @@ -1033,6 +1039,7 @@ fn check_methods_signatures(
resolver: &mut Resolver,
impl_methods: &Vec<(FileId, FuncId)>,
trait_id: TraitId,
trait_impl_generic_count: usize,
errors: &mut Vec<(CompilationError, FileId)>,
) {
let the_trait = resolver.interner.get_trait(trait_id);
Expand All @@ -1043,24 +1050,42 @@ fn check_methods_signatures(
let _ = the_trait.self_type_typevar.borrow_mut().bind_to(self_type.clone(), the_trait.span);

for (file_id, func_id) in impl_methods {
let meta = resolver.interner.function_meta(func_id);
let impl_method = resolver.interner.function_meta(func_id);
let func_name = resolver.interner.function_name(func_id).to_owned();

let mut typecheck_errors = Vec::new();

// `method` is None in the case where the impl block has a method that's not part of the trait.
// This is None in the case where the impl block has a method that's not part of the trait.
// If that's the case, a `MethodNotInTrait` error has already been thrown, and we can ignore
// the impl method, since there's nothing in the trait to match its signature against.
if let Some(method) =
if let Some(trait_method) =
the_trait.methods.iter().find(|method| method.name.0.contents == func_name)
{
let function_typ = meta.typ.instantiate(resolver.interner);
let impl_function_type = impl_method.typ.instantiate(resolver.interner);

let impl_method_generic_count =
impl_method.typ.generic_count() - trait_impl_generic_count;
let trait_method_generic_count = trait_method.generics.len();

if impl_method_generic_count != trait_method_generic_count {
let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics {
impl_method_generic_count,
trait_method_generic_count,
trait_name: the_trait.name.to_string(),
method_name: func_name.to_string(),
span: impl_method.location.span,
};
errors.push((error.into(), *file_id));
}

if let Type::Function(params, _, _) = function_typ.0 {
if method.arguments.len() == params.len() {
if let Type::Function(impl_params, _, _) = impl_function_type.0 {
if trait_method.arguments.len() == impl_params.len() {
// Check the parameters of the impl method against the parameters of the trait method
let args = trait_method.arguments.iter();
let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0);

for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in
method.arguments.iter().zip(&params).zip(&meta.parameters.0).enumerate()
args_and_params.enumerate()
{
expected.unify(actual, &mut typecheck_errors, || {
TypeCheckError::TraitMethodParameterTypeMismatch {
Expand All @@ -1073,33 +1098,28 @@ fn check_methods_signatures(
});
}
} else {
errors.push((
DefCollectorErrorKind::MismatchTraitImplementationNumParameters {
actual_num_parameters: meta.parameters.0.len(),
expected_num_parameters: method.arguments.len(),
trait_name: the_trait.name.to_string(),
method_name: func_name.to_string(),
span: meta.location.span,
}
.into(),
*file_id,
));
let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters {
actual_num_parameters: impl_method.parameters.0.len(),
expected_num_parameters: trait_method.arguments.len(),
trait_name: the_trait.name.to_string(),
method_name: func_name.to_string(),
span: impl_method.location.span,
};
errors.push((error.into(), *file_id));
}
}

// Check that impl method return type matches trait return type:
let resolved_return_type =
resolver.resolve_type(meta.return_type.get_type().into_owned());
resolver.resolve_type(impl_method.return_type.get_type().into_owned());

method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || {
let ret_type_span =
meta.return_type.get_type().span.expect("return type must always have a span");
trait_method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || {
let ret_type_span = impl_method.return_type.get_type().span;
let expr_span = ret_type_span.expect("return type must always have a span");

TypeCheckError::TypeMismatch {
expected_typ: method.return_type.to_string(),
expr_typ: meta.return_type().to_string(),
expr_span: ret_type_span,
}
let expected_typ = trait_method.return_type.to_string();
let expr_typ = impl_method.return_type().to_string();
TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span }
});

errors.extend(typecheck_errors.iter().cloned().map(|e| (e.into(), *file_id)));
Expand Down
25 changes: 23 additions & 2 deletions compiler/noirc_frontend/src/hir/def_collector/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,22 @@ pub enum DefCollectorErrorKind {
OverlappingImplNote { span: Span },
#[error("Cannot `impl` a type defined outside the current crate")]
ForeignImpl { span: Span, type_name: String },
#[error("Mismatch number of parameters in of trait implementation")]
#[error("Mismatched number of parameters in trait implementation")]
MismatchTraitImplementationNumParameters {
actual_num_parameters: usize,
expected_num_parameters: usize,
trait_name: String,
method_name: String,
span: Span,
},
#[error("Mismatched number of generics in impl method")]
MismatchTraitImplementationNumGenerics {
impl_method_generic_count: usize,
trait_method_generic_count: usize,
trait_name: String,
method_name: String,
span: Span,
},
#[error("Method is not defined in trait")]
MethodNotInTrait { trait_name: Ident, impl_method: Ident },
#[error("Only traits can be implemented")]
Expand Down Expand Up @@ -174,8 +182,21 @@ impl From<DefCollectorErrorKind> for Diagnostic {
method_name,
span,
} => {
let plural = if expected_num_parameters == 1 { "" } else { "s" };
let primary_message = format!(
"`{trait_name}::{method_name}` expects {expected_num_parameters} parameter{plural}, but this method has {actual_num_parameters}");
Diagnostic::simple_error(primary_message, "".to_string(), span)
}
DefCollectorErrorKind::MismatchTraitImplementationNumGenerics {
impl_method_generic_count,
trait_method_generic_count,
trait_name,
method_name,
span,
} => {
let plural = if trait_method_generic_count == 1 { "" } else { "s" };
let primary_message = format!(
"Method `{method_name}` of trait `{trait_name}` needs {expected_num_parameters} parameters, but has {actual_num_parameters}");
"`{trait_name}::{method_name}` expects {trait_method_generic_count} generic{plural}, but this method has {impl_method_generic_count}");
Diagnostic::simple_error(primary_message, "".to_string(), span)
}
DefCollectorErrorKind::MethodNotInTrait { trait_name, impl_method } => {
Expand Down
15 changes: 15 additions & 0 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,21 @@ impl Type {
.all(|(_, field)| field.is_valid_for_program_input()),
}
}

/// Returns the number of `Forall`-quantified type variables on this type.
/// Returns 0 if this is not a Type::Forall
pub fn generic_count(&self) -> usize {
match self {
Type::Forall(generics, _) => generics.len(),
Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => {
match &*type_variable.borrow() {
TypeBinding::Bound(binding) => binding.generic_count(),
TypeBinding::Unbound(_) => 0,
}
}
_ => 0,
}
}
}

impl std::fmt::Display for Type {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "trait_incorrect_generic_count"
type = "bin"
authors = [""]
compiler_version = ">=0.18.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

fn main(){
let x: u32 = 0;
x.trait_fn();
}

trait Trait {
fn trait_fn<T>(x: T) -> T {}
}

impl Trait for u32 {
fn trait_fn<A, B>(x: A) -> A { x }
}

0 comments on commit a9f9717

Please sign in to comment.