Skip to content

Commit

Permalink
Fix callables in binary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp authored and David Peter committed May 3, 2024
1 parent 3a7a654 commit 54a158e
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions numbat/src/typechecker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,11 @@ impl TypeChecker {
Box::new(rhs_checked),
lhs_type,
));
} else if let Type::Fn(parameter_types, return_type) = rhs_type {
} else if rhs_type.is_fn_type() && op == &BinaryOperator::ConvertTo {
let (parameter_types, return_type) = match rhs_type {
Type::Fn(p, r) => (p, r),
_ => unreachable!(),
};
// make sure that there is just one paramter (return arity error otherwise)
if parameter_types.len() != 1 {
return Err(TypeCheckError::WrongArity {
Expand Down Expand Up @@ -927,6 +931,14 @@ impl TypeChecker {
typed_ast::BinaryOperator::Equal | typed_ast::BinaryOperator::NotEqual => {
if lhs_type.is_dtype() || rhs_type.is_dtype() {
let _ = get_type_and_assert_equality()?;
} else if lhs_type.is_fn_type() || rhs_type.is_fn_type() {
return Err(TypeCheckError::IncompatibleTypesInComparison(
span_op.unwrap(),
lhs_type,
lhs.full_span(),
rhs_type,
rhs.full_span(),
));
} else if lhs_type != rhs_type {
return Err(TypeCheckError::IncompatibleTypesInComparison(
span_op.unwrap(),
Expand Down Expand Up @@ -1678,6 +1690,8 @@ mod tests {
fn error(m: String) -> !
fn returns_never() -> ! = error(\"\")
fn takes_never_returns_a(x: !) -> A = a
let callable = takes_a_returns_b
";

fn base_type(name: &str) -> BaseRepresentation {
Expand Down Expand Up @@ -1783,7 +1797,7 @@ mod tests {
}

#[test]
fn converisons() {
fn comparisons() {
assert_successful_typecheck("2 a > a");
assert_successful_typecheck("2 a / (3 a) > 3");

Expand Down Expand Up @@ -2300,4 +2314,31 @@ mod tests {
",
);
}

#[test]
fn callables() {
assert_successful_typecheck("callable(a)");
assert_successful_typecheck("a -> callable");
assert!(matches!(
get_typecheck_error("callable(b)"),
TypeCheckError::IncompatibleTypesInFunctionCall(..)
));
assert!(matches!(
get_typecheck_error("callable()"),
TypeCheckError::WrongArity { .. }
));
assert!(matches!(
get_typecheck_error("callable(a, a)"),
TypeCheckError::WrongArity { .. }
));

assert!(matches!(
get_typecheck_error("a + callable"),
TypeCheckError::ExpectedDimensionType { .. }
));
assert!(matches!(
get_typecheck_error("callable == callable"),
TypeCheckError::IncompatibleTypesInComparison { .. }
));
}
}

0 comments on commit 54a158e

Please sign in to comment.