Skip to content

Commit

Permalink
feat(experimental): comptime globals (#4918)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4916 

## Summary\*

Implements immutable `comptime` globals.

I've made a separate issue for mutable comptime globals since those
would require allowing a `mut` modifier when parsing globals in general.

## Additional Context

The scopes of the interpreter were already set up for globals to work.
The main main part that was needed here was evaluating comptime-known
identifiers during the scan pass. Without this monomorphization would
just use the original rhs of each global.

## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
jfecher and TomAFrench authored Apr 29, 2024
1 parent 8c89f04 commit 8a3c7f1
Show file tree
Hide file tree
Showing 21 changed files with 256 additions and 48 deletions.
1 change: 1 addition & 0 deletions aztec_macros/src/utils/hir_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ pub fn inject_global(
module_id,
file_id,
global.attributes.clone(),
false,
);

// Add the statement to the scope so its path can be looked up later
Expand Down
7 changes: 4 additions & 3 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum StatementKind {
Break,
Continue,
/// This statement should be executed at compile-time
Comptime(Box<StatementKind>),
Comptime(Box<Statement>),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -486,7 +486,8 @@ impl Pattern {
pub fn name_ident(&self) -> &Ident {
match self {
Pattern::Identifier(name_ident) => name_ident,
_ => panic!("only the identifier pattern can return a name"),
Pattern::Mutable(pattern, ..) => pattern.name_ident(),
_ => panic!("Only the Identifier or Mutable patterns can return a name"),
}
}

Expand Down Expand Up @@ -685,7 +686,7 @@ impl Display for StatementKind {
StatementKind::For(for_loop) => for_loop.fmt(f),
StatementKind::Break => write!(f, "break"),
StatementKind::Continue => write!(f, "continue"),
StatementKind::Comptime(statement) => write!(f, "comptime {statement}"),
StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind),
StatementKind::Semi(semi) => write!(f, "{semi};"),
StatementKind::Error => write!(f, "Error"),
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl StmtId {
HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)),
HirStatement::Error => StatementKind::Error,
HirStatement::Comptime(statement) => {
StatementKind::Comptime(Box::new(statement.to_ast(interner).kind))
StatementKind::Comptime(Box::new(statement.to_ast(interner)))
}
};

Expand Down
29 changes: 19 additions & 10 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ impl<'a> Interpreter<'a> {
/// `exit_function` is called.
pub(super) fn enter_function(&mut self) -> (bool, Vec<HashMap<DefinitionId, Value>>) {
// Drain every scope except the global scope
let scope = self.scopes.drain(1..).collect();
self.push_scope();
let mut scope = Vec::new();
if self.scopes.len() > 1 {
scope = self.scopes.drain(1..).collect();
self.push_scope();
}
(std::mem::take(&mut self.in_loop), scope)
}

Expand All @@ -160,7 +163,7 @@ impl<'a> Interpreter<'a> {
self.scopes.last_mut().unwrap()
}

fn define_pattern(
pub(super) fn define_pattern(
&mut self,
pattern: &HirPattern,
typ: &Type,
Expand Down Expand Up @@ -262,7 +265,7 @@ impl<'a> Interpreter<'a> {
Err(InterpreterError::NonComptimeVarReferenced { name, location })
}

fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
pub(super) fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
self.lookup_id(ident.id, ident.location)
}

Expand Down Expand Up @@ -291,7 +294,7 @@ impl<'a> Interpreter<'a> {
}

/// Evaluate an expression and return the result
fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
HirExpression::Ident(ident) => self.evaluate_ident(ident, id),
HirExpression::Literal(literal) => self.evaluate_literal(literal, id),
Expand Down Expand Up @@ -322,7 +325,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
Expand All @@ -332,9 +335,15 @@ impl<'a> Interpreter<'a> {
}
DefinitionKind::Local(_) => self.lookup(&ident),
DefinitionKind::Global(global_id) => {
let let_ = self.interner.get_global_let_statement(*global_id).unwrap();
self.evaluate_let(let_)?;
self.lookup(&ident)
// Don't need to check let_.comptime, we can evaluate non-comptime globals too.
// Avoid resetting the value if it is already known
if let Ok(value) = self.lookup(&ident) {
Ok(value)
} else {
let let_ = self.interner.get_global_let_statement(*global_id).unwrap();
self.evaluate_let(let_)?;
self.lookup(&ident)
}
}
DefinitionKind::GenericType(type_variable) => {
let value = match &*type_variable.borrow() {
Expand Down Expand Up @@ -1027,7 +1036,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
pub(super) fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
let rhs = self.evaluate(let_.expression)?;
let location = self.interner.expr_location(&let_.expression);
self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?;
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/comptime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mod value;

pub use errors::InterpreterError;
pub use interpreter::Interpreter;
pub use value::Value;
50 changes: 47 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ use crate::{
hir_def::{
expr::{
HirArrayLiteral, HirBlockExpression, HirCallExpression, HirConstructorExpression,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirIdent, HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirMethodCallExpression,
},
stmt::HirForStatement,
},
macros_api::{HirExpression, HirLiteral, HirStatement},
node_interner::{ExprId, FuncId, StmtId},
node_interner::{DefinitionKind, ExprId, FuncId, GlobalId, StmtId},
};

use super::{
errors::{IResult, InterpreterError},
interpreter::Interpreter,
Value,
};

#[allow(dead_code)]
Expand All @@ -48,9 +49,23 @@ impl<'interner> Interpreter<'interner> {
Ok(())
}

/// Evaluate this global if it is a comptime global.
/// Otherwise, scan through its expression for any comptime blocks to evaluate.
pub fn scan_global(&mut self, global: GlobalId) -> IResult<()> {
if let Some(let_) = self.interner.get_global_let_statement(global) {
if let_.comptime {
self.evaluate_let(let_)?;
} else {
self.scan_expression(let_.expression)?;
}
}

Ok(())
}

fn scan_expression(&mut self, expr: ExprId) -> IResult<()> {
match self.interner.expression(&expr) {
HirExpression::Ident(_) => Ok(()),
HirExpression::Ident(ident) => self.scan_ident(ident, expr),
HirExpression::Literal(literal) => self.scan_literal(literal),
HirExpression::Block(block) => self.scan_block(block),
HirExpression::Prefix(prefix) => self.scan_expression(prefix.rhs),
Expand Down Expand Up @@ -91,6 +106,27 @@ impl<'interner> Interpreter<'interner> {
}
}

// Identifiers have no code to execute but we may need to inline any values
// of comptime variables into runtime code.
fn scan_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<()> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
DefinitionKind::Function(_) => Ok(()),
_ => {
// Opportunistically evaluate this identifier to see if it is compile-time known.
// If so, inline its value.
if let Ok(value) = self.evaluate_ident(ident, id) {
// TODO(#4922): Inlining closures is currently unimplemented
if !matches!(value, Value::Closure(..)) {
self.inline_expression(value, id)?;
}
}
Ok(())
}
}
}

fn scan_literal(&mut self, literal: HirLiteral) -> IResult<()> {
match literal {
HirLiteral::Array(elements) | HirLiteral::Slice(elements) => match elements {
Expand Down Expand Up @@ -210,4 +246,12 @@ impl<'interner> Interpreter<'interner> {
self.pop_scope();
Ok(())
}

fn inline_expression(&mut self, value: Value, expr: ExprId) -> IResult<()> {
let location = self.interner.expr_location(&expr);
let new_expr = value.into_expression(self.interner, location)?;
let new_expr = self.interner.expression(&new_expr);
self.interner.replace_expr(&expr, new_expr);
Ok(())
}
}
14 changes: 14 additions & 0 deletions compiler/noirc_frontend/src/hir/comptime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ fn mutating_arrays() {
assert_eq!(result, Value::U8(22));
}

#[test]
fn mutate_in_new_scope() {
let program = "fn main() -> pub u8 {
let mut x = 0;
x += 1;
{
x += 1;
}
x
}";
let result = interpret(program, vec!["main".into()]);
assert_eq!(result, Value::U8(2));
}

#[test]
fn for_loop() {
let program = "fn main() -> pub u8 {
Expand Down
41 changes: 33 additions & 8 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ pub enum CompilationError {
InterpreterError(InterpreterError),
}

impl CompilationError {
fn is_error(&self) -> bool {
let diagnostic = CustomDiagnostic::from(self);
diagnostic.is_error()
}
}

impl<'a> From<&'a CompilationError> for CustomDiagnostic {
fn from(value: &'a CompilationError) -> Self {
match value {
Expand Down Expand Up @@ -404,10 +411,15 @@ impl DefCollector {
);
}

resolved_module.errors.extend(context.def_interner.check_for_dependency_cycles());
let cycle_errors = context.def_interner.check_for_dependency_cycles();
let cycles_present = !cycle_errors.is_empty();
resolved_module.errors.extend(cycle_errors);

resolved_module.type_check(context);
resolved_module.evaluate_comptime(&mut context.def_interner);

if !cycles_present {
resolved_module.evaluate_comptime(&mut context.def_interner);
}

resolved_module.errors
}
Expand Down Expand Up @@ -503,13 +515,21 @@ impl ResolvedModule {

/// Evaluate all `comptime` expressions in this module
fn evaluate_comptime(&mut self, interner: &mut NodeInterner) {
let mut interpreter = Interpreter::new(interner);
if self.count_errors() == 0 {
let mut interpreter = Interpreter::new(interner);

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
for (_file, global) in &self.globals {
if let Err(error) = interpreter.scan_global(*global) {
self.errors.push(error.into_compilation_error_pair());
}
}

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
}
}
}
}
Expand All @@ -524,4 +544,9 @@ impl ResolvedModule {
self.globals.extend(globals.globals);
self.errors.extend(globals.errors);
}

/// Counts the number of errors (minus warnings) this program currently has
fn count_errors(&self) -> usize {
self.errors.iter().filter(|(error, _)| error.is_error()).count()
}
}
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use noirc_errors::Location;

use crate::ast::{
FunctionDefinition, Ident, ItemVisibility, LetStatement, ModuleDeclaration, NoirFunction,
NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, TraitImplItem, TraitItem, TypeImpl,
NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Pattern, TraitImplItem, TraitItem,
TypeImpl,
};
use crate::{
graph::CrateId,
Expand Down Expand Up @@ -109,6 +110,7 @@ impl<'a> ModCollector<'a> {
self.module_id,
self.file_id,
global.attributes.clone(),
matches!(global.pattern, Pattern::Mutable { .. }),
);

// Add the statement to the scope so its path can be looked up later
Expand Down Expand Up @@ -463,6 +465,7 @@ impl<'a> ModCollector<'a> {
trait_id.0.local_id,
self.file_id,
vec![],
false,
);

if let Err((first_def, second_def)) = self.def_collector.def_map.modules
Expand Down
9 changes: 9 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ pub enum ResolverError {
JumpInConstrainedFn { is_break: bool, span: Span },
#[error("break/continue are only allowed within loops")]
JumpOutsideLoop { is_break: bool, span: Span },
#[error("Only `comptime` globals can be mutable")]
MutableGlobal { span: Span },
#[error("Self-referential structs are not supported")]
SelfReferentialStruct { span: Span },
#[error("#[inline(tag)] attribute is only allowed on constrained functions")]
Expand Down Expand Up @@ -346,6 +348,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic {
*span,
)
},
ResolverError::MutableGlobal { span } => {
Diagnostic::simple_error(
"Only `comptime` globals may be mutable".into(),
String::new(),
*span,
)
},
ResolverError::SelfReferentialStruct { span } => {
Diagnostic::simple_error(
"Self-referential structs are not supported".into(),
Expand Down
Loading

0 comments on commit 8a3c7f1

Please sign in to comment.