Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): comptime globals #4918

Merged
merged 17 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aztec_macros/src/utils/hir_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ pub fn inject_global(
module_id,
file_id,
global.attributes.clone(),
false,
);

// Add the statement to the scope so its path can be looked up later
Expand Down
7 changes: 4 additions & 3 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum StatementKind {
Break,
Continue,
/// This statement should be executed at compile-time
Comptime(Box<StatementKind>),
Comptime(Box<Statement>),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -486,7 +486,8 @@ impl Pattern {
pub fn name_ident(&self) -> &Ident {
match self {
Pattern::Identifier(name_ident) => name_ident,
_ => panic!("only the identifier pattern can return a name"),
Pattern::Mutable(pattern, ..) => pattern.name_ident(),
_ => panic!("Only the Identifier or Mutable patterns can return a name"),
}
}

Expand Down Expand Up @@ -685,7 +686,7 @@ impl Display for StatementKind {
StatementKind::For(for_loop) => for_loop.fmt(f),
StatementKind::Break => write!(f, "break"),
StatementKind::Continue => write!(f, "continue"),
StatementKind::Comptime(statement) => write!(f, "comptime {statement}"),
StatementKind::Comptime(statement) => write!(f, "comptime {}", statement.kind),
StatementKind::Semi(semi) => write!(f, "{semi};"),
StatementKind::Error => write!(f, "Error"),
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl StmtId {
HirStatement::Semi(expr) => StatementKind::Semi(expr.to_ast(interner)),
HirStatement::Error => StatementKind::Error,
HirStatement::Comptime(statement) => {
StatementKind::Comptime(Box::new(statement.to_ast(interner).kind))
StatementKind::Comptime(Box::new(statement.to_ast(interner)))
}
};

Expand Down
29 changes: 19 additions & 10 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
jfecher marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ impl<'a> Interpreter<'a> {
/// `exit_function` is called.
pub(super) fn enter_function(&mut self) -> (bool, Vec<HashMap<DefinitionId, Value>>) {
// Drain every scope except the global scope
let scope = self.scopes.drain(1..).collect();
self.push_scope();
let mut scope = Vec::new();
if self.scopes.len() > 1 {
scope = self.scopes.drain(1..).collect();
self.push_scope();
}
(std::mem::take(&mut self.in_loop), scope)
}

Expand All @@ -160,7 +163,7 @@ impl<'a> Interpreter<'a> {
self.scopes.last_mut().unwrap()
}

fn define_pattern(
pub(super) fn define_pattern(
&mut self,
pattern: &HirPattern,
typ: &Type,
Expand Down Expand Up @@ -262,7 +265,7 @@ impl<'a> Interpreter<'a> {
Err(InterpreterError::NonComptimeVarReferenced { name, location })
}

fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
pub(super) fn lookup(&self, ident: &HirIdent) -> IResult<Value> {
self.lookup_id(ident.id, ident.location)
}

Expand Down Expand Up @@ -291,7 +294,7 @@ impl<'a> Interpreter<'a> {
}

/// Evaluate an expression and return the result
fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
HirExpression::Ident(ident) => self.evaluate_ident(ident, id),
HirExpression::Literal(literal) => self.evaluate_literal(literal, id),
Expand Down Expand Up @@ -322,7 +325,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
pub(super) fn evaluate_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<Value> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
Expand All @@ -332,9 +335,15 @@ impl<'a> Interpreter<'a> {
}
DefinitionKind::Local(_) => self.lookup(&ident),
DefinitionKind::Global(global_id) => {
let let_ = self.interner.get_global_let_statement(*global_id).unwrap();
self.evaluate_let(let_)?;
self.lookup(&ident)
// Don't need to check let_.comptime, we can evaluate non-comptime globals too.
// Avoid resetting the value if it is already known
if let Ok(value) = self.lookup(&ident) {
Ok(value)
} else {
let let_ = self.interner.get_global_let_statement(*global_id).unwrap();
self.evaluate_let(let_)?;
self.lookup(&ident)
}
}
DefinitionKind::GenericType(type_variable) => {
let value = match &*type_variable.borrow() {
Expand Down Expand Up @@ -1027,7 +1036,7 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
pub(super) fn evaluate_let(&mut self, let_: HirLetStatement) -> IResult<Value> {
let rhs = self.evaluate(let_.expression)?;
let location = self.interner.expr_location(&let_.expression);
self.define_pattern(&let_.pattern, &let_.r#type, rhs, location)?;
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/comptime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mod value;

pub use errors::InterpreterError;
pub use interpreter::Interpreter;
pub use value::Value;
50 changes: 47 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ use crate::{
hir_def::{
expr::{
HirArrayLiteral, HirBlockExpression, HirCallExpression, HirConstructorExpression,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirIdent, HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda,
HirMethodCallExpression,
},
stmt::HirForStatement,
},
macros_api::{HirExpression, HirLiteral, HirStatement},
node_interner::{ExprId, FuncId, StmtId},
node_interner::{DefinitionKind, ExprId, FuncId, GlobalId, StmtId},
};

use super::{
errors::{IResult, InterpreterError},
interpreter::Interpreter,
Value,
};

#[allow(dead_code)]
Expand All @@ -48,9 +49,23 @@ impl<'interner> Interpreter<'interner> {
Ok(())
}

/// Evaluate this global if it is a comptime global.
/// Otherwise, scan through its expression for any comptime blocks to evaluate.
pub fn scan_global(&mut self, global: GlobalId) -> IResult<()> {
if let Some(let_) = self.interner.get_global_let_statement(global) {
if let_.comptime {
self.evaluate_let(let_)?;
} else {
self.scan_expression(let_.expression)?;
}
}

Ok(())
}

fn scan_expression(&mut self, expr: ExprId) -> IResult<()> {
match self.interner.expression(&expr) {
HirExpression::Ident(_) => Ok(()),
HirExpression::Ident(ident) => self.scan_ident(ident, expr),
HirExpression::Literal(literal) => self.scan_literal(literal),
HirExpression::Block(block) => self.scan_block(block),
HirExpression::Prefix(prefix) => self.scan_expression(prefix.rhs),
Expand Down Expand Up @@ -91,6 +106,27 @@ impl<'interner> Interpreter<'interner> {
}
}

// Identifiers have no code to execute but we may need to inline any values
// of comptime variables into runtime code.
fn scan_ident(&mut self, ident: HirIdent, id: ExprId) -> IResult<()> {
let definition = self.interner.definition(ident.id);

match &definition.kind {
DefinitionKind::Function(_) => Ok(()),
jfecher marked this conversation as resolved.
Show resolved Hide resolved
_ => {
// Opportunistically evaluate this identifier to see if it is compile-time known.
// If so, inline its value.
if let Ok(value) = self.evaluate_ident(ident, id) {
// TODO(#4922): Inlining closures is currently unimplemented
if !matches!(value, Value::Closure(..)) {
self.inline_expression(value, id)?;
}
}
Ok(())
}
}
}

fn scan_literal(&mut self, literal: HirLiteral) -> IResult<()> {
match literal {
HirLiteral::Array(elements) | HirLiteral::Slice(elements) => match elements {
Expand Down Expand Up @@ -210,4 +246,12 @@ impl<'interner> Interpreter<'interner> {
self.pop_scope();
Ok(())
}

fn inline_expression(&mut self, value: Value, expr: ExprId) -> IResult<()> {
let location = self.interner.expr_location(&expr);
let new_expr = value.into_expression(self.interner, location)?;
let new_expr = self.interner.expression(&new_expr);
self.interner.replace_expr(&expr, new_expr);
Ok(())
}
}
14 changes: 14 additions & 0 deletions compiler/noirc_frontend/src/hir/comptime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ fn mutating_arrays() {
assert_eq!(result, Value::U8(22));
}

#[test]
fn mutate_in_new_scope() {
let program = "fn main() -> pub u8 {
let mut x = 0;
x += 1;
{
x += 1;
}
x
}";
let result = interpret(program, vec!["main".into()]);
assert_eq!(result, Value::U8(2));
}

#[test]
fn for_loop() {
let program = "fn main() -> pub u8 {
Expand Down
41 changes: 33 additions & 8 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ pub enum CompilationError {
InterpreterError(InterpreterError),
}

impl CompilationError {
fn is_error(&self) -> bool {
let diagnostic = CustomDiagnostic::from(self);
diagnostic.is_error()
}
}

impl<'a> From<&'a CompilationError> for CustomDiagnostic {
fn from(value: &'a CompilationError) -> Self {
match value {
Expand Down Expand Up @@ -404,10 +411,15 @@ impl DefCollector {
);
}

resolved_module.errors.extend(context.def_interner.check_for_dependency_cycles());
let cycle_errors = context.def_interner.check_for_dependency_cycles();
let cycles_present = !cycle_errors.is_empty();
resolved_module.errors.extend(cycle_errors);

resolved_module.type_check(context);
resolved_module.evaluate_comptime(&mut context.def_interner);

if !cycles_present {
resolved_module.evaluate_comptime(&mut context.def_interner);
}

resolved_module.errors
}
Expand Down Expand Up @@ -503,13 +515,21 @@ impl ResolvedModule {

/// Evaluate all `comptime` expressions in this module
fn evaluate_comptime(&mut self, interner: &mut NodeInterner) {
let mut interpreter = Interpreter::new(interner);
if self.count_errors() == 0 {
let mut interpreter = Interpreter::new(interner);

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
for (_file, global) in &self.globals {
if let Err(error) = interpreter.scan_global(*global) {
self.errors.push(error.into_compilation_error_pair());
}
}

for (_file, function) in &self.functions {
// The file returned by the error may be different than the file the
// function is in so only use the error's file id.
if let Err(error) = interpreter.scan_function(*function) {
self.errors.push(error.into_compilation_error_pair());
}
}
}
}
Expand All @@ -524,4 +544,9 @@ impl ResolvedModule {
self.globals.extend(globals.globals);
self.errors.extend(globals.errors);
}

/// Counts the number of errors (minus warnings) this program currently has
fn count_errors(&self) -> usize {
self.errors.iter().filter(|(error, _)| error.is_error()).count()
}
}
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use noirc_errors::Location;

use crate::ast::{
FunctionDefinition, Ident, ItemVisibility, LetStatement, ModuleDeclaration, NoirFunction,
NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, TraitImplItem, TraitItem, TypeImpl,
NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Pattern, TraitImplItem, TraitItem,
TypeImpl,
};
use crate::{
graph::CrateId,
Expand Down Expand Up @@ -109,6 +110,7 @@ impl<'a> ModCollector<'a> {
self.module_id,
self.file_id,
global.attributes.clone(),
matches!(global.pattern, Pattern::Mutable { .. }),
);

// Add the statement to the scope so its path can be looked up later
Expand Down Expand Up @@ -463,6 +465,7 @@ impl<'a> ModCollector<'a> {
trait_id.0.local_id,
self.file_id,
vec![],
false,
);

if let Err((first_def, second_def)) = self.def_collector.def_map.modules
Expand Down
9 changes: 9 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ pub enum ResolverError {
JumpInConstrainedFn { is_break: bool, span: Span },
#[error("break/continue are only allowed within loops")]
JumpOutsideLoop { is_break: bool, span: Span },
#[error("Only `comptime` globals can be mutable")]
MutableGlobal { span: Span },
#[error("Self-referential structs are not supported")]
SelfReferentialStruct { span: Span },
#[error("#[inline(tag)] attribute is only allowed on constrained functions")]
Expand Down Expand Up @@ -346,6 +348,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic {
*span,
)
},
ResolverError::MutableGlobal { span } => {
Diagnostic::simple_error(
"Only `comptime` globals may be mutable".into(),
String::new(),
*span,
)
},
ResolverError::SelfReferentialStruct { span } => {
Diagnostic::simple_error(
"Self-referential structs are not supported".into(),
Expand Down
Loading
Loading