Skip to content

Commit

Permalink
refactor builtins
Browse files Browse the repository at this point in the history
  • Loading branch information
katat committed Apr 8, 2024
1 parent b650d9b commit 8947526
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 98 deletions.
91 changes: 33 additions & 58 deletions src/imports.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{collections::HashMap, fmt};

use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};

use crate::{
Expand Down Expand Up @@ -90,65 +91,39 @@ const ASSERT_FN: &str = "assert(condition: Bool)";
const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)";
const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]";

// pub enum BuiltInFunctions<B: Backend> {
// Assert(FnInfo<B>),
// AssertEq(FnInfo<B>),
// }

// impl<B: Backend> BuiltInFunctions<B> {
// pub fn fn_info(&self) -> &FnInfo<B> {
// match self {
// BuiltInFunctions::Assert(fn_info) => fn_info,
// BuiltInFunctions::AssertEq(fn_info) => fn_info,
// }
// }

// // TODO: cache the functions, so it won't need to rerun this code that is unnecesasry
// pub fn functions() -> Vec<BuiltInFunctions<B>> {
// // TODO: this makes the code difficult to maintain. there are probably better ways to do this.
// let fn_names = [ASSERT_FN, ASSERT_EQ_FN];

// // create a collection of FnInfo from fn_names
// fn_names.iter().map(|fn_name| {
// BuiltInFunctions::<B>::from_str(fn_name).unwrap()
// })
// .collect::<Vec<BuiltInFunctions<B>>>()
// }
// }

// TODO: can we make this a generic global variable via macro?
#[derive(Debug, Serialize, Deserialize)]
pub struct Builtins<B>
where
B: Backend,
{
pub functions: HashMap<String, FnInfo<B>>,
}
pub static BUILTIN_FNS_SIGS: Lazy<HashMap<&'static str, FnSig>> = Lazy::new(|| {
let sigs = [ASSERT_FN, ASSERT_EQ_FN, POSEIDON_FN];

impl<B: Backend> Builtins<B> {
pub fn new() -> Self {
let mut functions = HashMap::new();
// create a hashmap from the FnSig
sigs.iter().map(|sig| {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, sig).unwrap();
let fn_sig = FnSig::parse(ctx, &mut tokens).unwrap();

(sig.to_owned(), fn_sig)
}).collect()
});

pub fn get_builtin_fn<B>(name: &str) -> FnInfo<B> where B: Backend {
let ctx = &mut ParserCtx::default();
let mut tokens = Token::parse(0, name).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

let fn_handle = match name {
ASSERT_FN => B::assert,
ASSERT_EQ_FN => B::assert_eq,
POSEIDON_FN => B::poseidon,
_ => unreachable!(),
};

FnInfo { kind: FnKind::BuiltIn(sig, fn_handle), span: Span::default() }
}

let fn_sigs = [
(ASSERT_FN, B::assert as FnHandle<B>),
(ASSERT_EQ_FN, B::assert_eq as FnHandle<B>),
(POSEIDON_FN, B::poseidon as FnHandle<B>),
];

for (sig, fn_ptr) in fn_sigs.iter() {
let mut tokens = Token::parse(0, sig).unwrap();
let sig = FnSig::parse(ctx, &mut tokens).unwrap();

functions.insert(
sig.name.value.clone(),
FnInfo {
kind: FnKind::BuiltIn(sig, *fn_ptr),
span: Span::default(),
},
);
}

Self { functions }
}
/// a function iterate through builtin functions
pub fn builtin_fns<B: Backend>() -> Vec<FnInfo<B>> {
BUILTIN_FNS_SIGS.iter().map(|(sig, _)| get_builtin_fn::<B>(sig)).collect()
}

pub fn has_builtin_fn(name: &str) -> bool {
BUILTIN_FNS_SIGS.iter().any(|(_, s)| s.name.value == name)
}
4 changes: 0 additions & 4 deletions src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
cli::packages::UserRepo,
constants::Span,
error::{Error, ErrorKind, Result},
imports::Builtins,
parser::{
types::{FnArg, FnSig, FuncOrMethod, ModulePath, Stmt, StmtKind, TyKind},
ConstDef, CustomType, FunctionDef, StructDef, UsePath,
Expand All @@ -19,8 +18,6 @@ pub struct NameResCtx<B: Backend> {
/// maps `module` to its original `use a::module`
pub modules: HashMap<String, UsePath>,

pub builtins: Builtins<B>,

phantom: std::marker::PhantomData<B>,
}

Expand All @@ -29,7 +26,6 @@ impl<B: Backend> NameResCtx<B> {
Self {
this_module,
modules: HashMap::new(),
builtins: Builtins::new(),
phantom: std::marker::PhantomData,
}
}
Expand Down
8 changes: 2 additions & 6 deletions src/name_resolution/expr.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use crate::{
backends::Backend,
cli::packages::UserRepo,
error::Result,
parser::{types::ModulePath, CustomType, Expr, ExprKind},
stdlib::QUALIFIED_BUILTINS,
backends::Backend, cli::packages::UserRepo, error::Result, imports::{has_builtin_fn, BUILTIN_FNS_SIGS}, parser::{types::ModulePath, CustomType, Expr, ExprKind}, stdlib::QUALIFIED_BUILTINS
};

use super::context::NameResCtx;
Expand All @@ -23,7 +19,7 @@ impl<B: Backend> NameResCtx<B> {
args,
} => {
if matches!(module, ModulePath::Local)
&& self.builtins.functions.contains_key(&fn_name.value)
&& has_builtin_fn(fn_name.value.as_str())
{
// if it's a builtin, use `std::builtin`
*module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS));
Expand Down
7 changes: 3 additions & 4 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
backends::Backend,
constants::Span,
error::{Error, ErrorKind, Result},
imports::Builtins,
lexer::{Keyword, Token, TokenKind, Tokens},
};

Expand Down Expand Up @@ -152,7 +151,7 @@ impl<B: Backend> AST<B> {
TokenKind::Keyword(Keyword::Fn) => {
function_observed = true;

let func = FunctionDef::parse::<B>(ctx, &mut tokens, Builtins::<B>::new())?;
let func = FunctionDef::parse(ctx, &mut tokens)?;
ast.push(Root {
kind: RootKind::FunctionDef(func),
span: token.span,
Expand Down Expand Up @@ -192,7 +191,7 @@ impl<B: Backend> AST<B> {
//
#[cfg(test)]
mod tests {
use crate::{backends::kimchi::Kimchi, parser::types::Stmt};
use crate::{parser::types::Stmt};

use super::*;

Expand All @@ -201,7 +200,7 @@ mod tests {
let code = r#"main(pub public_input: [Fel; 3], private_input: [Fel; 3]) -> [Fel; 3] { return public_input; }"#;
let tokens = &mut Token::parse(0, code).unwrap();
let ctx = &mut ParserCtx::default();
let parsed = FunctionDef::parse::<Kimchi>(ctx, tokens, Builtins::<Kimchi>::new()).unwrap();
let parsed = FunctionDef::parse(ctx, tokens).unwrap();
println!("{:?}", parsed);
}

Expand Down
12 changes: 4 additions & 8 deletions src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ use serde::{Deserialize, Serialize};
use std::hash::{Hash, Hasher};

use crate::{
backends::Backend,
cli::packages::UserRepo,
constants::{Field, Span},
error::{ErrorKind, Result},
imports::Builtins,
imports::{has_builtin_fn, BUILTIN_FNS_SIGS},
lexer::{Keyword, Token, TokenKind, Tokens},
syntax::is_type,
};
Expand Down Expand Up @@ -792,11 +791,7 @@ impl FunctionDef {
}

/// Parse a function, without the `fn` keyword.
pub fn parse<B: Backend>(
ctx: &mut ParserCtx,
tokens: &mut Tokens,
buitlins: Builtins<B>,
) -> Result<Self> {
pub fn parse(ctx: &mut ParserCtx, tokens: &mut Tokens) -> Result<Self> {
// ghetto way of getting the span of the function: get the span of the first token (name), then try to get the span of the last token
let mut span = tokens
.peek()
Expand All @@ -812,7 +807,8 @@ impl FunctionDef {
let sig = FnSig::parse(ctx, tokens)?;

// make sure that it doesn't shadow a builtin
if buitlins.functions.contains_key(&sig.name.value) {
// TODO: better to compare the whole FnSig?
if has_builtin_fn(&sig.name.value) {
return Err(ctx.error(
ErrorKind::ShadowingBuiltIn(sig.name.value.clone()),
sig.name.span,
Expand Down
40 changes: 22 additions & 18 deletions src/type_checker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
use std::{collections::HashMap, str::FromStr};
use std::collections::HashMap;

use crate::{
backends::Backend, cli::packages::UserRepo, constants::{Field, Span}, error::{Error, ErrorKind, Result}, helpers::PrettyField, imports::{Builtins, FnKind}, name_resolution::NAST, parser::{
backends::Backend,
cli::packages::UserRepo,
constants::{Field, Span},
error::{Error, ErrorKind, Result},
imports::{builtin_fns, FnKind},
name_resolution::NAST,
parser::{
types::{FuncOrMethod, FunctionDef, ModulePath, RootKind, Ty, TyKind},
CustomType, Expr, StructDef,
}, stdlib::{QUALIFIED_BUILTINS}
},
stdlib::QUALIFIED_BUILTINS,
};

pub use checker::{FnInfo, StructInfo};
pub use fn_env::{TypeInfo, TypedFnEnv};

use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use strum::IntoEnumIterator;

pub mod checker;
pub mod fn_env;
Expand All @@ -21,7 +27,10 @@ const RESERVED_ARGS: [&str; 1] = ["public_output"];

#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstInfo<F> where F: Field {
pub struct ConstInfo<F>
where
F: Field,
{
#[serde_as(as = "crate::serialization::SerdeAs")]
pub value: Vec<F>,
pub typ: Ty,
Expand Down Expand Up @@ -54,7 +63,10 @@ impl FullyQualified {

/// The environment we use to type check a noname program.
#[derive(Debug, Serialize, Deserialize)]
pub struct TypeChecker<B> where B: Backend {
pub struct TypeChecker<B>
where
B: Backend,
{
/// the functions present in the scope
/// contains at least the set of builtin functions (like assert_eq)
functions: HashMap<FullyQualified, FnInfo<B>>,
Expand All @@ -69,8 +81,6 @@ pub struct TypeChecker<B> where B: Backend {
/// This can be used by the circuit-writer when it needs type information.
// TODO: I think we should get rid of this if we can
node_types: HashMap<usize, TyKind>,

builtins: Builtins<B>,
}

impl<B: Backend> TypeChecker<B> {
Expand All @@ -88,12 +98,7 @@ impl<B: Backend> TypeChecker<B> {
}

pub(crate) fn fn_info(&self, qualified: &FullyQualified) -> Option<&FnInfo<B>> {
if qualified.module == Some(UserRepo::new("std/builtins")) {

self.builtins.functions.get(&qualified.name)
} else {
self.functions.get(qualified)
}
self.functions.get(qualified)
}

pub(crate) fn const_info(&self, qualified: &FullyQualified) -> Option<&ConstInfo<B::Field>> {
Expand Down Expand Up @@ -129,20 +134,18 @@ impl<B: Backend> TypeChecker<B> {
impl<B: Backend> TypeChecker<B> {
// TODO: we can probably lazy const this
pub fn new() -> Self {
let builtins = Builtins::<B>::new();
let mut type_checker = Self {
functions: HashMap::new(),
structs: HashMap::new(),
constants: HashMap::new(),
node_types: HashMap::new(),
builtins,
};

// initialize it with the builtins
let builtin_module = ModulePath::Absolute(UserRepo::new(QUALIFIED_BUILTINS));
for fn_info in type_checker.builtins.functions.values() {

for fn_info in builtin_fns() {
let qualified = FullyQualified::new(&builtin_module, &fn_info.sig().name.value);
println!("inserting builtin: {:?}", qualified);
if type_checker
.functions
.insert(qualified, fn_info.clone())
Expand Down Expand Up @@ -186,6 +189,7 @@ impl<B: Backend> TypeChecker<B> {
let mut abort = None;

for root in &nast.ast.0 {
println!("root: {:?}", root);
match &root.kind {
RootKind::ConstDef(cst) => {
// important: no struct or function definition must appear before a constant declaration
Expand Down

0 comments on commit 8947526

Please sign in to comment.