Skip to content

Commit

Permalink
Support hint function declaration to bind with builtin function (#199)
Browse files Browse the repository at this point in the history
* support unsafe/hint attributes
  • Loading branch information
katat authored Oct 24, 2024
1 parent 28a9e67 commit 78ff4ba
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ impl<B: Backend> CircuitWriter<B> {
module,
fn_name,
args,
..
} => {
// sanity check
if fn_name.value == "main" {
Expand Down
9 changes: 9 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ pub enum ErrorKind {
#[error("function `{0}` not present in scope (did you misspell it?)")]
UndefinedFunction(String),

#[error("hint function `{0}` signature is missing its corresponding builtin function")]
MissingHintMapping(String),

#[error("function name `{0}` is already in use by a variable present in the scope")]
FunctionNameInUsebyVariable(String),

Expand All @@ -249,6 +252,12 @@ pub enum ErrorKind {
#[error("attribute not recognized: `{0:?}`")]
InvalidAttribute(AttributeKind),

#[error("unsafe attribute is needed to call a hint function. eg: `unsafe fn foo()`")]
ExpectedUnsafeAttribute,

#[error("unsafe attribute should only be applied to hint function calls")]
UnexpectedUnsafeAttribute,

#[error("A return value is not used")]
UnusedReturnValue,

Expand Down
8 changes: 8 additions & 0 deletions src/lexer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ pub enum Keyword {
Use,
/// A function
Fn,
/// A hint function
Hint,
/// Attribute required for hint functions
Unsafe,
/// New variable
Let,
/// Public input
Expand Down Expand Up @@ -75,6 +79,8 @@ impl Keyword {
match s {
"use" => Some(Self::Use),
"fn" => Some(Self::Fn),
"hint" => Some(Self::Hint),
"unsafe" => Some(Self::Unsafe),
"let" => Some(Self::Let),
"pub" => Some(Self::Pub),
"return" => Some(Self::Return),
Expand All @@ -97,6 +103,8 @@ impl Display for Keyword {
let desc = match self {
Self::Use => "use",
Self::Fn => "fn",
Self::Hint => "hint",
Self::Unsafe => "unsafe",
Self::Let => "let",
Self::Pub => "pub",
Self::Return => "return",
Expand Down
7 changes: 7 additions & 0 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ fn monomorphize_expr<B: Backend>(
module,
fn_name,
args,
unsafe_attr,
} => {
// compute the observed arguments types
let mut observed = Vec::with_capacity(args.len());
Expand Down Expand Up @@ -548,6 +549,7 @@ fn monomorphize_expr<B: Backend>(
module: module.clone(),
fn_name: resolved_sig.name,
args: args_mono,
unsafe_attr: *unsafe_attr,
},
);
let resolved_sig = &fn_info.sig().generics.resolved_sig;
Expand All @@ -568,6 +570,7 @@ fn monomorphize_expr<B: Backend>(
module: module.clone(),
fn_name: fn_name_mono.clone(),
args: args_mono,
unsafe_attr: *unsafe_attr,
},
);

Expand Down Expand Up @@ -611,6 +614,7 @@ fn monomorphize_expr<B: Backend>(
let fn_kind = FnKind::Native(method_type.clone());
let mut fn_info = FnInfo {
kind: fn_kind,
is_hint: false,
span: method_type.span,
};

Expand Down Expand Up @@ -1215,6 +1219,7 @@ pub fn instantiate_fn_call<B: Backend>(
let func_def = match fn_info.kind {
FnKind::BuiltIn(_, handle) => FnInfo {
kind: FnKind::BuiltIn(sig_typed, handle),
is_hint: fn_info.is_hint,
span: fn_info.span,
},
FnKind::Native(fn_def) => {
Expand All @@ -1226,7 +1231,9 @@ pub fn instantiate_fn_call<B: Backend>(
sig: sig_typed,
body: stmts_typed,
span: fn_def.span,
is_hint: fn_def.is_hint,
}),
is_hint: fn_info.is_hint,
span: fn_info.span,
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl NameResCtx {
}

pub(crate) fn resolve_fn_def(&self, fn_def: &mut FunctionDef) -> Result<()> {
let FunctionDef { sig, body, span: _ } = fn_def;
let FunctionDef { sig, body, .. } = fn_def;

//
// signature
Expand Down
1 change: 1 addition & 0 deletions src/name_resolution/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl NameResCtx {
module,
fn_name,
args,
unsafe_attr: _,
} => {
if matches!(module, ModulePath::Local)
&& BUILTIN_FN_NAMES.contains(&fn_name.value.as_str())
Expand Down
125 changes: 112 additions & 13 deletions src/negative_tests.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
use crate::{
backends::r1cs::{R1csBn254Field, R1CS},
circuit_writer::CircuitWriter,
backends::{
r1cs::{R1csBn254Field, R1CS},
Backend,
},
circuit_writer::{CircuitWriter, VarInfo},
compiler::{get_nast, typecheck_next_file_inner, Sources},
constants::Span,
error::{ErrorKind, Result},
imports::FnKind,
lexer::Token,
mast::Mast,
name_resolution::NAST,
type_checker::TypeChecker,
parser::{
types::{FnSig, GenericParameters},
ParserCtx,
},
type_checker::{FnInfo, FullyQualified, TypeChecker},
var::Var,
witness::CompiledCircuit,
};

fn nast_pass(code: &str) -> Result<(NAST<R1CS<R1csBn254Field>>, usize)> {
type R1csBackend = R1CS<R1csBn254Field>;

fn nast_pass(code: &str) -> Result<(NAST<R1csBackend>, usize)> {
let mut source = Sources::new();
let res = get_nast(
get_nast(
None,
&mut source,
"example.no".to_string(),
code.to_string(),
0,
);

res
)
}

fn tast_pass(code: &str) -> (Result<usize>, TypeChecker<R1CS<R1csBn254Field>>, Sources) {
fn tast_pass(code: &str) -> (Result<usize>, TypeChecker<R1csBackend>, Sources) {
let mut source = Sources::new();
let mut tast = TypeChecker::<R1CS<R1csBn254Field>>::new();
let mut tast = TypeChecker::<R1csBackend>::new();
let res = typecheck_next_file_inner(
&mut tast,
None,
Expand All @@ -37,12 +48,12 @@ fn tast_pass(code: &str) -> (Result<usize>, TypeChecker<R1CS<R1csBn254Field>>, S
(res, tast, source)
}

fn mast_pass(code: &str) -> Result<Mast<R1CS<R1csBn254Field>>> {
fn mast_pass(code: &str) -> Result<Mast<R1csBackend>> {
let (_, tast, _) = tast_pass(code);
crate::mast::monomorphize(tast)
}

fn synthesizer_pass(code: &str) -> Result<CompiledCircuit<R1CS<R1csBn254Field>>> {
fn synthesizer_pass(code: &str) -> Result<CompiledCircuit<R1csBackend>> {
let mast = mast_pass(code);
CircuitWriter::generate_circuit(mast?, R1CS::new())
}
Expand Down Expand Up @@ -510,6 +521,94 @@ fn test_generic_missing_parenthesis() {
"#;

let res = nast_pass(code).err();
println!("{:?}", res);
assert!(matches!(res.unwrap().kind, ErrorKind::MissingParenthesis));
}
fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result<usize> {
let mut source = Sources::new();
let mut tast = TypeChecker::<R1csBackend>::new();
// mock a builtin function
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, "calc(val: Field) -> Field;").unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

fn mocked_builtin_fn<B: Backend>(
_: &mut CircuitWriter<B>,
_: &GenericParameters,
_: &[VarInfo<B::Field, B::Var>],
_: Span,
) -> Result<Option<Var<B::Field, B::Var>>> {
Ok(None)
}

let fn_info = FnInfo {
kind: FnKind::BuiltIn(sig, mocked_builtin_fn::<R1csBackend>),
is_hint: true,
span: Span::default(),
};

// add the mocked builtin function
// note that this should happen in the tast phase, instead of mast phase.
// currently this function is the only way to mock a builtin function.
tast.add_monomorphized_fn(qualified.clone(), fn_info);

typecheck_next_file_inner(
&mut tast,
None,
&mut source,
"example.no".to_string(),
code.to_string(),
0,
)
}

#[test]
fn test_hint_call_missing_unsafe() {
let qualified = FullyQualified {
module: None,
name: "calc".to_string(),
};

let valid_code = r#"
hint fn calc(val: Field) -> Field;
fn main(pub xx: Field) {
let yy = unsafe calc(xx);
}
"#;

let res = test_hint_builtin_fn(&qualified, valid_code);
assert!(res.is_ok());

let invalid_code = r#"
hint fn calc(val: Field) -> Field;
fn main(pub xx: Field) {
let yy = calc(xx);
}
"#;

let res = test_hint_builtin_fn(&qualified, invalid_code);
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::ExpectedUnsafeAttribute
));
}

#[test]
fn test_nonhint_call_with_unsafe() {
let code = r#"
fn calc(val: Field) -> Field {
return val + 1;
}
fn main(pub xx: Field) {
let yy = unsafe calc(xx);
}
"#;

let res = tast_pass(code).0;
assert!(matches!(
res.unwrap_err().kind,
ErrorKind::UnexpectedUnsafeAttribute
));
}
17 changes: 17 additions & 0 deletions src/parser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub enum ExprKind {
module: ModulePath,
fn_name: Ident,
args: Vec<Expr>,
unsafe_attr: bool,
},

/// `lhs.method_name(args)`
Expand Down Expand Up @@ -379,6 +380,21 @@ impl Expr {
}
}

TokenKind::Keyword(Keyword::Unsafe) => {
let mut fn_call = Expr::parse(ctx, tokens)?;
// should be FnCall
match &mut fn_call.kind {
ExprKind::FnCall { unsafe_attr, .. } => {
*unsafe_attr = true;
}
_ => {
return Err(ctx.error(ErrorKind::InvalidExpression, fn_call.span));
}
};

fn_call
}

// unrecognized pattern
_ => {
return Err(ctx.error(ErrorKind::InvalidExpression, token.span));
Expand Down Expand Up @@ -576,6 +592,7 @@ impl Expr {
module,
fn_name,
args,
unsafe_attr: false,
},
span,
)
Expand Down
20 changes: 20 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ impl<B: Backend> AST<B> {
});
}

// `hint fn calc() { }`
TokenKind::Keyword(Keyword::Hint) => {
// expect fn token
tokens.bump_expected(ctx, TokenKind::Keyword(Keyword::Fn))?;

function_observed = true;

let func = FunctionDef::parse_hint(ctx, &mut tokens)?;

// expect ;, as the hint function is an empty function wired with a builtin.
// todo: later these hint functions will be migrated from builtins to native functions
// then it will expect a function block instead of ;
tokens.bump_expected(ctx, TokenKind::SemiColon)?;

ast.push(Root {
kind: RootKind::FunctionDef(func),
span: token.span,
});
}

// `struct Foo { a: Field, b: Field }`
TokenKind::Keyword(Keyword::Struct) => {
let s = StructDef::parse(ctx, &mut tokens)?;
Expand Down
31 changes: 30 additions & 1 deletion src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ impl Attribute {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub is_hint: bool,
pub sig: FnSig,
pub body: Vec<Stmt>,
pub span: Span,
Expand Down Expand Up @@ -1181,7 +1182,35 @@ impl FunctionDef {
));
}

let func = Self { sig, body, span };
let func = Self {
sig,
body,
span,
is_hint: false,
};

Ok(func)
}

/// Parse a hint function signature
pub fn parse_hint(ctx: &mut ParserCtx, tokens: &mut Tokens) -> Result<Self> {
// parse signature
let sig = FnSig::parse(ctx, tokens)?;
let span = sig.name.span;

// make sure that it doesn't shadow a builtin
if BUILTIN_FN_NAMES.contains(&sig.name.value.as_ref()) {
return Err(ctx.error(ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), span));
}

// for now the body is empty.
// this will be changed once the native hint is implemented.
let func = Self {
sig,
body: vec![],
span,
is_hint: true,
};

Ok(func)
}
Expand Down
Loading

0 comments on commit 78ff4ba

Please sign in to comment.