Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Infer globals to be u32 when used in a type #6083

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl<'context> Elaborator<'context> {
UnresolvedTypeExpression::Constant(0, span)
});

let length = self.convert_expression_type(length, span);
let length = self.convert_expression_type(length, &Kind::u32(), span);
let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element);

let length_clone = length.clone();
Expand Down
57 changes: 31 additions & 26 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,22 @@ impl<'context> Elaborator<'context> {
FieldElement => Type::FieldElement,
Array(size, elem) => {
let elem = Box::new(self.resolve_type_inner(*elem, kind));
let size = self.convert_expression_type(size, span);
let size = self.convert_expression_type(size, &Kind::u32(), span);
Type::Array(Box::new(size), elem)
}
Slice(elem) => {
let elem = Box::new(self.resolve_type_inner(*elem, kind));
Type::Slice(elem)
}
Expression(expr) => self.convert_expression_type(expr, span),
Expression(expr) => self.convert_expression_type(expr, kind, span),
Integer(sign, bits) => Type::Integer(sign, bits),
Bool => Type::Bool,
String(size) => {
let resolved_size = self.convert_expression_type(size, span);
let resolved_size = self.convert_expression_type(size, &Kind::u32(), span);
Type::String(Box::new(resolved_size))
}
FormatString(size, fields) => {
let resolved_size = self.convert_expression_type(size, span);
let resolved_size = self.convert_expression_type(size, &Kind::u32(), span);
let fields = self.resolve_type_inner(*fields, kind);
Type::FmtString(Box::new(resolved_size), Box::new(fields))
}
Expand Down Expand Up @@ -426,37 +426,25 @@ impl<'context> Elaborator<'context> {
pub(super) fn convert_expression_type(
&mut self,
length: UnresolvedTypeExpression,
expected_kind: &Kind,
span: Span,
) -> Type {
match length {
UnresolvedTypeExpression::Variable(path) => {
let resolved_length =
self.lookup_generic_or_global_type(&path).unwrap_or_else(|| {
self.push_err(ResolverError::NoSuchNumericTypeVariable { path });
Type::Constant(0, Kind::u32())
});

if let Type::NamedGeneric(ref _type_var, ref _name, ref kind) = resolved_length {
if !kind.is_numeric() {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: Kind::u32().to_string(),
expr_kind: kind.to_string(),
expr_span: span,
});
return Type::Error;
}
}
resolved_length
let typ = self.resolve_named_type(path, GenericTypeArgs::default());
self.check_kind(typ, expected_kind, span)
}
UnresolvedTypeExpression::Constant(int, _span) => {
Type::Constant(int, expected_kind.clone())
}
UnresolvedTypeExpression::Constant(int, _span) => Type::Constant(int, Kind::u32()),
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => {
let (lhs_span, rhs_span) = (lhs.span(), rhs.span());
let lhs = self.convert_expression_type(*lhs, lhs_span);
let rhs = self.convert_expression_type(*rhs, rhs_span);
let lhs = self.convert_expression_type(*lhs, expected_kind, lhs_span);
let rhs = self.convert_expression_type(*rhs, expected_kind, rhs_span);

match (lhs, rhs) {
(Type::Constant(lhs, lhs_kind), Type::Constant(rhs, rhs_kind)) => {
if lhs_kind != rhs_kind {
if !lhs_kind.unifies(&rhs_kind) {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: lhs_kind.to_string(),
expr_kind: rhs_kind.to_string(),
Expand All @@ -474,10 +462,27 @@ impl<'context> Elaborator<'context> {
(lhs, rhs) => Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)).canonicalize(),
}
}
UnresolvedTypeExpression::AsTraitPath(path) => self.resolve_as_trait_path(*path),
UnresolvedTypeExpression::AsTraitPath(path) => {
let typ = self.resolve_as_trait_path(*path);
self.check_kind(typ, expected_kind, span)
}
}
}

fn check_kind(&mut self, typ: Type, expected_kind: &Kind, span: Span) -> Type {
if let Some(kind) = typ.kind() {
if !kind.unifies(expected_kind) {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: expected_kind.to_string(),
expr_kind: kind.to_string(),
expr_span: span,
});
return Type::Error;
}
}
typ
}

fn resolve_as_trait_path(&mut self, path: AsTraitPath) -> Type {
let span = path.trait_path.span;
let Some(trait_id) = self.resolve_trait_by_path(path.trait_path.clone()) else {
Expand Down
30 changes: 27 additions & 3 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,36 @@
}

pub(crate) fn matches_opt(&self, other: Option<Self>) -> bool {
other.as_ref().map_or(true, |other_kind| self == other_kind)
other.as_ref().map_or(true, |other_kind| self.unifies(other_kind))
}

pub(crate) fn u32() -> Self {
Self::Numeric(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo)))
}

/// Unifies this kind with the other. Returns true on success
pub(crate) fn unifies(&self, other: &Kind) -> bool {
match (self, other) {
(Kind::Normal, Kind::Normal) => true,
(Kind::Numeric(lhs), Kind::Numeric(rhs)) => {
let mut bindings = TypeBindings::new();
let unifies = lhs.try_unify(rhs, &mut bindings).is_ok();
if unifies {
Type::apply_type_bindings(bindings);
}
unifies
}
_ => false,
}
}

pub(crate) fn unify(&self, other: &Kind) -> Result<(), UnificationError> {
if self.unifies(other) {
Ok(())
} else {
Err(UnificationError)
}
}
}

impl std::fmt::Display for Kind {
Expand Down Expand Up @@ -1465,13 +1489,13 @@
}
}

(NamedGeneric(binding_a, name_a, _), NamedGeneric(binding_b, name_b, _)) => {
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
(NamedGeneric(binding_a, name_a, kind_a), NamedGeneric(binding_b, name_b, kind_b)) => {
// Bound NamedGenerics are caught by the check above
assert!(binding_a.borrow().is_unbound());
assert!(binding_b.borrow().is_unbound());

if name_a == name_b {
Ok(())
kind_a.unify(kind_b)
} else {
Err(UnificationError)
}
Expand Down Expand Up @@ -1857,7 +1881,7 @@
}

let recur_on_binding = |id, replacement: &Type| {
// Prevent recuring forever if there's a `T := T` binding

Check warning on line 1884 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (recuring)
if replacement.type_variable_id() == Some(id) {
replacement.clone()
} else {
Expand Down Expand Up @@ -1928,7 +1952,7 @@
Type::Tuple(fields)
}
Type::Forall(typevars, typ) => {
// Trying to substitute_helper a variable de, substitute_bound_typevarsfined within a nested Forall

Check warning on line 1955 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (typevarsfined)
// is usually impossible and indicative of an error in the type checker somewhere.
for var in typevars {
assert!(!type_bindings.contains_key(&var.id()));
Expand Down Expand Up @@ -2095,7 +2119,7 @@

/// Replace any `Type::NamedGeneric` in this type with a `Type::TypeVariable`
/// using to the same inner `TypeVariable`. This is used during monomorphization
/// to bind to named generics since they are unbindable during type checking.

Check warning on line 2122 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (unbindable)
pub fn replace_named_generics_with_type_variables(&mut self) {
match self {
Type::FieldElement
Expand Down Expand Up @@ -2470,7 +2494,7 @@
len.hash(state);
env.hash(state);
}
Type::Tuple(elems) => elems.hash(state),

Check warning on line 2497 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (elems)

Check warning on line 2497 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (elems)
Type::Struct(def, args) => {
def.hash(state);
args.hash(state);
Expand Down Expand Up @@ -2565,7 +2589,7 @@
// Special case: we consider unbound named generics and type variables to be equal to each
// other if their type variable ids match. This is important for some corner cases in
// monomorphization where we call `replace_named_generics_with_type_variables` but
// still want them to be equal for canonicalization checks in arithmetic generics.

Check warning on line 2592 in compiler/noirc_frontend/src/hir_def/types.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (canonicalization)
// Without this we'd fail the `serialize` test.
(
NamedGeneric(lhs_var, _, _) | TypeVariable(lhs_var, _),
Expand Down
68 changes: 62 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1616,25 +1616,30 @@
#[test]
fn bool_generic_as_loop_bound() {
let src = r#"
pub fn read<let N: bool>() {
let mut fields = [0; N];
for i in 0..N {
pub fn read<let N: bool>() { // error here
let mut fields = [0; N]; // error here
for i in 0..N { // error here
fields[i] = i + 1;
}
assert(fields[0] == 1);
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 2);
assert_eq!(errors.len(), 3);

assert!(matches!(
errors[0].0,
CompilationError::ResolverError(ResolverError::UnsupportedNumericGenericType { .. }),
));

assert!(matches!(
errors[1].0,
CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. }),
));

let CompilationError::TypeError(TypeCheckError::TypeMismatch {
expected_typ, expr_typ, ..
}) = &errors[1].0
}) = &errors[2].0
else {
panic!("Got an error other than a type mismatch");
};
Expand All @@ -1646,7 +1651,7 @@
#[test]
fn numeric_generic_in_function_signature() {
let src = r#"
pub fn foo<let N: u8>(arr: [Field; N]) -> [Field; N] { arr }
pub fn foo<let N: u32>(arr: [Field; N]) -> [Field; N] { arr }
"#;
assert_no_errors(src);
}
Expand Down Expand Up @@ -2368,7 +2373,7 @@
}

#[test]
fn underflowing_u8() {

Check warning on line 2376 in compiler/noirc_frontend/src/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (underflowing)
let src = r#"
fn main() {
let _: u8 = -1;
Expand Down Expand Up @@ -2406,7 +2411,7 @@
}

#[test]
fn underflowing_i8() {

Check warning on line 2414 in compiler/noirc_frontend/src/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (underflowing)
let src = r#"
fn main() {
let _: i8 = -129;
Expand Down Expand Up @@ -3551,7 +3556,7 @@
}

#[test]
fn arithmetic_generics_canonicalization_deduplication_regression() {

Check warning on line 3559 in compiler/noirc_frontend/src/tests.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (canonicalization)
let source = r#"
struct ArrData<let N: u32> {
a: [Field; N],
Expand Down Expand Up @@ -3644,3 +3649,54 @@

assert_eq!(name, "undefined");
}

#[test]
fn infer_globals_to_u32_from_type_use() {
let src = r#"
global ARRAY_LEN = 3;
global STR_LEN = 2;
global FMT_STR_LEN = 2;

fn main() {
let _a: [u32; ARRAY_LEN] = [1, 2, 3];
let _b: str<STR_LEN> = "hi";
let _c: fmtstr<FMT_STR_LEN, _> = f"hi";
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}

#[test]
fn non_u32_in_array_length() {
let src = r#"
global ARRAY_LEN: u8 = 3;

fn main() {
let _a: [u32; ARRAY_LEN] = [1, 2, 3];
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. })
));
}

#[test]
fn use_non_u32_generic_in_struct() {
let src = r#"
struct S<let N: u8> {}

fn main() {
let _: S<3> = S {};
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn main() {
}

// Used in the signature of a function
fn id<let I: Field>(x: [Field; I]) -> [Field; I] {
fn id<let I: u32>(x: [Field; I]) -> [Field; I] {
x
}

Expand Down
Loading