diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 942a8892fe1..1d010874fad 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -14,8 +14,7 @@ use crate::hir::type_check::{type_check_func, TypeCheckError, TypeChecker}; use crate::hir::Context; use crate::hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType}; use crate::node_interner::{ - allow_trait_impl_for_type, FuncId, NodeInterner, StmtId, StructId, TraitId, TraitImplKey, - TypeAliasId, + FuncId, NodeInterner, StmtId, StructId, TraitId, TraitImplKey, TypeAliasId, }; use crate::parser::ParserError; @@ -551,12 +550,6 @@ fn collect_trait_impl( errors.push((err.into(), trait_impl.file_id)); } } - } else if !allow_trait_impl_for_type(&typ) { - let error = DefCollectorErrorKind::NonStructTraitImpl { - trait_path: trait_impl.trait_path.clone(), - span: trait_impl.trait_path.span(), - }; - errors.push((error.into(), trait_impl.file_id)); } } } @@ -1025,7 +1018,15 @@ fn resolve_trait_impls( trait_id, methods: vecmap(&impl_methods, |(_, func_id)| *func_id), }); - interner.add_trait_implementation(&key, resolved_trait_impl.clone()); + if !interner.add_trait_implementation(&key, resolved_trait_impl.clone()) { + // error + // unreachable!("Cannot add a method to the unsupported type '{}'", key.typ) + let error = DefCollectorErrorKind::TraitImplNotAllowedFor { + trait_path: trait_impl.trait_path.clone(), + span: trait_impl.trait_path.span(), + }; + errors.push((error.into(), trait_impl.file_id)); + } } methods.append(&mut impl_methods); diff --git a/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/compiler/noirc_frontend/src/hir/def_collector/errors.rs index f959cdec598..d693b250b62 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -30,8 +30,8 @@ pub enum DefCollectorErrorKind { PathResolutionError(PathResolutionError), #[error("Non-struct type used in impl")] NonStructTypeInImpl { span: Span }, - #[error("Non-struct type used in trait impl")] - NonStructTraitImpl { trait_path: Path, span: Span }, + #[error("Trait implementation is not allowed for this")] + TraitImplNotAllowedFor { trait_path: Path, span: Span }, #[error("Cannot `impl` a type defined outside the current crate")] ForeignImpl { span: Span, type_name: String }, #[error("Mismatch number of parameters in of trait implementation")] @@ -119,10 +119,10 @@ impl From for Diagnostic { "Only struct types may have implementation methods".into(), span, ), - DefCollectorErrorKind::NonStructTraitImpl { trait_path, span } => { + DefCollectorErrorKind::TraitImplNotAllowedFor { trait_path, span } => { Diagnostic::simple_error( - format!("Only struct types may implement trait `{trait_path}`"), - "Only struct types may implement traits".into(), + format!("Only limited types may implement trait `{trait_path}`"), + "Only limited types may implement traits".into(), span, ) } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 7f02297a635..f226aa18e48 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -872,29 +872,45 @@ impl NodeInterner { self.trait_implementations.get(key).cloned() } - pub fn add_trait_implementation(&mut self, key: &TraitImplKey, trait_impl: Shared) { + pub fn add_trait_implementation( + &mut self, + key: &TraitImplKey, + trait_impl: Shared, + ) -> bool { self.trait_implementations.insert(key.clone(), trait_impl.clone()); - for func_id in &trait_impl.borrow().methods { - let method_name = self.function_name(func_id).to_owned(); - match &key.typ { - Type::Struct(struct_type, _generics) => { + match &key.typ { + Type::Struct(struct_type, _generics) => { + for func_id in &trait_impl.borrow().methods { + let method_name = self.function_name(func_id).to_owned(); let key = (struct_type.borrow().id, method_name); self.struct_methods.insert(key, *func_id); } - Type::FieldElement - | Type::Array(..) - | Type::Integer(..) - | Type::Bool - | Type::Tuple(..) - | Type::String(..) => { + true + } + Type::FieldElement + | Type::Unit + | Type::Array(..) + | Type::Integer(..) + | Type::Bool + | Type::Tuple(..) + | Type::String(..) + | Type::FmtString(..) + => { + for func_id in &trait_impl.borrow().methods { + let method_name = self.function_name(func_id).to_owned(); let key = (key.typ.clone(), method_name); self.primitive_trait_impls.insert(key, *func_id); } - Type::Error => {} - _ => { - unreachable!("Cannot add a method to the unsupported type '{}'", key.typ) - } + true } + Type::TypeVariable(..) | + Type::NamedGeneric(..) | + Type::Function(..) | + Type::MutableReference(..) | + Type::Forall(..) | + Type::Constant(..) | + Type::NotConstant | + Type::Error => false, } } @@ -954,26 +970,3 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::FmtString(_, _) => None, } } - -pub fn allow_trait_impl_for_type(typ: &Type) -> bool { - match &typ { - Type::FieldElement - | Type::Array(..) - | Type::Integer(..) - | Type::Bool - | Type::Struct(..) - | Type::Tuple(..) - | Type::String(..) => true, - - Type::FmtString(..) - | Type::Unit - | Type::TypeVariable(..) - | Type::NamedGeneric(..) - | Type::Function(..) - | Type::MutableReference(..) - | Type::Forall(..) - | Type::Constant(..) - | Type::NotConstant - | Type::Error => false, - } -} diff --git a/tooling/nargo_cli/tests/execution_success/trait_impl_base_type/src/main.nr b/tooling/nargo_cli/tests/execution_success/trait_impl_base_type/src/main.nr index f4a424d0052..30d79f8ffeb 100644 --- a/tooling/nargo_cli/tests/execution_success/trait_impl_base_type/src/main.nr +++ b/tooling/nargo_cli/tests/execution_success/trait_impl_base_type/src/main.nr @@ -49,6 +49,28 @@ impl Fieldable for str<6> { } } +impl Fieldable for () { + fn to_field(self) -> Field { + 0 + } +} + +type Point2D = [Field; 2]; +type Point2DAlias = Point2D; + +impl Fieldable for Point2DAlias { + fn to_field(self) -> Field { + self[0] + self[1] + } +} + +impl Fieldable for fmtstr<14, (Field, Field)> { + fn to_field(self) -> Field { + 52 + } +} + + // x = 15 fn main(x: u32) { assert(x.to_field() == 15); @@ -66,4 +88,11 @@ fn main(x: u32) { assert(k_false.to_field() == 32); let m = "String"; assert(m.to_field() == 6); + let unit = (); + assert(unit.to_field() == 0); + let point: Point2DAlias = [2, 3]; + assert(point.to_field() == 5); + let i: Field = 2; + let j: Field = 6; + assert(f"i: {i}, j: {j}".to_field() == 52); } \ No newline at end of file