Skip to content

Commit

Permalink
Merge pull request #187 from zksecurity/fix/unnecessary-generic-arr
Browse files Browse the repository at this point in the history
Fix function instantiation
  • Loading branch information
katat authored Sep 23, 2024
2 parents 1711beb + 8fa2dba commit 34e0d16
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 68 deletions.
12 changes: 11 additions & 1 deletion examples/array.no
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@ fn init() -> [Field; size] { // array type can depends on constant var
return [4; size]; // array init with constant var
}

fn init_concrete() -> [Field; 3] {
// as this function won't be monomorphized,
// this is to test this array is constructed as Array instead of GenericSizedArray.
let mut arr = [0; 3];
for idx in 0..3 {
arr[idx] = idx + 1;
}
return arr;
}

fn main(pub public_input: [Field; 2]) {
let xx = [1, 2, 3];
let xx = init_concrete();

assert_eq(public_input[0], xx[0]);
assert_eq(public_input[1], xx[1]);
Expand Down
13 changes: 13 additions & 0 deletions examples/fixture/asm/kimchi/generic_nested_func.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@ noname.0.7.0

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
(0,0) -> (4,0)
(1,0) -> (5,0)
(2,0) -> (6,0)
(3,0) -> (4,1) -> (5,1) -> (6,1)
13 changes: 13 additions & 0 deletions examples/fixture/asm/kimchi/generic_nested_method.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@ noname.0.7.0

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
(0,0) -> (4,0)
(1,0) -> (5,0)
(2,0) -> (6,0)
(3,0) -> (4,1) -> (5,1) -> (6,1)
5 changes: 5 additions & 0 deletions examples/fixture/asm/r1cs/generic_nested_func.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@ noname.0.7.0

v_4 == (v_1) * (1)
v_4 == (v_2) * (1)
v_4 == (v_3) * (1)
5 changes: 5 additions & 0 deletions examples/fixture/asm/r1cs/generic_nested_method.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@ noname.0.7.0

v_4 == (v_1) * (1)
v_4 == (v_2) * (1)
v_4 == (v_3) * (1)
17 changes: 17 additions & 0 deletions examples/generic_nested_func.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
fn nested_func(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}

fn mod_arr(val: Field) -> [Field; 3] {
// this generic function should be instantiated
let mut result = nested_func(3);
for idx in 0..3 {
result[idx] = val;
}
return result;
}

fn main(pub val: Field) -> [Field; 3] {
let result = mod_arr(val);
return result;
}
21 changes: 21 additions & 0 deletions examples/generic_nested_method.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
struct Thing {
xx: Field,
}

fn Thing.nested_func(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}

fn Thing.mod_arr(self) -> [Field; 3] {
// this generic function should be instantiated
let mut result = self.nested_func(3);
for idx in 0..3 {
result[idx] = self.xx;
}
return result;
}

fn main(pub val: Field) -> [Field; 3] {
let thing = Thing {xx: val};
return thing.mod_arr();
}
17 changes: 9 additions & 8 deletions src/mast/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ use super::MastCtx;
impl Expr {
/// Convert an expression to another expression, with the same span and a regenerated node id.
pub fn to_mast<B: Backend>(&self, ctx: &mut MastCtx<B>, kind: &ExprKind) -> Expr {
if !ctx.in_generic_func {
return self.clone();
}

Expr {
node_id: ctx.next_node_id(),
kind: kind.clone(),
..self.clone()
match ctx.generic_func_scope {
// not in any generic function scope
Some(0) => self.clone(),
// in a generic function scope
_ => Expr {
node_id: ctx.next_node_id(),
kind: kind.clone(),
..self.clone()
},
}
}
}
105 changes: 51 additions & 54 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use num_bigint::BigUint;
use num_traits::ToPrimitive;
use std::collections::HashMap;

use crate::{
Expand Down Expand Up @@ -150,7 +151,7 @@ where
B: Backend,
{
tast: TypeChecker<B>,
in_generic_func: bool,
generic_func_scope: Option<usize>,
// fully qualified function name
functions_to_delete: Vec<FullyQualified>,
// fully qualified struct name, method name
Expand All @@ -161,7 +162,7 @@ impl<B: Backend> MastCtx<B> {
pub fn new(tast: TypeChecker<B>) -> Self {
Self {
tast,
in_generic_func: false,
generic_func_scope: Some(0),
functions_to_delete: vec![],
methods_to_delete: vec![],
}
Expand All @@ -174,11 +175,11 @@ impl<B: Backend> MastCtx<B> {
}

pub fn start_monomorphize_func(&mut self) {
self.in_generic_func = true;
self.generic_func_scope = Some(self.generic_func_scope.unwrap() + 1);
}

pub fn finish_monomorphize_func(&mut self) {
self.in_generic_func = false;
self.generic_func_scope = Some(self.generic_func_scope.unwrap() - 1);
}

pub fn add_monomorphized_fn(
Expand All @@ -187,8 +188,11 @@ impl<B: Backend> MastCtx<B> {
new_qualified: FullyQualified,
fn_info: FnInfo<B>,
) {
self.tast.add_monomorphized_fn(new_qualified, fn_info);
self.functions_to_delete.push(old_qualified);
self.tast
.add_monomorphized_fn(new_qualified.clone(), fn_info);
if new_qualified != old_qualified {
self.functions_to_delete.push(old_qualified);
}
}

pub fn add_monomorphized_method(
Expand All @@ -200,8 +204,11 @@ impl<B: Backend> MastCtx<B> {
) {
self.tast
.add_monomorphized_method(struct_qualified.clone(), method_name, fn_info);
self.methods_to_delete
.push((struct_qualified, old_method_name.to_string()));

if method_name != old_method_name {
self.methods_to_delete
.push((struct_qualified, old_method_name.to_string()));
}
}

pub fn clear_generic_fns(&mut self) {
Expand Down Expand Up @@ -411,30 +418,23 @@ fn monomorphize_expr<B: Backend>(
.to_owned();

// monomorphize the function call
let (mexpr, typ) = if fn_info.sig().require_monomorphization() {
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;

let args_mono = observed.clone().into_iter().map(|e| e.expr).collect();
let args_mono = observed.clone().into_iter().map(|e| e.expr).collect();

let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::FnCall {
module: module.clone(),
fn_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let qualified = FullyQualified::new(module, &fn_name_mono.value);
ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono);

(mexpr, typ)
} else {
// otherwise, reuse the expression node and the computed type
(expr.clone(), ctx.tast.expr_type(expr).cloned())
let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::FnCall {
module: module.clone(),
fn_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let qualified = FullyQualified::new(module, &fn_name_mono.value);
ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono);

// assume the function call won't return constant value
ExprMonoInfo::new(mexpr, typ, None)
}
Expand Down Expand Up @@ -491,29 +491,22 @@ fn monomorphize_expr<B: Backend>(
}

// monomorphize the function call
let (mexpr, typ) = if fn_info.sig().require_monomorphization() {
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;

let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::MethodCall {
lhs: Box::new(lhs_mono.expr),
method_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let fn_def = fn_info_mono.native();
ctx.tast
.add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def);
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;

(mexpr, typ)
} else {
// otherwise, reuse the expression node and the computed type
(expr.clone(), ctx.tast.expr_type(expr).cloned())
let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::MethodCall {
lhs: Box::new(lhs_mono.expr),
method_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let fn_def = fn_info_mono.native();
ctx.tast
.add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def);

// assume the function call won't return constant value
ExprMonoInfo::new(mexpr, typ, None)
}
Expand Down Expand Up @@ -566,8 +559,12 @@ fn monomorphize_expr<B: Backend>(
| Op2::BoolOr => lhs_mono.typ,
};

let cst = match (lhs_mono.constant, rhs_mono.constant) {
(Some(lhs), Some(rhs)) => match op {
let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono;
let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono;

// fold constants
let cst = match (&lhs_expr.kind, &rhs_expr.kind) {
(ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op {
Op2::Addition => Some(lhs + rhs),
Op2::Subtraction => Some(lhs - rhs),
Op2::Multiplication => Some(lhs * rhs),
Expand All @@ -579,18 +576,18 @@ fn monomorphize_expr<B: Backend>(

match cst {
Some(v) => {
let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(v)));
let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone()));

ExprMonoInfo::new(mexpr, typ, Some(v))
ExprMonoInfo::new(mexpr, typ, v.to_u32())
}
None => {
let mexpr = expr.to_mast(
ctx,
&ExprKind::BinaryOp {
op: op.clone(),
protected: *protected,
lhs: Box::new(lhs_mono.expr),
rhs: Box::new(rhs_mono.expr),
lhs: Box::new(lhs_expr),
rhs: Box::new(rhs_expr),
},
);

Expand Down
36 changes: 36 additions & 0 deletions src/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,39 @@ fn test_generic_iterator(#[case] backend: BackendKind) -> miette::Result<()> {

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_generic_nested_func(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"val":"1"}"#;
let private_inputs = r#"{}"#;

test_file(
"generic_nested_func",
public_inputs,
private_inputs,
vec!["1", "1", "1"],
backend,
)?;

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_generic_nested_method(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"val":"1"}"#;
let private_inputs = r#"{}"#;

test_file(
"generic_nested_method",
public_inputs,
private_inputs,
vec!["1", "1", "1"],
backend,
)?;

Ok(())
}
17 changes: 12 additions & 5 deletions src/type_checker/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,18 @@ impl<B: Backend> TypeChecker<B> {
.expect("expected a typed size");

if is_numeric(&size_node.typ) {
// use generic array as the size node might include generic parameters or constant vars
let res = ExprTyInfo::new_anon(TyKind::GenericSizedArray(
Box::new(item_node.typ),
Symbolic::parse(size)?,
));
let sym = Symbolic::parse(size)?;
let res = if let Symbolic::Concrete(size) = sym {
// if sym is a concrete variant, then just return concrete array type
ExprTyInfo::new_anon(TyKind::Array(Box::new(item_node.typ), size))
} else {
// use generic array as the size node might include generic parameters or constant vars
ExprTyInfo::new_anon(TyKind::GenericSizedArray(
Box::new(item_node.typ),
sym,
))
};

Some(res)
} else {
return Err(self.error(ErrorKind::InvalidArraySize, expr.span));
Expand Down

0 comments on commit 34e0d16

Please sign in to comment.