diff --git a/numbat/src/diagnostic.rs b/numbat/src/diagnostic.rs index 6e0fbb95..309c9e88 100644 --- a/numbat/src/diagnostic.rs +++ b/numbat/src/diagnostic.rs @@ -473,15 +473,23 @@ impl ErrorDiagnostic for TypeCheckError { TypeCheckError::ExponentiationNeedsTypeAnnotation(span) => d.with_labels(vec![span .diagnostic_label(LabelStyle::Primary) .with_message(inner_error)]), - TypeCheckError::TypedHoleInStatement(span, type_, statement) => d - .with_labels(vec![span - .diagnostic_label(LabelStyle::Primary) - .with_message(type_)]) - .with_message("Found typed hole") - .with_notes(vec![ + TypeCheckError::TypedHoleInStatement(span, type_, statement, matches) => { + let mut notes = vec![ format!("Found a hole of type '{type_}' in the statement:"), format!(" {statement}"), - ]), + ]; + + if !matches.is_empty() { + notes.push("Relevant matches for this hole include:".into()); + notes.push(format!(" {}", matches.join(", "))); + } + + d.with_labels(vec![span + .diagnostic_label(LabelStyle::Primary) + .with_message(type_)]) + .with_message("Found typed hole") + .with_notes(notes) + } }; vec![d] } diff --git a/numbat/src/typechecker/environment.rs b/numbat/src/typechecker/environment.rs index 67d7ad8c..7caa084a 100644 --- a/numbat/src/typechecker/environment.rs +++ b/numbat/src/typechecker/environment.rs @@ -27,8 +27,9 @@ pub struct FunctionMetadata { #[derive(Clone, Debug)] pub enum IdentifierKind { - /// A normal identifier (variable, unit) with the place where it has been defined - Normal(TypeScheme, Span), + /// A normal identifier (variable, unit) with the place where it has been defined. + /// The boolean flag signifies whether the identifier is a unit or not + Normal(TypeScheme, Span, bool), /// A function Function(FunctionSignature, FunctionMetadata), /// Identifiers that are defined by the language: `_` and `ans` (see LAST_RESULT_IDENTIFIERS) @@ -39,7 +40,7 @@ impl IdentifierKind { fn get_type(&self) -> TypeScheme { match self { IdentifierKind::Predefined(t) => t.clone(), - IdentifierKind::Normal(t, _) => t.clone(), + IdentifierKind::Normal(t, _, _) => t.clone(), IdentifierKind::Function(s, _) => s.fn_type.clone(), } } @@ -51,14 +52,16 @@ pub struct Environment { } impl Environment { - pub fn add(&mut self, i: Identifier, type_: Type, span: Span) { - self.identifiers - .insert(i, IdentifierKind::Normal(TypeScheme::Concrete(type_), span)); + pub fn add(&mut self, i: Identifier, type_: Type, span: Span, is_unit: bool) { + self.identifiers.insert( + i, + IdentifierKind::Normal(TypeScheme::Concrete(type_), span, is_unit), + ); } - pub fn add_scheme(&mut self, i: Identifier, scheme: TypeScheme, span: Span) { + pub fn add_scheme(&mut self, i: Identifier, scheme: TypeScheme, span: Span, is_unit: bool) { self.identifiers - .insert(i, IdentifierKind::Normal(scheme, span)); + .insert(i, IdentifierKind::Normal(scheme, span, is_unit)); } pub(crate) fn add_function( @@ -84,6 +87,17 @@ impl Environment { self.identifiers.keys() } + pub fn iter_relevant_matches(&self) -> impl Iterator { + self.identifiers + .iter() + .filter(|(_, kind)| match kind { + IdentifierKind::Normal(_, _, true) => false, + IdentifierKind::Predefined(..) => false, + _ => true, + }) + .map(|(id, kind)| (id, kind.get_type())) + } + pub(crate) fn get_function_info( &self, name: &str, @@ -97,7 +111,7 @@ impl Environment { pub(crate) fn generalize_types(&mut self, dtype_variables: &[TypeVariable]) { for (_, kind) in self.identifiers.iter_mut() { match kind { - IdentifierKind::Normal(t, _) => { + IdentifierKind::Normal(t, _, _) => { t.generalize(dtype_variables); } IdentifierKind::Function(signature, _) => { @@ -115,7 +129,7 @@ impl ApplySubstitution for Environment { fn apply(&mut self, substitution: &Substitution) -> Result<(), SubstitutionError> { for (_, kind) in self.identifiers.iter_mut() { match kind { - IdentifierKind::Normal(t, _) => { + IdentifierKind::Normal(t, _, _) => { t.apply(substitution)?; } IdentifierKind::Function(signature, _) => { diff --git a/numbat/src/typechecker/error.rs b/numbat/src/typechecker/error.rs index e49918ee..324c32e5 100644 --- a/numbat/src/typechecker/error.rs +++ b/numbat/src/typechecker/error.rs @@ -150,7 +150,7 @@ pub enum TypeCheckError { DerivedUnitDefinitionMustNotBeGeneric(Span), #[error("Typed hole")] - TypedHoleInStatement(Span, String, String), + TypedHoleInStatement(Span, String, String, Vec), #[error("Multiple typed holes in statement")] MultipleTypedHoles(Span), diff --git a/numbat/src/typechecker/mod.rs b/numbat/src/typechecker/mod.rs index c0942d72..ee18404e 100644 --- a/numbat/src/typechecker/mod.rs +++ b/numbat/src/typechecker/mod.rs @@ -1139,7 +1139,7 @@ impl TypeChecker { for (name, _) in decorator::name_and_aliases(identifier, decorators) { self.env - .add(name.clone(), type_deduced.clone(), *identifier_span); + .add(name.clone(), type_deduced.clone(), *identifier_span, false); self.value_namespace.add_identifier_allow_override( name.clone(), @@ -1183,8 +1183,12 @@ impl TypeChecker { .into() }; for (name, _) in decorator::name_and_aliases(unit_name, decorators) { - self.env - .add(name.clone(), Type::Dimension(type_specified.clone()), *span); + self.env.add( + name.clone(), + Type::Dimension(type_specified.clone()), + *span, + true, + ); } typed_ast::Statement::DefineBaseUnit( @@ -1261,7 +1265,7 @@ impl TypeChecker { for (name, _) in decorator::name_and_aliases(identifier, decorators) { self.env - .add(name.clone(), type_deduced.clone(), *identifier_span); + .add(name.clone(), type_deduced.clone(), *identifier_span, true); } typed_ast::Statement::DefineDerivedUnit( identifier.clone(), @@ -1350,6 +1354,7 @@ impl TypeChecker { parameter.clone(), TypeScheme::make_quantified(parameter_type.clone()), *parameter_span, + false, ); typed_parameters.push((*parameter_span, parameter.clone(), parameter_type)); } @@ -1764,6 +1769,13 @@ impl TypeChecker { span, type_of_hole.to_readable_type(&self.registry).to_string(), elaborated_statement.pretty_print().to_string(), + self.env + .iter_relevant_matches() + .filter(|(_, t)| t == &type_of_hole) + .take(10) + .map(|(n, _)| n) + .cloned() + .collect(), )); } diff --git a/numbat/src/typechecker/tests/type_inference.rs b/numbat/src/typechecker/tests/type_inference.rs index 88a62dab..f8aa1f03 100644 --- a/numbat/src/typechecker/tests/type_inference.rs +++ b/numbat/src/typechecker/tests/type_inference.rs @@ -403,26 +403,26 @@ fn recursive_functions() { fn typed_holes() { assert!(matches!( get_typecheck_error("a + ?"), - TypeCheckError::TypedHoleInStatement(_, type_, _) if type_ == "A" + TypeCheckError::TypedHoleInStatement(_, type_, _, _) if type_ == "A" )); assert!(matches!( get_typecheck_error("c + a × ?"), - TypeCheckError::TypedHoleInStatement(_, type_, _) if type_ == "B" + TypeCheckError::TypedHoleInStatement(_, type_, _, _) if type_ == "B" )); assert!(matches!( get_typecheck_error("let x: B = c / ?"), - TypeCheckError::TypedHoleInStatement(_, type_, _) if type_ == "A" + TypeCheckError::TypedHoleInStatement(_, type_, _, _) if type_ == "A" )); assert!(matches!( get_typecheck_error("if true then a else ?"), - TypeCheckError::TypedHoleInStatement(_, type_, _) if type_ == "A" + TypeCheckError::TypedHoleInStatement(_, type_, _, _) if type_ == "A" )); assert!(matches!( get_typecheck_error("let x: C = ?(a, b)"), - TypeCheckError::TypedHoleInStatement(_, type_, _) if type_ == "Fn[(A, B) -> A × B]" + TypeCheckError::TypedHoleInStatement(_, type_, _, _) if type_ == "Fn[(A, B) -> A × B]" )); } diff --git a/numbat/src/typechecker/type_scheme.rs b/numbat/src/typechecker/type_scheme.rs index f579e492..ef0dc96f 100644 --- a/numbat/src/typechecker/type_scheme.rs +++ b/numbat/src/typechecker/type_scheme.rs @@ -141,6 +141,7 @@ impl TypeScheme { // Generate qualified type let bounds = dtype_variables .iter() + .filter(|v| type_.contains(v, true)) .map(|v| Bound::IsDim(Type::TVar(v.clone()))) .collect(); let qualified_type = QualifiedType::new(type_.clone(), bounds);