Skip to content

Commit

Permalink
feat: add unsafe blocks for calling unconstrained code from constra…
Browse files Browse the repository at this point in the history
…ined functions (#4429)

# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

## Summary\*

This is an experimental PR to play around with unsafe blocks as a method
to make it clearer when unconstrained values are entering the circuit to
avoid under-constrained errors.

Calls to unconstrained functions will now fail to compile unless inside
of an `unsafe` block. Note that any code inside of an `unsafe` block
still lays down constraints as usual so it's possible to fully constrain
the return values from any unconstrained function calls before returning
a safe value out from the block.


## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [ ] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Ary Borenszweig <asterite@gmail.com>
Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 4ea25db commit 79593b4
Show file tree
Hide file tree
Showing 110 changed files with 943 additions and 406 deletions.
5 changes: 4 additions & 1 deletion aztec_macros/src/utils/parse_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ fn empty_expression(expression: &mut Expression) {
ExpressionKind::Comptime(block_expression, _span) => {
empty_block_expression(block_expression);
}
ExpressionKind::Unsafe(block_expression, _span) => {
empty_block_expression(block_expression);
}
ExpressionKind::Quote(..) | ExpressionKind::Resolved(_) | ExpressionKind::Error => (),
ExpressionKind::AsTraitPath(path) => {
empty_unresolved_type(&mut path.typ);
Expand Down Expand Up @@ -325,7 +328,7 @@ fn empty_unresolved_type(unresolved_type: &mut UnresolvedType) {
empty_unresolved_type(unresolved_type)
}
UnresolvedTypeData::Tuple(unresolved_types) => empty_unresolved_types(unresolved_types),
UnresolvedTypeData::Function(args, ret, _env) => {
UnresolvedTypeData::Function(args, ret, _env, _) => {
empty_unresolved_types(args);
empty_unresolved_type(ret);
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_driver/src/abi_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub(super) fn abi_type_from_hir_type(context: &Context, typ: &Type) -> AbiType {
| Type::Forall(..)
| Type::Quoted(_)
| Type::Slice(_)
| Type::Function(_, _, _) => unreachable!("{typ} cannot be used in the abi"),
| Type::Function(_, _, _, _) => unreachable!("{typ} cannot be used in the abi"),
Type::FmtString(_, _) => unreachable!("format strings cannot be used in the abi"),
Type::MutableReference(_) => unreachable!("&mut cannot be used in the abi"),
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl<'a> FunctionContext<'a> {
}
ast::Type::Unit => panic!("convert_non_tuple_type called on a unit type"),
ast::Type::Tuple(_) => panic!("convert_non_tuple_type called on a tuple: {typ}"),
ast::Type::Function(_, _, _) => Type::Function,
ast::Type::Function(_, _, _, _) => Type::Function,
ast::Type::Slice(_) => panic!("convert_non_tuple_type called on a slice: {typ}"),
ast::Type::MutableReference(element) => {
// Recursive call to panic if element is a tuple
Expand Down
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum ExpressionKind {
Quote(Tokens),
Unquote(Box<Expression>),
Comptime(BlockExpression, Span),
Unsafe(BlockExpression, Span),
AsTraitPath(AsTraitPath),

// This variant is only emitted when inlining the result of comptime
Expand Down Expand Up @@ -587,6 +588,7 @@ impl Display for ExpressionKind {
Lambda(lambda) => lambda.fmt(f),
Parenthesized(sub_expr) => write!(f, "({sub_expr})"),
Comptime(block, _) => write!(f, "comptime {block}"),
Unsafe(block, _) => write!(f, "unsafe {block}"),
Error => write!(f, "Error"),
Resolved(_) => write!(f, "?Resolved"),
Unquote(expr) => write!(f, "$({expr})"),
Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_frontend/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub enum UnresolvedTypeData {
/*args:*/ Vec<UnresolvedType>,
/*ret:*/ Box<UnresolvedType>,
/*env:*/ Box<UnresolvedType>,
/*unconstrained:*/ bool,
),

/// The type of quoted code for metaprogramming
Expand Down Expand Up @@ -222,7 +223,11 @@ impl std::fmt::Display for UnresolvedTypeData {
Bool => write!(f, "bool"),
String(len) => write!(f, "str<{len}>"),
FormatString(len, elements) => write!(f, "fmt<{len}, {elements}"),
Function(args, ret, env) => {
Function(args, ret, env, unconstrained) => {
if *unconstrained {
write!(f, "unconstrained ")?;
}

let args = vecmap(args, ToString::to_string).join(", ");

match &env.as_ref().typ {
Expand Down
10 changes: 7 additions & 3 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ impl StatementKind {
StatementKind::Expression(expr) => {
match (&expr.kind, semi, last_statement_in_block) {
// Semicolons are optional for these expressions
(ExpressionKind::Block(_), semi, _) | (ExpressionKind::If(_), semi, _) => {
(ExpressionKind::Block(_), semi, _)
| (ExpressionKind::Unsafe(..), semi, _)
| (ExpressionKind::If(_), semi, _) => {
if semi.is_some() {
StatementKind::Semi(expr)
} else {
Expand Down Expand Up @@ -745,8 +747,10 @@ impl ForRange {
let block = ExpressionKind::Block(BlockExpression {
statements: vec![let_array, for_loop],
});
let kind = StatementKind::Expression(Expression::new(block, for_loop_span));
Statement { kind, span: for_loop_span }
Statement {
kind: StatementKind::Expression(Expression::new(block, for_loop_span)),
span: for_loop_span,
}
}
}
}
Expand Down
24 changes: 18 additions & 6 deletions compiler/noirc_frontend/src/debug/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ pub fn build_debug_crate_file() -> String {
__debug_var_assign_oracle(var_id, value);
}
pub fn __debug_var_assign<T>(var_id: u32, value: T) {
__debug_var_assign_inner(var_id, value);
unsafe {{
__debug_var_assign_inner(var_id, value);
}}
}
#[oracle(__debug_var_drop)]
Expand All @@ -502,7 +504,9 @@ pub fn build_debug_crate_file() -> String {
__debug_var_drop_oracle(var_id);
}
pub fn __debug_var_drop(var_id: u32) {
__debug_var_drop_inner(var_id);
unsafe {{
__debug_var_drop_inner(var_id);
}}
}
#[oracle(__debug_fn_enter)]
Expand All @@ -511,7 +515,9 @@ pub fn build_debug_crate_file() -> String {
__debug_fn_enter_oracle(fn_id);
}
pub fn __debug_fn_enter(fn_id: u32) {
__debug_fn_enter_inner(fn_id);
unsafe {{
__debug_fn_enter_inner(fn_id);
}}
}
#[oracle(__debug_fn_exit)]
Expand All @@ -520,7 +526,9 @@ pub fn build_debug_crate_file() -> String {
__debug_fn_exit_oracle(fn_id);
}
pub fn __debug_fn_exit(fn_id: u32) {
__debug_fn_exit_inner(fn_id);
unsafe {{
__debug_fn_exit_inner(fn_id);
}}
}
#[oracle(__debug_dereference_assign)]
Expand All @@ -529,7 +537,9 @@ pub fn build_debug_crate_file() -> String {
__debug_dereference_assign_oracle(var_id, value);
}
pub fn __debug_dereference_assign<T>(var_id: u32, value: T) {
__debug_dereference_assign_inner(var_id, value);
unsafe {{
__debug_dereference_assign_inner(var_id, value);
}}
}
"#
.to_string(),
Expand All @@ -553,7 +563,9 @@ pub fn build_debug_crate_file() -> String {
__debug_oracle_member_assign_{n}(var_id, value, {vars});
}}
pub fn __debug_member_assign_{n}<T, Index>(var_id: u32, value: T, {var_sig}) {{
__debug_inner_member_assign_{n}(var_id, value, {vars});
unsafe {{
__debug_inner_member_assign_{n}(var_id, value, {vars});
}}
}}
"#
Expand Down
18 changes: 17 additions & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ impl<'context> Elaborator<'context> {
ExpressionKind::Comptime(comptime, _) => {
return self.elaborate_comptime_block(comptime, expr.span)
}
ExpressionKind::Unsafe(block_expression, _) => {
self.elaborate_unsafe_block(block_expression)
}
ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)),
ExpressionKind::Error => (HirExpression::Error, Type::Error),
ExpressionKind::Unquote(_) => {
Expand Down Expand Up @@ -105,6 +108,19 @@ impl<'context> Elaborator<'context> {
(HirBlockExpression { statements }, block_type)
}

fn elaborate_unsafe_block(&mut self, block: BlockExpression) -> (HirExpression, Type) {
// Before entering the block we cache the old value of `in_unsafe_block` so it can be restored.
let old_in_unsafe_block = self.in_unsafe_block;
self.in_unsafe_block = true;

let (hir_block_expression, typ) = self.elaborate_block_expression(block);

// Finally, we restore the original value of `self.in_unsafe_block`.
self.in_unsafe_block = old_in_unsafe_block;

(HirExpression::Unsafe(hir_block_expression), typ)
}

fn elaborate_literal(&mut self, literal: Literal, span: Span) -> (HirExpression, Type) {
use HirExpression::Literal as Lit;
match literal {
Expand Down Expand Up @@ -710,7 +726,7 @@ impl<'context> Elaborator<'context> {

let captures = lambda_context.captures;
let expr = HirExpression::Lambda(HirLambda { parameters, return_type, body, captures });
(expr, Type::Function(arg_types, Box::new(body_type), Box::new(env_type)))
(expr, Type::Function(arg_types, Box::new(body_type), Box::new(env_type), false))
}

fn elaborate_quote(&mut self, mut tokens: Tokens) -> (HirExpression, Type) {
Expand Down
9 changes: 8 additions & 1 deletion compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub struct Elaborator<'context> {

file: FileId,

in_unsafe_block: bool,
nested_loops: usize,

/// Contains a mapping of the current struct or functions's generics to
Expand Down Expand Up @@ -194,6 +195,7 @@ impl<'context> Elaborator<'context> {
interner,
def_maps,
file: FileId::dummy(),
in_unsafe_block: false,
nested_loops: 0,
generics: Vec::new(),
lambda_stack: Vec::new(),
Expand Down Expand Up @@ -802,7 +804,12 @@ impl<'context> Elaborator<'context> {

let return_type = Box::new(self.resolve_type(func.return_type()));

let mut typ = Type::Function(parameter_types, return_type, Box::new(Type::Unit));
let mut typ = Type::Function(
parameter_types,
return_type,
Box::new(Type::Unit),
func.def.is_unconstrained,
);

if !generics.is_empty() {
typ = Type::Forall(generics, Box::new(typ));
Expand Down
13 changes: 10 additions & 3 deletions compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ impl<'context> Elaborator<'context> {
};

let no_environment = Box::new(Type::Unit);
// TODO: unconstrained
let function_type =
Type::Function(arguments, Box::new(return_type), no_environment);
Type::Function(arguments, Box::new(return_type), no_environment, false);

functions.push(TraitFunction {
name: name.clone(),
Expand Down Expand Up @@ -345,9 +346,15 @@ fn check_function_type_matches_expected_type(
) {
let mut bindings = TypeBindings::new();
// Shouldn't need to unify envs, they should always be equal since they're both free functions
if let (Type::Function(params_a, ret_a, _env_a), Type::Function(params_b, ret_b, _env_b)) =
(expected, actual)
if let (
Type::Function(params_a, ret_a, _env_a, _unconstrained_a),
Type::Function(params_b, ret_b, _env_b, _unconstrained_b),
) = (expected, actual)
{
// TODO: we don't yet allow marking a trait function or a trait impl function as unconstrained,
// so both values will always be false here. Once we support that, we should check that both
// match (adding a test for it).

if params_a.len() == params_b.len() {
for (i, (a, b)) in params_a.iter().zip(params_b.iter()).enumerate() {
if a.try_unify(b, &mut bindings).is_err() {
Expand Down
36 changes: 25 additions & 11 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl<'context> Elaborator<'context> {
Tuple(fields) => {
Type::Tuple(vecmap(fields, |field| self.resolve_type_inner(field, kind)))
}
Function(args, ret, env) => {
Function(args, ret, env, unconstrained) => {
let args = vecmap(args, |arg| self.resolve_type_inner(arg, kind));
let ret = Box::new(self.resolve_type_inner(*ret, kind));

Expand All @@ -139,7 +139,7 @@ impl<'context> Elaborator<'context> {

match *env {
Type::Unit | Type::Tuple(_) | Type::NamedGeneric(_, _, _) => {
Type::Function(args, ret, env)
Type::Function(args, ret, env, unconstrained)
}
_ => {
self.push_err(ResolverError::InvalidClosureEnvironment {
Expand Down Expand Up @@ -735,8 +735,8 @@ impl<'context> Elaborator<'context> {
return Type::Error;
}

for (param, (arg, _, arg_span)) in fn_params.iter().zip(callsite_args) {
self.unify(arg, param, || TypeCheckError::TypeMismatch {
for (param, (arg, arg_expr_id, arg_span)) in fn_params.iter().zip(callsite_args) {
self.unify_with_coercions(arg, param, *arg_expr_id, || TypeCheckError::TypeMismatch {
expected_typ: param.to_string(),
expr_typ: arg.to_string(),
expr_span: *arg_span,
Expand All @@ -763,7 +763,8 @@ impl<'context> Elaborator<'context> {
let ret = self.interner.next_type_variable();
let args = vecmap(args, |(arg, _, _)| arg);
let env_type = self.interner.next_type_variable();
let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type));
let expected =
Type::Function(args, Box::new(ret.clone()), Box::new(env_type), false);

if let Err(error) = binding.try_bind(expected, span) {
self.push_err(error);
Expand All @@ -772,7 +773,7 @@ impl<'context> Elaborator<'context> {
}
// The closure env is ignored on purpose: call arguments never place
// constraints on closure environments.
Type::Function(parameters, ret, _env) => {
Type::Function(parameters, ret, _env, _unconstrained) => {
self.bind_function_type_impl(&parameters, &ret, &args, span)
}
Type::Error => Type::Error,
Expand Down Expand Up @@ -1128,7 +1129,7 @@ impl<'context> Elaborator<'context> {
let (method_type, mut bindings) = method.typ.clone().instantiate(self.interner);

match method_type {
Type::Function(args, _, _) => {
Type::Function(args, _, _, _) => {
// We can cheat a bit and match against only the object type here since no operator
// overload uses other generic parameters or return types aside from the object type.
let expected_object_type = &args[0];
Expand Down Expand Up @@ -1314,9 +1315,22 @@ impl<'context> Elaborator<'context> {

let is_current_func_constrained = self.in_constrained_function();

let is_unconstrained_call = self.is_unconstrained_call(call.func);
let func_type_is_unconstrained =
if let Type::Function(_args, _ret, _env, unconstrained) = &func_type {
*unconstrained
} else {
false
};

let is_unconstrained_call =
func_type_is_unconstrained || self.is_unconstrained_call(call.func);
let crossing_runtime_boundary = is_current_func_constrained && is_unconstrained_call;
if crossing_runtime_boundary {
if !self.in_unsafe_block {
self.push_err(TypeCheckError::Unsafe { span });
return Type::Error;
}

let called_func_id = self
.interner
.lookup_function_from_expr(&call.func)
Expand Down Expand Up @@ -1373,9 +1387,9 @@ impl<'context> Elaborator<'context> {
object: &mut ExprId,
) {
let expected_object_type = match function_type {
Type::Function(args, _, _) => args.first(),
Type::Function(args, _, _, _) => args.first(),
Type::Forall(_, typ) => match typ.as_ref() {
Type::Function(args, _, _) => args.first(),
Type::Function(args, _, _, _) => args.first(),
typ => unreachable!("Unexpected type for function: {typ}"),
},
typ => unreachable!("Unexpected type for function: {typ}"),
Expand Down Expand Up @@ -1577,7 +1591,7 @@ impl<'context> Elaborator<'context> {
}
}

Type::Function(parameters, return_type, _env) => {
Type::Function(parameters, return_type, _env, _unconstrained) => {
for parameter in parameters {
Self::find_numeric_generics_in_type(parameter, found);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ impl HirExpression {
HirExpression::Comptime(block) => {
ExpressionKind::Comptime(block.to_display_ast(interner), span)
}
HirExpression::Unsafe(block) => {
ExpressionKind::Unsafe(block.to_display_ast(interner), span)
}
HirExpression::Quote(block) => ExpressionKind::Quote(block.clone()),

// A macro was evaluated here: return the quoted result
Expand Down Expand Up @@ -340,11 +343,11 @@ impl Type {
let name = Path::from_single(name.as_ref().clone(), Span::default());
UnresolvedTypeData::TraitAsType(name, Vec::new())
}
Type::Function(args, ret, env) => {
Type::Function(args, ret, env, unconstrained) => {
let args = vecmap(args, |arg| arg.to_display_ast());
let ret = Box::new(ret.to_display_ast());
let env = Box::new(env.to_display_ast());
UnresolvedTypeData::Function(args, ret, env)
UnresolvedTypeData::Function(args, ret, env, *unconstrained)
}
Type::MutableReference(element) => {
let element = Box::new(element.to_display_ast());
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> {
HirExpression::Lambda(lambda) => self.evaluate_lambda(lambda, id),
HirExpression::Quote(tokens) => self.evaluate_quote(tokens, id),
HirExpression::Comptime(block) => self.evaluate_block(block),
HirExpression::Unsafe(block) => self.evaluate_block(block),
HirExpression::Unquote(tokens) => {
// An Unquote expression being found is indicative of a macro being
// expanded within another comptime fn which we don't currently support.
Expand Down
Loading

0 comments on commit 79593b4

Please sign in to comment.