Skip to content

Commit

Permalink
feat: Implement Operator Overloading (#3931)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Working towards #2568

## Summary\*

This PR implements operator overloading through traits. This works for
all operators in noir, including bitwise operators. The comparison
operators `<`, `<=`, `>`, and `>=`, are notable in that their
corresponding function `cmp` returns an `Ordering` object rather than a
boolean as the operators themselves do.

## Additional Context

There is a bug when using static trait function syntax
`Default::default` or operator overloading when the trait has generic
methods. I'm considering this a separate bug, although for the moment it
prevents us from being able to use e.g. the `Eq` trait for arrays. There
is a boolean in `type_check/expr.rs` that determines if we use a trait
impl or primitive implementation of each operator and because of this
bug I've kept the primitive implementations for string and array
equality.

## Documentation\*

Check one:
- [ ] No documentation needed.
- [x] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <tom@tomfren.ch>
Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 3, 2024
1 parent 459700c commit 4b16090
Show file tree
Hide file tree
Showing 30 changed files with 1,219 additions and 328 deletions.
8 changes: 6 additions & 2 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1253,8 +1253,10 @@ impl<'a> Resolver<'a> {
if let Some((hir_expr, object_type)) = self.resolve_trait_generic_path(&path) {
let expr_id = self.interner.push_expr(hir_expr);
self.interner.push_expr_location(expr_id, expr.span, self.file);
self.interner
.select_impl_for_ident(expr_id, TraitImplKind::Assumed { object_type });
self.interner.select_impl_for_expression(
expr_id,
TraitImplKind::Assumed { object_type },
);
return expr_id;
} else {
// If the Path is being used as an Expression, then it is referring to a global from a separate module
Expand Down Expand Up @@ -1313,10 +1315,12 @@ impl<'a> Resolver<'a> {
ExpressionKind::Infix(infix) => {
let lhs = self.resolve_expression(infix.lhs);
let rhs = self.resolve_expression(infix.rhs);
let trait_id = self.interner.get_operator_trait_method(infix.operator.contents);

HirExpression::Infix(HirInfixExpression {
lhs,
operator: HirBinaryOp::new(infix.operator, self.file),
trait_method_id: trait_id,
rhs,
})
}
Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ pub(crate) fn resolve_traits(
context.def_interner.update_trait(trait_id, |trait_def| {
trait_def.set_methods(methods);
});

// This check needs to be after the trait's methods are set since
// the interner may set `interner.ordering_type` based on the result type
// of the Cmp trait, if this is it.
if crate_id.is_stdlib() {
context.def_interner.try_add_operator_trait(trait_id);
}
}
res
}
Expand Down
187 changes: 111 additions & 76 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,21 @@ impl<'interner> TypeChecker<'interner> {
let rhs_span = self.interner.expr_span(&infix_expr.rhs);
let span = lhs_span.merge(rhs_span);

self.infix_operand_type_rules(&lhs_type, &infix_expr.operator, &rhs_type, span)
.unwrap_or_else(|error| {
let operator = &infix_expr.operator;
match self.infix_operand_type_rules(&lhs_type, operator, &rhs_type, span) {
Ok((typ, use_impl)) => {
if use_impl {
let id = infix_expr.trait_method_id;
self.verify_trait_constraint(&lhs_type, id.trait_id, *expr_id, span);
self.typecheck_operator_method(*expr_id, id, &lhs_type, span);
}
typ
}
Err(error) => {
self.errors.push(error);
Type::Error
})
}
}
}
HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr),
HirExpression::Call(call_expr) => {
Expand Down Expand Up @@ -294,16 +304,18 @@ impl<'interner> TypeChecker<'interner> {

// We must also remember to apply these substitutions to the object_type
// referenced by the selected trait impl, if one has yet to be selected.
let impl_kind = self.interner.get_selected_impl_for_ident(*expr_id);
let impl_kind = self.interner.get_selected_impl_for_expression(*expr_id);
if let Some(TraitImplKind::Assumed { object_type }) = impl_kind {
let the_trait = self.interner.get_trait(method.trait_id);
let object_type = object_type.substitute(&bindings);
bindings.insert(
the_trait.self_type_typevar_id,
(the_trait.self_type_typevar.clone(), object_type.clone()),
);
self.interner
.select_impl_for_ident(*expr_id, TraitImplKind::Assumed { object_type });
self.interner.select_impl_for_expression(
*expr_id,
TraitImplKind::Assumed { object_type },
);
}

self.interner.store_instantiation_bindings(*expr_id, bindings);
Expand All @@ -323,7 +335,7 @@ impl<'interner> TypeChecker<'interner> {
span: Span,
) {
match self.interner.lookup_trait_implementation(object_type, trait_id) {
Ok(impl_kind) => self.interner.select_impl_for_ident(function_ident_id, impl_kind),
Ok(impl_kind) => self.interner.select_impl_for_expression(function_ident_id, impl_kind),
Err(erroring_constraints) => {
// Don't show any errors where try_get_trait returns None.
// This can happen if a trait is used that was never declared.
Expand Down Expand Up @@ -753,19 +765,22 @@ impl<'interner> TypeChecker<'interner> {
None
}

// Given a binary comparison operator and another type. This method will produce the output type
// and a boolean indicating whether to use the trait impl corresponding to the operator
// or not. A value of false indicates the caller to use a primitive operation for this
// operator, while a true value indicates a user-provided trait impl is required.
fn comparator_operand_type_rules(
&mut self,
lhs_type: &Type,
rhs_type: &Type,
op: &HirBinaryOp,
span: Span,
) -> Result<Type, TypeCheckError> {
use crate::BinaryOpKind::{Equal, NotEqual};
) -> Result<(Type, bool), TypeCheckError> {
use Type::*;

match (lhs_type, rhs_type) {
// Avoid reporting errors multiple times
(Error, _) | (_, Error) => Ok(Bool),
(Error, _) | (_, Error) => Ok((Bool, false)),

// Matches on TypeVariable must be first to follow any type
// bindings.
Expand All @@ -791,7 +806,7 @@ impl<'interner> TypeChecker<'interner> {
|| other == &Type::Error
{
Type::apply_type_bindings(bindings);
Ok(Bool)
Ok((Bool, false))
} else {
Err(TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
Expand All @@ -816,56 +831,33 @@ impl<'interner> TypeChecker<'interner> {
span,
});
}
Ok(Bool)
}
(Integer(..), FieldElement) | (FieldElement, Integer(..)) => {
Err(TypeCheckError::IntegerAndFieldBinaryOperation { span })
}
(Integer(..), typ) | (typ, Integer(..)) => {
Err(TypeCheckError::IntegerTypeMismatch { typ: typ.clone(), span })
Ok((Bool, false))
}
(FieldElement, FieldElement) => {
if op.kind.is_valid_for_field_type() {
Ok(Bool)
Ok((Bool, false))
} else {
Err(TypeCheckError::FieldComparison { span })
}
}

// <= and friends are technically valid for booleans, just not very useful
(Bool, Bool) => Ok(Bool),
(Bool, Bool) => Ok((Bool, false)),

// Special-case == and != for arrays
(Array(x_size, x_type), Array(y_size, y_type))
if matches!(op.kind, Equal | NotEqual) =>
if matches!(op.kind, BinaryOpKind::Equal | BinaryOpKind::NotEqual) =>
{
self.unify(x_type, y_type, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::ArrayElements,
span: op.location.span,
});

self.unify(x_size, y_size, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::ArrayLen,
span: op.location.span,
});

Ok(Bool)
}
(lhs @ NamedGeneric(binding_a, _), rhs @ NamedGeneric(binding_b, _)) => {
if binding_a == binding_b {
return Ok(Bool);
}
Err(TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
source: Source::Comparison,
span,
})
self.comparator_operand_type_rules(x_type, y_type, op, span)
}

(String(x_size), String(y_size)) => {
self.unify(x_size, y_size, || TypeCheckError::TypeMismatchWithSource {
expected: *x_size.clone(),
Expand All @@ -874,14 +866,17 @@ impl<'interner> TypeChecker<'interner> {
source: Source::StringLen,
});

Ok(Bool)
Ok((Bool, false))
}
(lhs, rhs) => {
self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
span: op.location.span,
source: Source::Binary,
});
Ok((Bool, true))
}
(lhs, rhs) => Err(TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
source: Source::Comparison,
span,
}),
}
}

Expand Down Expand Up @@ -1041,21 +1036,24 @@ impl<'interner> TypeChecker<'interner> {
}

// Given a binary operator and another type. This method will produce the output type
// and a boolean indicating whether to use the trait impl corresponding to the operator
// or not. A value of false indicates the caller to use a primitive operation for this
// operator, while a true value indicates a user-provided trait impl is required.
fn infix_operand_type_rules(
&mut self,
lhs_type: &Type,
op: &HirBinaryOp,
rhs_type: &Type,
span: Span,
) -> Result<Type, TypeCheckError> {
) -> Result<(Type, bool), TypeCheckError> {
if op.kind.is_comparator() {
return self.comparator_operand_type_rules(lhs_type, rhs_type, op, span);
}

use Type::*;
match (lhs_type, rhs_type) {
// An error type on either side will always return an error
(Error, _) | (_, Error) => Ok(Error),
(Error, _) | (_, Error) => Ok((Error, false)),

// Matches on TypeVariable must be first so that we follow any type
// bindings.
Expand Down Expand Up @@ -1096,7 +1094,7 @@ impl<'interner> TypeChecker<'interner> {
|| other == &Type::Error
{
Type::apply_type_bindings(bindings);
Ok(other.clone())
Ok((other.clone(), false))
} else {
Err(TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
Expand All @@ -1121,25 +1119,8 @@ impl<'interner> TypeChecker<'interner> {
span,
});
}
Ok(Integer(*sign_x, *bit_width_x))
}
(Integer(..), FieldElement) | (FieldElement, Integer(..)) => {
Err(TypeCheckError::IntegerAndFieldBinaryOperation { span })
Ok((Integer(*sign_x, *bit_width_x), false))
}
(Integer(..), typ) | (typ, Integer(..)) => {
Err(TypeCheckError::IntegerTypeMismatch { typ: typ.clone(), span })
}
// These types are not supported in binary operations
(Array(..), _) | (_, Array(..)) => {
Err(TypeCheckError::InvalidInfixOp { kind: "Arrays", span })
}
(Struct(..), _) | (_, Struct(..)) => {
Err(TypeCheckError::InvalidInfixOp { kind: "Structs", span })
}
(Tuple(_), _) | (_, Tuple(_)) => {
Err(TypeCheckError::InvalidInfixOp { kind: "Tuples", span })
}

// The result of two Fields is always a witness
(FieldElement, FieldElement) => {
if op.is_bitwise() {
Expand All @@ -1148,17 +1129,20 @@ impl<'interner> TypeChecker<'interner> {
if op.is_modulo() {
return Err(TypeCheckError::FieldModulo { span });
}
Ok(FieldElement)
Ok((FieldElement, false))
}

(Bool, Bool) => Ok(Bool),
(Bool, Bool) => Ok((Bool, false)),

(lhs, rhs) => Err(TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
source: Source::BinOp(op.kind),
span,
}),
(lhs, rhs) => {
self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
span: op.location.span,
source: Source::Binary,
});
Ok((lhs.clone(), true))
}
}
}

Expand Down Expand Up @@ -1210,6 +1194,57 @@ impl<'interner> TypeChecker<'interner> {
}
}
}

/// Prerequisite: verify_trait_constraint of the operator's trait constraint.
///
/// Although by this point the operator is expected to already have a trait impl,
/// we still need to match the operator's type against the method's instantiated type
/// to ensure the instantiation bindings are correct and the monomorphizer can
/// re-apply the needed bindings.
fn typecheck_operator_method(
&mut self,
expr_id: ExprId,
trait_method_id: TraitMethodId,
object_type: &Type,
span: Span,
) {
let the_trait = self.interner.get_trait(trait_method_id.trait_id);

let method = &the_trait.methods[trait_method_id.method_index];
let (method_type, mut bindings) = method.typ.instantiate(self.interner);

match method_type {
Type::Function(args, _, _) => {
// We can cheat a bit and match against only the object type here since no operator
// overload uses other generic parameters or return types aside from the object type.
let expected_object_type = &args[0];
self.unify(object_type, expected_object_type, || TypeCheckError::TypeMismatch {
expected_typ: expected_object_type.to_string(),
expr_typ: object_type.to_string(),
expr_span: span,
});
}
other => {
unreachable!("Expected operator method to have a function type, but found {other}")
}
}

// We must also remember to apply these substitutions to the object_type
// referenced by the selected trait impl, if one has yet to be selected.
let impl_kind = self.interner.get_selected_impl_for_expression(expr_id);
if let Some(TraitImplKind::Assumed { object_type }) = impl_kind {
let the_trait = self.interner.get_trait(trait_method_id.trait_id);
let object_type = object_type.substitute(&bindings);
bindings.insert(
the_trait.self_type_typevar_id,
(the_trait.self_type_typevar.clone(), object_type.clone()),
);
self.interner
.select_impl_for_expression(expr_id, TraitImplKind::Assumed { object_type });
}

self.interner.store_instantiation_bindings(expr_id, bindings);
}
}

/// Taken from: https://stackoverflow.com/a/47127500
Expand Down
Loading

1 comment on commit 4b16090

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 Published on https://noir-lang.org as production
🚀 Deployed on https://6595b225ada86b050209a29a--noir-docs.netlify.app

Please sign in to comment.