From 78ff4ba874a67bcf4e6acd7eb831609fbc524683 Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Thu, 24 Oct 2024 18:47:32 +0800 Subject: [PATCH] Support hint function declaration to bind with builtin function (#199) * support unsafe/hint attributes --- src/circuit_writer/writer.rs | 1 + src/error.rs | 9 +++ src/lexer/mod.rs | 8 +++ src/mast/mod.rs | 7 ++ src/name_resolution/context.rs | 2 +- src/name_resolution/expr.rs | 1 + src/negative_tests.rs | 125 +++++++++++++++++++++++++++++---- src/parser/expr.rs | 17 +++++ src/parser/mod.rs | 20 ++++++ src/parser/types.rs | 31 +++++++- src/stdlib/mod.rs | 1 + src/type_checker/checker.rs | 16 +++++ src/type_checker/mod.rs | 50 +++++++++++++ 13 files changed, 273 insertions(+), 15 deletions(-) diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 924c57b6f..ec1a744c7 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -379,6 +379,7 @@ impl CircuitWriter { module, fn_name, args, + .. } => { // sanity check if fn_name.value == "main" { diff --git a/src/error.rs b/src/error.rs index d04f6f40b..19f2e1d9b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -240,6 +240,9 @@ pub enum ErrorKind { #[error("function `{0}` not present in scope (did you misspell it?)")] UndefinedFunction(String), + #[error("hint function `{0}` signature is missing its corresponding builtin function")] + MissingHintMapping(String), + #[error("function name `{0}` is already in use by a variable present in the scope")] FunctionNameInUsebyVariable(String), @@ -249,6 +252,12 @@ pub enum ErrorKind { #[error("attribute not recognized: `{0:?}`")] InvalidAttribute(AttributeKind), + #[error("unsafe attribute is needed to call a hint function. eg: `unsafe fn foo()`")] + ExpectedUnsafeAttribute, + + #[error("unsafe attribute should only be applied to hint function calls")] + UnexpectedUnsafeAttribute, + #[error("A return value is not used")] UnusedReturnValue, diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 89784b213..21d634fc3 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -44,6 +44,10 @@ pub enum Keyword { Use, /// A function Fn, + /// A hint function + Hint, + /// Attribute required for hint functions + Unsafe, /// New variable Let, /// Public input @@ -75,6 +79,8 @@ impl Keyword { match s { "use" => Some(Self::Use), "fn" => Some(Self::Fn), + "hint" => Some(Self::Hint), + "unsafe" => Some(Self::Unsafe), "let" => Some(Self::Let), "pub" => Some(Self::Pub), "return" => Some(Self::Return), @@ -97,6 +103,8 @@ impl Display for Keyword { let desc = match self { Self::Use => "use", Self::Fn => "fn", + Self::Hint => "hint", + Self::Unsafe => "unsafe", Self::Let => "let", Self::Pub => "pub", Self::Return => "return", diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 23b208512..f16a2cbac 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -518,6 +518,7 @@ fn monomorphize_expr( module, fn_name, args, + unsafe_attr, } => { // compute the observed arguments types let mut observed = Vec::with_capacity(args.len()); @@ -548,6 +549,7 @@ fn monomorphize_expr( module: module.clone(), fn_name: resolved_sig.name, args: args_mono, + unsafe_attr: *unsafe_attr, }, ); let resolved_sig = &fn_info.sig().generics.resolved_sig; @@ -568,6 +570,7 @@ fn monomorphize_expr( module: module.clone(), fn_name: fn_name_mono.clone(), args: args_mono, + unsafe_attr: *unsafe_attr, }, ); @@ -611,6 +614,7 @@ fn monomorphize_expr( let fn_kind = FnKind::Native(method_type.clone()); let mut fn_info = FnInfo { kind: fn_kind, + is_hint: false, span: method_type.span, }; @@ -1215,6 +1219,7 @@ pub fn instantiate_fn_call( let func_def = match fn_info.kind { FnKind::BuiltIn(_, handle) => FnInfo { kind: FnKind::BuiltIn(sig_typed, handle), + is_hint: fn_info.is_hint, span: fn_info.span, }, FnKind::Native(fn_def) => { @@ -1226,7 +1231,9 @@ pub fn instantiate_fn_call( sig: sig_typed, body: stmts_typed, span: fn_def.span, + is_hint: fn_def.is_hint, }), + is_hint: fn_info.is_hint, span: fn_info.span, } } diff --git a/src/name_resolution/context.rs b/src/name_resolution/context.rs index 59a5efad6..d4d196eea 100644 --- a/src/name_resolution/context.rs +++ b/src/name_resolution/context.rs @@ -66,7 +66,7 @@ impl NameResCtx { } pub(crate) fn resolve_fn_def(&self, fn_def: &mut FunctionDef) -> Result<()> { - let FunctionDef { sig, body, span: _ } = fn_def; + let FunctionDef { sig, body, .. } = fn_def; // // signature diff --git a/src/name_resolution/expr.rs b/src/name_resolution/expr.rs index 292afeb97..d2630fe00 100644 --- a/src/name_resolution/expr.rs +++ b/src/name_resolution/expr.rs @@ -20,6 +20,7 @@ impl NameResCtx { module, fn_name, args, + unsafe_attr: _, } => { if matches!(module, ModulePath::Local) && BUILTIN_FN_NAMES.contains(&fn_name.value.as_str()) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 183228832..66fa0291b 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -1,30 +1,41 @@ use crate::{ - backends::r1cs::{R1csBn254Field, R1CS}, - circuit_writer::CircuitWriter, + backends::{ + r1cs::{R1csBn254Field, R1CS}, + Backend, + }, + circuit_writer::{CircuitWriter, VarInfo}, compiler::{get_nast, typecheck_next_file_inner, Sources}, + constants::Span, error::{ErrorKind, Result}, + imports::FnKind, + lexer::Token, mast::Mast, name_resolution::NAST, - type_checker::TypeChecker, + parser::{ + types::{FnSig, GenericParameters}, + ParserCtx, + }, + type_checker::{FnInfo, FullyQualified, TypeChecker}, + var::Var, witness::CompiledCircuit, }; -fn nast_pass(code: &str) -> Result<(NAST>, usize)> { +type R1csBackend = R1CS; + +fn nast_pass(code: &str) -> Result<(NAST, usize)> { let mut source = Sources::new(); - let res = get_nast( + get_nast( None, &mut source, "example.no".to_string(), code.to_string(), 0, - ); - - res + ) } -fn tast_pass(code: &str) -> (Result, TypeChecker>, Sources) { +fn tast_pass(code: &str) -> (Result, TypeChecker, Sources) { let mut source = Sources::new(); - let mut tast = TypeChecker::>::new(); + let mut tast = TypeChecker::::new(); let res = typecheck_next_file_inner( &mut tast, None, @@ -37,12 +48,12 @@ fn tast_pass(code: &str) -> (Result, TypeChecker>, S (res, tast, source) } -fn mast_pass(code: &str) -> Result>> { +fn mast_pass(code: &str) -> Result> { let (_, tast, _) = tast_pass(code); crate::mast::monomorphize(tast) } -fn synthesizer_pass(code: &str) -> Result>> { +fn synthesizer_pass(code: &str) -> Result> { let mast = mast_pass(code); CircuitWriter::generate_circuit(mast?, R1CS::new()) } @@ -510,6 +521,94 @@ fn test_generic_missing_parenthesis() { "#; let res = nast_pass(code).err(); - println!("{:?}", res); assert!(matches!(res.unwrap().kind, ErrorKind::MissingParenthesis)); } +fn test_hint_builtin_fn(qualified: &FullyQualified, code: &str) -> Result { + let mut source = Sources::new(); + let mut tast = TypeChecker::::new(); + // mock a builtin function + let ctx = &mut ParserCtx::default(); + let mut tokens = Token::parse(0, "calc(val: Field) -> Field;").unwrap(); + let sig = FnSig::parse(ctx, &mut tokens).unwrap(); + + fn mocked_builtin_fn( + _: &mut CircuitWriter, + _: &GenericParameters, + _: &[VarInfo], + _: Span, + ) -> Result>> { + Ok(None) + } + + let fn_info = FnInfo { + kind: FnKind::BuiltIn(sig, mocked_builtin_fn::), + is_hint: true, + span: Span::default(), + }; + + // add the mocked builtin function + // note that this should happen in the tast phase, instead of mast phase. + // currently this function is the only way to mock a builtin function. + tast.add_monomorphized_fn(qualified.clone(), fn_info); + + typecheck_next_file_inner( + &mut tast, + None, + &mut source, + "example.no".to_string(), + code.to_string(), + 0, + ) +} + +#[test] +fn test_hint_call_missing_unsafe() { + let qualified = FullyQualified { + module: None, + name: "calc".to_string(), + }; + + let valid_code = r#" + hint fn calc(val: Field) -> Field; + + fn main(pub xx: Field) { + let yy = unsafe calc(xx); + } + "#; + + let res = test_hint_builtin_fn(&qualified, valid_code); + assert!(res.is_ok()); + + let invalid_code = r#" + hint fn calc(val: Field) -> Field; + + fn main(pub xx: Field) { + let yy = calc(xx); + } + "#; + + let res = test_hint_builtin_fn(&qualified, invalid_code); + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ExpectedUnsafeAttribute + )); +} + +#[test] +fn test_nonhint_call_with_unsafe() { + let code = r#" + fn calc(val: Field) -> Field { + return val + 1; + } + + fn main(pub xx: Field) { + let yy = unsafe calc(xx); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::UnexpectedUnsafeAttribute + )); +} diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 9fed71992..d070290bb 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -58,6 +58,7 @@ pub enum ExprKind { module: ModulePath, fn_name: Ident, args: Vec, + unsafe_attr: bool, }, /// `lhs.method_name(args)` @@ -379,6 +380,21 @@ impl Expr { } } + TokenKind::Keyword(Keyword::Unsafe) => { + let mut fn_call = Expr::parse(ctx, tokens)?; + // should be FnCall + match &mut fn_call.kind { + ExprKind::FnCall { unsafe_attr, .. } => { + *unsafe_attr = true; + } + _ => { + return Err(ctx.error(ErrorKind::InvalidExpression, fn_call.span)); + } + }; + + fn_call + } + // unrecognized pattern _ => { return Err(ctx.error(ErrorKind::InvalidExpression, token.span)); @@ -576,6 +592,7 @@ impl Expr { module, fn_name, args, + unsafe_attr: false, }, span, ) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index da765cbe3..511354478 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -158,6 +158,26 @@ impl AST { }); } + // `hint fn calc() { }` + TokenKind::Keyword(Keyword::Hint) => { + // expect fn token + tokens.bump_expected(ctx, TokenKind::Keyword(Keyword::Fn))?; + + function_observed = true; + + let func = FunctionDef::parse_hint(ctx, &mut tokens)?; + + // expect ;, as the hint function is an empty function wired with a builtin. + // todo: later these hint functions will be migrated from builtins to native functions + // then it will expect a function block instead of ; + tokens.bump_expected(ctx, TokenKind::SemiColon)?; + + ast.push(Root { + kind: RootKind::FunctionDef(func), + span: token.span, + }); + } + // `struct Foo { a: Field, b: Field }` TokenKind::Keyword(Keyword::Struct) => { let s = StructDef::parse(ctx, &mut tokens)?; diff --git a/src/parser/types.rs b/src/parser/types.rs index 428b7dfd6..0691f1356 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -709,6 +709,7 @@ impl Attribute { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionDef { + pub is_hint: bool, pub sig: FnSig, pub body: Vec, pub span: Span, @@ -1181,7 +1182,35 @@ impl FunctionDef { )); } - let func = Self { sig, body, span }; + let func = Self { + sig, + body, + span, + is_hint: false, + }; + + Ok(func) + } + + /// Parse a hint function signature + pub fn parse_hint(ctx: &mut ParserCtx, tokens: &mut Tokens) -> Result { + // parse signature + let sig = FnSig::parse(ctx, tokens)?; + let span = sig.name.span; + + // make sure that it doesn't shadow a builtin + if BUILTIN_FN_NAMES.contains(&sig.name.value.as_ref()) { + return Err(ctx.error(ErrorKind::ShadowingBuiltIn(sig.name.value.clone()), span)); + } + + // for now the body is empty. + // this will be changed once the native hint is implemented. + let func = Self { + sig, + body: vec![], + span, + is_hint: true, + }; Ok(func) } diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index f2ad66a8e..150d31517 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -72,6 +72,7 @@ trait Module { let sig = FnSig::parse(ctx, &mut tokens).unwrap(); res.push(FnInfo { kind: FnKind::BuiltIn(sig, fn_handle), + is_hint: false, span: Span::default(), }); } diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 06216332b..75900ec5b 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -25,6 +25,11 @@ where B: Backend, { pub kind: FnKind, + // TODO: We will remove this once the native hint is supported + // This field is to indicate if a builtin function should be treated as a hint. + // instead of adding this flag to the FnKind::Builtin, we add this field to the FnInfo. + // Then this flag will only present in the FunctionDef. + pub is_hint: bool, pub span: Span, } @@ -130,6 +135,7 @@ impl TypeChecker { module, fn_name, args, + unsafe_attr, } => { // retrieve the function signature let qualified = FullyQualified::new(&module, &fn_name.value); @@ -141,6 +147,16 @@ impl TypeChecker { })?; let fn_sig = fn_info.sig().clone(); + // check if the function is a hint + if fn_info.is_hint && !unsafe_attr { + return Err(self.error(ErrorKind::ExpectedUnsafeAttribute, expr.span)); + } + + // unsafe attribute should only be used on hints + if !fn_info.is_hint && *unsafe_attr { + return Err(self.error(ErrorKind::UnexpectedUnsafeAttribute, expr.span)); + } + // check if generic is allowed if fn_sig.require_monomorphization() && typed_fn_env.is_in_forloop() { for (observed_arg, expected_arg) in args.iter().zip(fn_sig.arguments.iter()) { diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index a2b7469ef..50308a376 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -312,9 +312,59 @@ impl TypeChecker { } // save the function in the typed global env + + if function.is_hint { + // convert to builtin function + let qualified = match &function.sig.kind { + FuncOrMethod::Function(module) => { + FullyQualified::new(module, &function.sig.name.value) + } + FuncOrMethod::Method(_) => unreachable!("methods are not supported"), + }; + + // this will override the builtin function in the global env + let builtin_fn = self + .functions + .get(&qualified) + .ok_or_else(|| { + Error::new( + "type-checker", + ErrorKind::MissingHintMapping(qualified.name.clone()), + function.span, + ) + })? + .kind + .clone(); + + // check it is a builtin function + let fn_handle = match builtin_fn { + FnKind::BuiltIn(_, fn_handle) => fn_handle, + _ => { + return Err(Error::new( + "type-checker", + ErrorKind::UnexpectedError("expected builtin function"), + function.span, + )) + } + }; + + // override the builtin function as a hint function + self.functions.insert( + qualified, + FnInfo { + is_hint: true, + kind: FnKind::BuiltIn(function.sig.clone(), fn_handle), + span: function.span, + }, + ); + + continue; + }; + let fn_kind = FnKind::Native(function.clone()); let fn_info = FnInfo { kind: fn_kind, + is_hint: function.is_hint, span: function.span, };