Skip to content

Commit

Permalink
feat: Insert trait impls into the program from type annotations (#5327)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves the vertical slice from this comment
#4594 (comment)

## Summary\*

This PR lets us actually insert the generated trait impl from the called
macro into the program. Currently we only support inserting trait impls,
functions, and globals to keep things simple. Each time these are
inserted we have to call the def collection code on them so I wanted to
avoid adding them all at once in a larger PR.

## Additional Context

We can now generate impls for simple traits on a type! See the
`derive_impl` test for details. The call site currently looks like this:
```rs
#[derive_default]
struct Foo {
    x: Field,
    y: Bar,
}

#[derive_default]
struct Bar {}

fn main() {
    let _foo: Foo = Default::default();
}
```
If `Bar` doesn't also derive `Default` the error that is issued is in
the code to derive the impl unfortunately:
```
error: No matching impl found for `Bar: Default`
   ┌─ /.../derive_impl/src/main.nr:33:50
   │
33 │         result = result.push_back(quote { $name: Default::default(), });
   │                                                  ---------------- No impl for `Bar: Default`
   │
```
Since we only support unquoting a few items at top-level currently, here
is what it looks like when we try to unquote a different item. In this
case, a non-trait impl:
```
error: Unsupported statement type to unquote
   ┌─ /.../derive_impl/src/main.nr:23:1
   │  
23 │ ╭ #[derive_default]
24 │ │ struct Foo {
25 │ │     x: Field,
26 │ │     y: Bar,
27 │ │ }
   │ ╰─' Only functions, globals, and trait impls can be unquoted here
   │  
   = Unquoted item was:
     impl Foo {
         Attributes { function: None, secondary: [] }
         fn bar(self: Self) -> Self {
             self
         }
     }
```

## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
jfecher and TomAFrench authored Jun 26, 2024
1 parent 083070e commit efdd818
Show file tree
Hide file tree
Showing 12 changed files with 298 additions and 97 deletions.
18 changes: 14 additions & 4 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,22 +824,32 @@ impl Display for FunctionDefinition {
writeln!(f, "{:?}", self.attributes)?;

let parameters = vecmap(&self.parameters, |Param { visibility, pattern, typ, span: _ }| {
format!("{pattern}: {visibility} {typ}")
if *visibility == Visibility::Public {
format!("{pattern}: {visibility} {typ}")
} else {
format!("{pattern}: {typ}")
}
});

let where_clause = vecmap(&self.where_clause, ToString::to_string);
let where_clause_str = if !where_clause.is_empty() {
format!("where {}", where_clause.join(", "))
format!(" where {}", where_clause.join(", "))
} else {
"".to_string()
};

let return_type = if matches!(&self.return_type, FunctionReturnType::Default(_)) {
String::new()
} else {
format!(" -> {}", self.return_type)
};

write!(
f,
"fn {}({}) -> {} {} {}",
"fn {}({}){}{} {}",
self.name,
parameters.join(", "),
self.return_type,
return_type,
where_clause_str,
self.body
)
Expand Down
169 changes: 143 additions & 26 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use std::{
use crate::{
ast::{FunctionKind, UnresolvedTraitConstraint},
hir::{
comptime::{self, Interpreter, Value},
comptime::{self, Interpreter, InterpreterError, Value},
def_collector::{
dc_crate::{
filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal,
UnresolvedStruct, UnresolvedTypeAlias,
},
dc_mod,
errors::DuplicateType,
},
resolution::{errors::ResolverError, path_resolver::PathResolver, resolver::LambdaContext},
Expand All @@ -31,6 +32,7 @@ use crate::{
node_interner::{
DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, TraitId, TypeAliasId,
},
parser::TopLevelStatement,
Shared, Type, TypeVariable,
};
use crate::{
Expand Down Expand Up @@ -229,7 +231,7 @@ impl<'context> Elaborator<'context> {
}

// Must resolve structs before we resolve globals.
self.collect_struct_definitions(items.types);
let generated_items = self.collect_struct_definitions(items.types);

self.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);

Expand All @@ -255,6 +257,16 @@ impl<'context> Elaborator<'context> {
self.elaborate_global(global);
}

// After everything is collected, we can elaborate our generated items.
// It may be better to inline these within `items` entirely since elaborating them
// all here means any globals will not see these. Inlining them completely within `items`
// means we must be more careful about missing any additional items that need to be already
// elaborated. E.g. if a new struct is created, we've already passed the code path to
// elaborate them.
if !generated_items.is_empty() {
self.elaborate_items(generated_items);
}

for functions in items.functions {
self.elaborate_functions(functions);
}
Expand Down Expand Up @@ -1147,11 +1159,18 @@ impl<'context> Elaborator<'context> {
self.generics.clear();
}

fn collect_struct_definitions(&mut self, structs: BTreeMap<StructId, UnresolvedStruct>) {
fn collect_struct_definitions(
&mut self,
structs: BTreeMap<StructId, UnresolvedStruct>,
) -> CollectedItems {
// This is necessary to avoid cloning the entire struct map
// when adding checks after each struct field is resolved.
let struct_ids = structs.keys().copied().collect::<Vec<_>>();

// This will contain any additional top-level items that are generated at compile-time
// via macros. This often includes derived trait impls.
let mut generated_items = CollectedItems::default();

// Resolve each field in each struct.
// Each struct should already be present in the NodeInterner after def collection.
for (type_id, mut typ) in structs {
Expand Down Expand Up @@ -1188,7 +1207,7 @@ impl<'context> Elaborator<'context> {
}
});

self.run_comptime_attributes_on_struct(attributes, type_id, span);
self.run_comptime_attributes_on_struct(attributes, type_id, span, &mut generated_items);
}

// Check whether the struct fields have nested slices
Expand All @@ -1210,43 +1229,64 @@ impl<'context> Elaborator<'context> {
}
}
}

generated_items
}

fn run_comptime_attributes_on_struct(
&mut self,
attributes: Vec<SecondaryAttribute>,
struct_id: StructId,
span: Span,
generated_items: &mut CollectedItems,
) {
for attribute in attributes {
if let SecondaryAttribute::Custom(name) = attribute {
match self.lookup_global(Path::from_single(name, span)) {
Ok(id) => {
let definition = self.interner.definition(id);
if let DefinitionKind::Function(function) = &definition.kind {
let function = *function;
let mut interpreter = Interpreter::new(
self.interner,
&mut self.comptime_scopes,
self.crate_id,
);

let location = Location::new(span, self.file);
let arguments = vec![(Value::TypeDefinition(struct_id), location)];
let result = interpreter.call_function(function, arguments, location);
if let Err(error) = result {
self.errors.push(error.into_compilation_error_pair());
}
} else {
self.push_err(ResolverError::NonFunctionInAnnotation { span });
}
}
Err(_) => self.push_err(ResolverError::UnknownAnnotation { span }),
if let Err(error) =
self.run_comptime_attribute_on_struct(name, struct_id, span, generated_items)
{
self.errors.push(error);
}
}
}
}

fn run_comptime_attribute_on_struct(
&mut self,
attribute: String,
struct_id: StructId,
span: Span,
generated_items: &mut CollectedItems,
) -> Result<(), (CompilationError, FileId)> {
let id = self
.lookup_global(Path::from_single(attribute, span))
.map_err(|_| (ResolverError::UnknownAnnotation { span }.into(), self.file))?;

let definition = self.interner.definition(id);
let DefinitionKind::Function(function) = definition.kind else {
return Err((ResolverError::NonFunctionInAnnotation { span }.into(), self.file));
};
let mut interpreter =
Interpreter::new(self.interner, &mut self.comptime_scopes, self.crate_id);

let location = Location::new(span, self.file);
let arguments = vec![(Value::TypeDefinition(struct_id), location)];

let value = interpreter
.call_function(function, arguments, location)
.map_err(|error| error.into_compilation_error_pair())?;

if value != Value::Unit {
let item = value
.into_top_level_item(location)
.map_err(|error| error.into_compilation_error_pair())?;

self.add_item(item, generated_items, location);
}

Ok(())
}

pub fn resolve_struct_fields(
&mut self,
unresolved: NoirStruct,
Expand Down Expand Up @@ -1460,4 +1500,81 @@ impl<'context> Elaborator<'context> {
items.functions = function_sets;
(comptime, items)
}

fn add_item(
&mut self,
item: TopLevelStatement,
generated_items: &mut CollectedItems,
location: Location,
) {
match item {
TopLevelStatement::Function(function) => {
let id = self.interner.push_empty_fn();
let module = self.module_id();
self.interner.push_function(id, &function.def, module, location);
let functions = vec![(self.local_module, id, function)];
generated_items.functions.push(UnresolvedFunctions {
file_id: self.file,
functions,
trait_id: None,
self_type: None,
});
}
TopLevelStatement::TraitImpl(mut trait_impl) => {
let methods = dc_mod::collect_trait_impl_functions(
self.interner,
&mut trait_impl,
self.crate_id,
self.file,
self.local_module,
);

generated_items.trait_impls.push(UnresolvedTraitImpl {
file_id: self.file,
module_id: self.local_module,
trait_generics: trait_impl.trait_generics,
trait_path: trait_impl.trait_name,
object_type: trait_impl.object_type,
methods,
generics: trait_impl.impl_generics,
where_clause: trait_impl.where_clause,

// These last fields are filled in later
trait_id: None,
impl_id: None,
resolved_object_type: None,
resolved_generics: Vec::new(),
resolved_trait_generics: Vec::new(),
});
}
TopLevelStatement::Global(global) => {
let (global, error) = dc_mod::collect_global(
self.interner,
self.def_maps.get_mut(&self.crate_id).unwrap(),
global,
self.file,
self.local_module,
);

generated_items.globals.push(global);
if let Some(error) = error {
self.errors.push(error);
}
}
// Assume that an error has already been issued
TopLevelStatement::Error => (),

TopLevelStatement::Module(_)
| TopLevelStatement::Import(_)
| TopLevelStatement::Struct(_)
| TopLevelStatement::Trait(_)
| TopLevelStatement::Impl(_)
| TopLevelStatement::TypeAlias(_)
| TopLevelStatement::SubModule(_) => {
let item = item.to_string();
let error = InterpreterError::UnsupportedTopLevelItemUnquote { item, location };
self.errors.push(error.into_compilation_error_pair());
}
}
}
}
13 changes: 12 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum InterpreterError {
CannotInlineMacro { value: Value, location: Location },
UnquoteFoundDuringEvaluation { location: Location },
FailedToParseMacro { error: ParserError, tokens: Rc<Tokens>, rule: &'static str, file: FileId },
UnsupportedTopLevelItemUnquote { item: String, location: Location },
NonComptimeFnCallInSameCrate { function: String, location: Location },

Unimplemented { item: String, location: Location },
Expand Down Expand Up @@ -102,6 +103,7 @@ impl InterpreterError {
| InterpreterError::NonStructInConstructor { location, .. }
| InterpreterError::CannotInlineMacro { location, .. }
| InterpreterError::UnquoteFoundDuringEvaluation { location, .. }
| InterpreterError::UnsupportedTopLevelItemUnquote { location, .. }
| InterpreterError::NonComptimeFnCallInSameCrate { location, .. }
| InterpreterError::Unimplemented { location, .. }
| InterpreterError::BreakNotInLoop { location, .. }
Expand Down Expand Up @@ -261,7 +263,8 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic {
CustomDiagnostic::simple_error(msg, String::new(), location.span)
}
InterpreterError::CannotInlineMacro { value, location } => {
let msg = "Cannot inline value into runtime code if it contains references".into();
let typ = value.get_type();
let msg = format!("Cannot inline values of type `{typ}` into this position");
let secondary = format!("Cannot inline value {value:?}");
CustomDiagnostic::simple_error(msg, secondary, location.span)
}
Expand Down Expand Up @@ -295,6 +298,14 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic {
diagnostic.add_note(push_the_problem_on_the_library_author);
diagnostic
}
InterpreterError::UnsupportedTopLevelItemUnquote { item, location } => {
let msg = "Unsupported statement type to unquote".into();
let secondary =
"Only functions, globals, and trait impls can be unquoted here".into();
let mut error = CustomDiagnostic::simple_error(msg, secondary, location.span);
error.add_note(format!("Unquoted item was:\n{item}"));
error
}
InterpreterError::NonComptimeFnCallInSameCrate { function, location } => {
let msg = format!("`{function}` cannot be called in a `comptime` context here");
let secondary =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ fn type_def_as_type(

let struct_def = interner.get_struct(type_def);
let struct_def = struct_def.borrow();
let make_token = |name| SpannedToken::new(Token::Str(name), span);
let make_token = |name| SpannedToken::new(Token::Ident(name), span);

let mut tokens = vec![make_token(struct_def.name.to_string())];

Expand Down Expand Up @@ -111,7 +111,7 @@ fn type_def_generics(
.generics
.iter()
.map(|generic| {
let name = SpannedToken::new(Token::Str(generic.type_var.borrow().to_string()), span);
let name = SpannedToken::new(Token::Ident(generic.type_var.borrow().to_string()), span);
Value::Code(Rc::new(Tokens(vec![name])))
})
.collect();
Expand All @@ -137,7 +137,7 @@ fn type_def_fields(
let struct_def = interner.get_struct(type_def);
let struct_def = struct_def.borrow();

let make_token = |name| SpannedToken::new(Token::Str(name), span);
let make_token = |name| SpannedToken::new(Token::Ident(name), span);
let make_quoted = |tokens| Value::Code(Rc::new(Tokens(tokens)));

let mut fields = im::Vector::new();
Expand Down
20 changes: 19 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
StructId,
},
node_interner::{ExprId, FuncId},
parser,
parser::{self, NoirParser, TopLevelStatement},
token::{SpannedToken, Token, Tokens},
QuotedType, Shared, Type,
};
Expand Down Expand Up @@ -319,13 +319,31 @@ impl Value {
_ => None,
}
}

pub(crate) fn into_top_level_item(self, location: Location) -> IResult<TopLevelStatement> {
match self {
Value::Code(tokens) => parse_tokens(tokens, parser::top_level_item(), location.file),
value => Err(InterpreterError::CannotInlineMacro { value, location }),
}
}
}

/// Unwraps an Rc value without cloning the inner value if the reference count is 1. Clones otherwise.
pub(crate) fn unwrap_rc<T: Clone>(rc: Rc<T>) -> T {
Rc::try_unwrap(rc).unwrap_or_else(|rc| (*rc).clone())
}

fn parse_tokens<T>(tokens: Rc<Tokens>, parser: impl NoirParser<T>, file: fm::FileId) -> IResult<T> {
match parser.parse(tokens.as_ref().clone()) {
Ok(expr) => Ok(expr),
Err(mut errors) => {
let error = errors.swap_remove(0);
let rule = "an expression";
Err(InterpreterError::FailedToParseMacro { error, file, tokens, rule })
}
}
}

impl Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
Loading

0 comments on commit efdd818

Please sign in to comment.