Skip to content

Commit

Permalink
feat: Run comptime code from annotations on a type definition (#5256)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5255

## Summary\*

Implements the ability to run comptime code when an annotation is put on
a type definition. This annotation must resolve to a comptime function
in scope, which is then called with the type definition as an argument.
There are currently no API functions to actually do anything with a
`TypeDefinition` object. The plan is to add functions in the future to
inspect or add fields & generics.

```rs
#[print_type]
struct Foo {
    bar: Field,
}

comptime fn print_type(typ: TypeDefinition) {
    println("hello from print_type_name at compile-time");
    println(typ); // only prints "(type definition)" currently
}

fn main(){}
```

## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Maxim Vezenov <mvezenov@gmail.com>
  • Loading branch information
jfecher and vezenovm authored Jun 18, 2024
1 parent 9db65d8 commit 6cbe6a0
Show file tree
Hide file tree
Showing 21 changed files with 269 additions and 74 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_driver/src/abi_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub(super) fn abi_type_from_hir_type(context: &Context, typ: &Type) -> AbiType {
| Type::TypeVariable(_, _)
| Type::NamedGeneric(..)
| Type::Forall(..)
| Type::Expr
| Type::Quoted(_)
| Type::Slice(_)
| Type::Function(_, _, _) => unreachable!("{typ} cannot be used in the abi"),
Type::FmtString(_, _) => unreachable!("format strings cannot be used in the abi"),
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ pub enum UnresolvedTypeData {
),

// The type of quoted code for metaprogramming
Expr,
Quoted(crate::QuotedType),

Unspecified, // This is for when the user declares a variable without specifying it's type
Error,
Expand Down Expand Up @@ -216,7 +216,7 @@ impl std::fmt::Display for UnresolvedTypeData {
}
}
MutableReference(element) => write!(f, "&mut {element}"),
Expr => write!(f, "Expr"),
Quoted(quoted) => write!(f, "{}", quoted),
Unit => write!(f, "()"),
Error => write!(f, "error"),
Unspecified => write!(f, "unspecified"),
Expand Down
9 changes: 4 additions & 5 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
MethodCallExpression, PrefixExpression,
},
node_interner::{DefinitionKind, ExprId, FuncId},
Shared, StructType, Type,
QuotedType, Shared, StructType, Type,
};

use super::Elaborator;
Expand Down Expand Up @@ -650,7 +650,7 @@ impl<'context> Elaborator<'context> {
let mut unquoted_exprs = Vec::new();
self.find_unquoted_exprs_in_block(&mut block, &mut unquoted_exprs);
let quoted = HirQuoted { quoted_block: block, unquoted_exprs };
(HirExpression::Quote(quoted), Type::Expr)
(HirExpression::Quote(quoted), Type::Quoted(QuotedType::Expr))
}

fn elaborate_comptime_block(&mut self, block: BlockExpression, span: Span) -> (ExprId, Type) {
Expand Down Expand Up @@ -716,9 +716,8 @@ impl<'context> Elaborator<'context> {
location: Location,
return_type: Type,
) -> Option<(HirExpression, Type)> {
self.unify(&return_type, &Type::Expr, || TypeCheckError::MacroReturningNonExpr {
typ: return_type.clone(),
span: location.span,
self.unify(&return_type, &Type::Quoted(QuotedType::Expr), || {
TypeCheckError::MacroReturningNonExpr { typ: return_type.clone(), span: location.span }
});

let function = match self.try_get_comptime_function(func, location) {
Expand Down
133 changes: 114 additions & 19 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use crate::{
ast::{FunctionKind, UnresolvedTraitConstraint},
hir::{
comptime::{self, Interpreter},
comptime::{self, Interpreter, Value},
def_collector::{
dc_crate::{
filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal,
Expand Down Expand Up @@ -205,69 +205,79 @@ impl<'context> Elaborator<'context> {
pub fn elaborate(
context: &'context mut Context,
crate_id: CrateId,
mut items: CollectedItems,
items: CollectedItems,
) -> Vec<(CompilationError, FileId)> {
let mut this = Self::new(context, crate_id);

// Filter out comptime items to execute their functions first if needed.
// This step is why comptime items can only refer to other comptime items
// in the same crate, but can refer to any item in dependencies. Trying to
// run these at the same time as other items would lead to them seeing empty
// function bodies from functions that have yet to be elaborated.
let (comptime_items, runtime_items) = Self::filter_comptime_items(items);
this.elaborate_items(comptime_items);
this.elaborate_items(runtime_items);
this.errors
}

fn elaborate_items(&mut self, mut items: CollectedItems) {
// We must first resolve and intern the globals before we can resolve any stmts inside each function.
// Each function uses its own resolver with a newly created ScopeForest, and must be resolved again to be within a function's scope
//
// Additionally, we must resolve integer globals before structs since structs may refer to
// the values of integer globals as numeric generics.
let (literal_globals, non_literal_globals) = filter_literal_globals(items.globals);
for global in non_literal_globals {
this.unresolved_globals.insert(global.global_id, global);
self.unresolved_globals.insert(global.global_id, global);
}

for global in literal_globals {
this.elaborate_global(global);
self.elaborate_global(global);
}

for (alias_id, alias) in items.type_aliases {
this.define_type_alias(alias_id, alias);
self.define_type_alias(alias_id, alias);
}

this.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);
this.collect_traits(items.traits);
self.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);
self.collect_traits(items.traits);

// Must resolve structs before we resolve globals.
this.collect_struct_definitions(items.types);
self.collect_struct_definitions(items.types);

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
// done during def collection since we need to be able to resolve the type of
// the impl since that determines the module we should collect into.
for ((_self_type, module), impls) in &mut items.impls {
this.collect_impls(*module, impls);
self.collect_impls(*module, impls);
}

// Bind trait impls to their trait. Collect trait functions, that have a
// default implementation, which hasn't been overridden.
for trait_impl in &mut items.trait_impls {
this.collect_trait_impl(trait_impl);
self.collect_trait_impl(trait_impl);
}

// We must wait to resolve non-literal globals until after we resolve structs since struct
// globals will need to reference the struct type they're initialized to ensure they are valid.
while let Some((_, global)) = this.unresolved_globals.pop_first() {
this.elaborate_global(global);
while let Some((_, global)) = self.unresolved_globals.pop_first() {
self.elaborate_global(global);
}

for functions in items.functions {
this.elaborate_functions(functions);
self.elaborate_functions(functions);
}

for impls in items.impls.into_values() {
this.elaborate_impls(impls);
self.elaborate_impls(impls);
}

for trait_impl in items.trait_impls {
this.elaborate_trait_impl(trait_impl);
self.elaborate_trait_impl(trait_impl);
}

let cycle_errors = this.interner.check_for_dependency_cycles();
this.errors.extend(cycle_errors);
this.errors
self.errors.extend(self.interner.check_for_dependency_cycles());
}

/// Runs `f` and if it modifies `self.generics`, `self.generics` is truncated
Expand Down Expand Up @@ -1085,15 +1095,20 @@ impl<'context> Elaborator<'context> {

// Resolve each field in each struct.
// Each struct should already be present in the NodeInterner after def collection.
for (type_id, typ) in structs {
for (type_id, mut typ) in structs {
self.file = typ.file_id;
self.local_module = typ.module_id;

let attributes = std::mem::take(&mut typ.struct_def.attributes);
let span = typ.struct_def.span;
let (generics, fields) = self.resolve_struct_fields(typ.struct_def, type_id);

self.interner.update_struct(type_id, |struct_def| {
struct_def.set_fields(fields);
struct_def.generics = generics;
});

self.run_comptime_attributes_on_struct(attributes, type_id, span);
}

// Check whether the struct fields have nested slices
Expand All @@ -1117,6 +1132,38 @@ impl<'context> Elaborator<'context> {
}
}

fn run_comptime_attributes_on_struct(
&mut self,
attributes: Vec<SecondaryAttribute>,
struct_id: StructId,
span: Span,
) {
for attribute in attributes {
if let SecondaryAttribute::Custom(name) = attribute {
match self.lookup_global(Path::from_single(name, span)) {
Ok(id) => {
let definition = self.interner.definition(id);
if let DefinitionKind::Function(function) = &definition.kind {
let function = *function;
let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes);

let location = Location::new(span, self.file);
let arguments = vec![(Value::TypeDefinition(struct_id), location)];
let result = interpreter.call_function(function, arguments, location);
if let Err(error) = result {
self.errors.push(error.into_compilation_error_pair());
}
} else {
self.push_err(ResolverError::NonFunctionInAnnotation { span });
}
}
Err(_) => self.push_err(ResolverError::UnknownAnnotation { span }),
}
}
}
}

pub fn resolve_struct_fields(
&mut self,
unresolved: NoirStruct,
Expand Down Expand Up @@ -1261,4 +1308,52 @@ impl<'context> Elaborator<'context> {
});
}
}

/// Filters out comptime items from non-comptime items.
/// Returns a pair of (comptime items, non-comptime items)
fn filter_comptime_items(mut items: CollectedItems) -> (CollectedItems, CollectedItems) {
let mut function_sets = Vec::with_capacity(items.functions.len());
let mut comptime_function_sets = Vec::new();

for function_set in items.functions {
let mut functions = Vec::with_capacity(function_set.functions.len());
let mut comptime_functions = Vec::new();

for function in function_set.functions {
if function.2.def.is_comptime {
comptime_functions.push(function);
} else {
functions.push(function);
}
}

let file_id = function_set.file_id;
let self_type = function_set.self_type;
let trait_id = function_set.trait_id;

if !comptime_functions.is_empty() {
comptime_function_sets.push(UnresolvedFunctions {
functions: comptime_functions,
file_id,
trait_id,
self_type: self_type.clone(),
});
}

function_sets.push(UnresolvedFunctions { functions, file_id, trait_id, self_type });
}

let comptime = CollectedItems {
functions: comptime_function_sets,
types: BTreeMap::new(),
type_aliases: BTreeMap::new(),
traits: BTreeMap::new(),
trait_impls: Vec::new(),
globals: Vec::new(),
impls: std::collections::HashMap::new(),
};

items.functions = function_sets;
(comptime, items)
}
}
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl<'context> Elaborator<'context> {
let fields = self.resolve_type_inner(*fields);
Type::FmtString(Box::new(resolved_size), Box::new(fields))
}
Expr => Type::Expr,
Quoted(quoted) => Type::Quoted(quoted),
Unit => Type::Unit,
Unspecified => Type::Error,
Error => Type::Error,
Expand Down Expand Up @@ -1398,7 +1398,7 @@ impl<'context> Elaborator<'context> {
| Type::TypeVariable(_, _)
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::Expr
| Type::Quoted(_)
| Type::Forall(_, _) => (),

Type::TraitAsType(_, _, args) => {
Expand Down
6 changes: 3 additions & 3 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,18 @@ impl<'a> Interpreter<'a> {
"array_len" => builtin::array_len(&arguments),
"as_slice" => builtin::as_slice(arguments),
_ => {
let item = format!("Evaluation for builtin function {builtin}");
let item = format!("Comptime evaluation for builtin function {builtin}");
Err(InterpreterError::Unimplemented { item, location })
}
}
} else if let Some(foreign) = func_attrs.foreign() {
let item = format!("Evaluation for foreign functions like {foreign}");
let item = format!("Comptime evaluation for foreign functions like {foreign}");
Err(InterpreterError::Unimplemented { item, location })
} else if let Some(oracle) = func_attrs.oracle() {
if oracle == "print" {
self.print_oracle(arguments)
} else {
let item = format!("Evaluation for oracle functions like {oracle}");
let item = format!("Comptime evaluation for oracle functions like {oracle}");
Err(InterpreterError::Unimplemented { item, location })
}
} else {
Expand Down
14 changes: 9 additions & 5 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use crate::{
hir_def::expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind},
macros_api::{
Expression, ExpressionKind, HirExpression, HirLiteral, Literal, NodeInterner, Path,
StructId,
},
node_interner::{ExprId, FuncId},
Shared, Type,
QuotedType, Shared, Type,
};
use rustc_hash::FxHashMap as HashMap;

Expand Down Expand Up @@ -42,6 +43,7 @@ pub enum Value {
Array(Vector<Value>, Type),
Slice(Vector<Value>, Type),
Code(Rc<BlockExpression>),
TypeDefinition(StructId),
}

impl Value {
Expand Down Expand Up @@ -70,7 +72,8 @@ impl Value {
Value::Struct(_, typ) => return Cow::Borrowed(typ),
Value::Array(_, typ) => return Cow::Borrowed(typ),
Value::Slice(_, typ) => return Cow::Borrowed(typ),
Value::Code(_) => Type::Expr,
Value::Code(_) => Type::Quoted(QuotedType::Expr),
Value::TypeDefinition(_) => Type::Quoted(QuotedType::TypeDefinition),
Value::Pointer(element) => {
let element = element.borrow().get_type().into_owned();
Type::MutableReference(Box::new(element))
Expand Down Expand Up @@ -172,7 +175,7 @@ impl Value {
ExpressionKind::Literal(Literal::Slice(ArrayLiteral::Standard(elements)))
}
Value::Code(block) => ExpressionKind::Block(unwrap_rc(block)),
Value::Pointer(_) => {
Value::Pointer(_) | Value::TypeDefinition(_) => {
return Err(InterpreterError::CannotInlineMacro { value: self, location })
}
};
Expand Down Expand Up @@ -273,7 +276,7 @@ impl Value {
HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements)))
}
Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)),
Value::Pointer(_) => {
Value::Pointer(_) | Value::TypeDefinition(_) => {
return Err(InterpreterError::CannotInlineMacro { value: self, location })
}
};
Expand Down Expand Up @@ -348,7 +351,8 @@ impl Display for Value {
let values = vecmap(values, ToString::to_string);
write!(f, "&[{}]", values.join(", "))
}
Value::Code(_) => todo!(),
Value::Code(block) => write!(f, "quote {block}"),
Value::TypeDefinition(_) => write!(f, "(type definition)"),
}
}
}
Loading

0 comments on commit 6cbe6a0

Please sign in to comment.