Skip to content

Commit

Permalink
fix!: Infer globals to be u32 when used in a type (noir-lang#6083)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves noir-lang#6081
Resolves noir-lang#6082

## Summary\*

Now that we check kinds explicitly we were erroring if globals weren't
already `u32` yet leading to a lot of existing code needing to be
changed.

If we change the exact equality to a unification check we can instead
infer the globals need to be a `u32` without requiring user code to be
changed.

This PR is still breaking since there were causes we didn't catch before
of invalid kinds being used in array lengths, e.g. `fn foo<let N: u8>(a:
[Field; N])` in one of our unit tests.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
jfecher authored Sep 18, 2024
1 parent e678091 commit 78262c9
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 37 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl<'context> Elaborator<'context> {
UnresolvedTypeExpression::Constant(0, span)
});

let length = self.convert_expression_type(length, span);
let length = self.convert_expression_type(length, &Kind::u32(), span);
let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element);

let length_clone = length.clone();
Expand Down
57 changes: 31 additions & 26 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,22 @@ impl<'context> Elaborator<'context> {
FieldElement => Type::FieldElement,
Array(size, elem) => {
let elem = Box::new(self.resolve_type_inner(*elem, kind));
let size = self.convert_expression_type(size, span);
let size = self.convert_expression_type(size, &Kind::u32(), span);
Type::Array(Box::new(size), elem)
}
Slice(elem) => {
let elem = Box::new(self.resolve_type_inner(*elem, kind));
Type::Slice(elem)
}
Expression(expr) => self.convert_expression_type(expr, span),
Expression(expr) => self.convert_expression_type(expr, kind, span),
Integer(sign, bits) => Type::Integer(sign, bits),
Bool => Type::Bool,
String(size) => {
let resolved_size = self.convert_expression_type(size, span);
let resolved_size = self.convert_expression_type(size, &Kind::u32(), span);
Type::String(Box::new(resolved_size))
}
FormatString(size, fields) => {
let resolved_size = self.convert_expression_type(size, span);
let resolved_size = self.convert_expression_type(size, &Kind::u32(), span);
let fields = self.resolve_type_inner(*fields, kind);
Type::FmtString(Box::new(resolved_size), Box::new(fields))
}
Expand Down Expand Up @@ -426,37 +426,25 @@ impl<'context> Elaborator<'context> {
pub(super) fn convert_expression_type(
&mut self,
length: UnresolvedTypeExpression,
expected_kind: &Kind,
span: Span,
) -> Type {
match length {
UnresolvedTypeExpression::Variable(path) => {
let resolved_length =
self.lookup_generic_or_global_type(&path).unwrap_or_else(|| {
self.push_err(ResolverError::NoSuchNumericTypeVariable { path });
Type::Constant(0, Kind::u32())
});

if let Type::NamedGeneric(ref _type_var, ref _name, ref kind) = resolved_length {
if !kind.is_numeric() {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: Kind::u32().to_string(),
expr_kind: kind.to_string(),
expr_span: span,
});
return Type::Error;
}
}
resolved_length
let typ = self.resolve_named_type(path, GenericTypeArgs::default());
self.check_kind(typ, expected_kind, span)
}
UnresolvedTypeExpression::Constant(int, _span) => {
Type::Constant(int, expected_kind.clone())
}
UnresolvedTypeExpression::Constant(int, _span) => Type::Constant(int, Kind::u32()),
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => {
let (lhs_span, rhs_span) = (lhs.span(), rhs.span());
let lhs = self.convert_expression_type(*lhs, lhs_span);
let rhs = self.convert_expression_type(*rhs, rhs_span);
let lhs = self.convert_expression_type(*lhs, expected_kind, lhs_span);
let rhs = self.convert_expression_type(*rhs, expected_kind, rhs_span);

match (lhs, rhs) {
(Type::Constant(lhs, lhs_kind), Type::Constant(rhs, rhs_kind)) => {
if lhs_kind != rhs_kind {
if !lhs_kind.unifies(&rhs_kind) {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: lhs_kind.to_string(),
expr_kind: rhs_kind.to_string(),
Expand All @@ -474,10 +462,27 @@ impl<'context> Elaborator<'context> {
(lhs, rhs) => Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)).canonicalize(),
}
}
UnresolvedTypeExpression::AsTraitPath(path) => self.resolve_as_trait_path(*path),
UnresolvedTypeExpression::AsTraitPath(path) => {
let typ = self.resolve_as_trait_path(*path);
self.check_kind(typ, expected_kind, span)
}
}
}

fn check_kind(&mut self, typ: Type, expected_kind: &Kind, span: Span) -> Type {
if let Some(kind) = typ.kind() {
if !kind.unifies(expected_kind) {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: expected_kind.to_string(),
expr_kind: kind.to_string(),
expr_span: span,
});
return Type::Error;
}
}
typ
}

fn resolve_as_trait_path(&mut self, path: AsTraitPath) -> Type {
let span = path.trait_path.span;
let Some(trait_id) = self.resolve_trait_by_path(path.trait_path.clone()) else {
Expand Down
30 changes: 27 additions & 3 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,36 @@ impl Kind {
}

pub(crate) fn matches_opt(&self, other: Option<Self>) -> bool {
other.as_ref().map_or(true, |other_kind| self == other_kind)
other.as_ref().map_or(true, |other_kind| self.unifies(other_kind))
}

pub(crate) fn u32() -> Self {
Self::Numeric(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo)))
}

/// Unifies this kind with the other. Returns true on success
pub(crate) fn unifies(&self, other: &Kind) -> bool {
match (self, other) {
(Kind::Normal, Kind::Normal) => true,
(Kind::Numeric(lhs), Kind::Numeric(rhs)) => {
let mut bindings = TypeBindings::new();
let unifies = lhs.try_unify(rhs, &mut bindings).is_ok();
if unifies {
Type::apply_type_bindings(bindings);
}
unifies
}
_ => false,
}
}

pub(crate) fn unify(&self, other: &Kind) -> Result<(), UnificationError> {
if self.unifies(other) {
Ok(())
} else {
Err(UnificationError)
}
}
}

impl std::fmt::Display for Kind {
Expand Down Expand Up @@ -1465,13 +1489,13 @@ impl Type {
}
}

(NamedGeneric(binding_a, name_a, _), NamedGeneric(binding_b, name_b, _)) => {
(NamedGeneric(binding_a, name_a, kind_a), NamedGeneric(binding_b, name_b, kind_b)) => {
// Bound NamedGenerics are caught by the check above
assert!(binding_a.borrow().is_unbound());
assert!(binding_b.borrow().is_unbound());

if name_a == name_b {
Ok(())
kind_a.unify(kind_b)
} else {
Err(UnificationError)
}
Expand Down
68 changes: 62 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1616,25 +1616,30 @@ fn numeric_generic_binary_operation_type_mismatch() {
#[test]
fn bool_generic_as_loop_bound() {
let src = r#"
pub fn read<let N: bool>() {
let mut fields = [0; N];
for i in 0..N {
pub fn read<let N: bool>() { // error here
let mut fields = [0; N]; // error here
for i in 0..N { // error here
fields[i] = i + 1;
}
assert(fields[0] == 1);
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 2);
assert_eq!(errors.len(), 3);

assert!(matches!(
errors[0].0,
CompilationError::ResolverError(ResolverError::UnsupportedNumericGenericType { .. }),
));

assert!(matches!(
errors[1].0,
CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. }),
));

let CompilationError::TypeError(TypeCheckError::TypeMismatch {
expected_typ, expr_typ, ..
}) = &errors[1].0
}) = &errors[2].0
else {
panic!("Got an error other than a type mismatch");
};
Expand All @@ -1646,7 +1651,7 @@ fn bool_generic_as_loop_bound() {
#[test]
fn numeric_generic_in_function_signature() {
let src = r#"
pub fn foo<let N: u8>(arr: [Field; N]) -> [Field; N] { arr }
pub fn foo<let N: u32>(arr: [Field; N]) -> [Field; N] { arr }
"#;
assert_no_errors(src);
}
Expand Down Expand Up @@ -3644,3 +3649,54 @@ fn does_not_crash_when_passing_mutable_undefined_variable() {

assert_eq!(name, "undefined");
}

#[test]
fn infer_globals_to_u32_from_type_use() {
let src = r#"
global ARRAY_LEN = 3;
global STR_LEN = 2;
global FMT_STR_LEN = 2;
fn main() {
let _a: [u32; ARRAY_LEN] = [1, 2, 3];
let _b: str<STR_LEN> = "hi";
let _c: fmtstr<FMT_STR_LEN, _> = f"hi";
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}

#[test]
fn non_u32_in_array_length() {
let src = r#"
global ARRAY_LEN: u8 = 3;
fn main() {
let _a: [u32; ARRAY_LEN] = [1, 2, 3];
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. })
));
}

#[test]
fn use_non_u32_generic_in_struct() {
let src = r#"
struct S<let N: u8> {}
fn main() {
let _: S<3> = S {};
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn main() {
}

// Used in the signature of a function
fn id<let I: Field>(x: [Field; I]) -> [Field; I] {
fn id<let I: u32>(x: [Field; I]) -> [Field; I] {
x
}

Expand Down

0 comments on commit 78262c9

Please sign in to comment.