Skip to content

Commit 8c4fd1b

Browse files
committed
Add array_fold builtin function
1 parent 0ecec39 commit 8c4fd1b

File tree

11 files changed

+233
-5
lines changed

11 files changed

+233
-5
lines changed

book/src/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
- [Match Expression](./match_expression.md)
1313
- [Functions](./function.md)
1414
- [Programs](./program.md)
15+
- [Builtins](./builtins.md)

book/src/builtins.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Builtin functions
2+
3+
## Bounded loop
4+
5+
Run a function repeatedly with a bounded counter. The loop stops early when the function returns a successful value.
6+
7+
- Signature: `for_while::<f>(initial_accumulator: A, readonly_context: C) -> Either<B, A>`
8+
- Loop body: `fn f(acc: A, ctx: C, counter: uN) -> Either<B, A>` where `N ∈ {1, 2, 4, 8, 16}`
9+
10+
Example: stop when `counter == 10`.
11+
12+
```rust
13+
fn stop_at_10(acc: (), _: (), i: u8) -> Either<u8, ()> {
14+
match jet::eq_8(i, 10) {
15+
true => Left(i), // success → exit loop
16+
false => Right(acc), // continue with same accumulator
17+
}
18+
}
19+
20+
fn main() {
21+
let out: Either<u8, ()> = for_while::<stop_at_10>((), ());
22+
assert!(jet::eq_8(10, unwrap_left::<()>(out)));
23+
}
24+
```
25+
26+
## List folding
27+
28+
Fold a list of bounded length by repeatedly applying a function.
29+
30+
- Signature: `fold::<f, N>(list: List<E, N>, initial_accumulator: A) -> A`
31+
- Fold step: `fn f(element: E, acc: A) -> A`
32+
- Note: `N` is a power of two; lists hold fewer than `N` elements.
33+
34+
Example: sum a list of 32-bit integers.
35+
36+
```rust
37+
fn sum(elt: u32, acc: u32) -> u32 {
38+
let (_, acc): (bool, u32) = jet::add_32(elt, acc);
39+
acc
40+
}
41+
42+
fn main() {
43+
let xs: List<u32, 8> = list![1, 2, 3];
44+
let s: u32 = fold::<sum, 8>(xs, 0);
45+
assert!(jet::eq_32(s, 6));
46+
}
47+
```
48+
49+
## Array folding
50+
51+
Fold a fixed-size array by repeatedly applying a function.
52+
53+
- Signature: `array_fold::<f, N>(array: [E; N], initial_accumulator: A) -> A`
54+
- Fold step: `fn f(element: E, acc: A) -> A`
55+
56+
Example: sum an array of 7 elements.
57+
58+
```rust
59+
fn sum(elt: u32, acc: u32) -> u32 {
60+
let (_, acc): (bool, u32) = jet::add_32(elt, acc);
61+
acc
62+
}
63+
64+
fn main() {
65+
let arr: [u32; 7] = [1, 2, 3, 4, 5, 6, 7];
66+
let sum: u32 = array_fold::<sum, 7>(arr, 0);
67+
assert!(jet::eq_32(sum, 28));
68+
}
69+
```

examples/array_fold.simf

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
fn sum(elt: u32, acc: u32) -> u32 {
2+
let (_, acc): (bool, u32) = jet::add_32(elt, acc);
3+
acc
4+
}
5+
6+
fn main() {
7+
let arr: [u32; 7] = [1, 2, 3, 4, 5, 6, 7];
8+
let sum: u32 = array_fold::<sum, 7>(arr, 0);
9+
assert!(jet::eq_32(sum, 28));
10+
}

src/ast.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::collections::hash_map::Entry;
22
use std::collections::HashMap;
3+
use std::num::NonZeroUsize;
34
use std::str::FromStr;
45
use std::sync::Arc;
56

@@ -275,6 +276,8 @@ pub enum CallName {
275276
Custom(CustomFunction),
276277
/// Fold of a bounded list with the given function.
277278
Fold(CustomFunction, NonZeroPow2Usize),
279+
/// Fold of an array with the given function.
280+
ArrayFold(CustomFunction, NonZeroUsize),
278281
/// Loop over the given function a bounded number of times until it returns success.
279282
ForWhile(CustomFunction, Pow2Usize),
280283
}
@@ -1187,6 +1190,26 @@ impl AbstractSyntaxTree for Call {
11871190
check_output_type(out_ty, ty).with_span(from)?;
11881191
analyze_arguments(from.args(), &args_ty, scope)?
11891192
}
1193+
CallName::ArrayFold(function, size) => {
1194+
// An array fold has the signature:
1195+
// array_fold::<f, N>(array: [E; N], initial_accumulator: A) -> A
1196+
// where
1197+
// fn f(element: E, accumulator: A) -> A
1198+
let element_ty = function.params().first().expect("foldable function").ty();
1199+
let array_ty = ResolvedType::array(element_ty.clone(), size.get());
1200+
let accumulator_ty = function
1201+
.params()
1202+
.get(1)
1203+
.expect("foldable function")
1204+
.ty()
1205+
.clone();
1206+
let args_ty = [array_ty, accumulator_ty];
1207+
1208+
check_argument_types(from.args(), &args_ty).with_span(from)?;
1209+
let out_ty = function.body().ty();
1210+
check_output_type(out_ty, ty).with_span(from)?;
1211+
analyze_arguments(from.args(), &args_ty, scope)?
1212+
}
11901213
CallName::ForWhile(function, _bit_width) => {
11911214
// A for-while loop has the signature:
11921215
// for_while::<f>(initial_accumulator: A, readonly_context: C) -> Either<B, A>
@@ -1262,6 +1285,21 @@ impl AbstractSyntaxTree for CallName {
12621285
.map(Self::Custom)
12631286
.ok_or(Error::FunctionUndefined(name.clone()))
12641287
.with_span(from),
1288+
parse::CallName::ArrayFold(name, size) => {
1289+
let function = scope
1290+
.get_function(name)
1291+
.cloned()
1292+
.ok_or(Error::FunctionUndefined(name.clone()))
1293+
.with_span(from)?;
1294+
// A function that is used in a array fold has the signature:
1295+
// fn f(element: E, accumulator: A) -> A
1296+
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
1297+
{
1298+
Err(Error::FunctionNotFoldable(name.clone())).with_span(from)
1299+
} else {
1300+
Ok(Self::ArrayFold(function, *size))
1301+
}
1302+
}
12651303
parse::CallName::Fold(name, bound) => {
12661304
let function = scope
12671305
.get_function(name)

src/builtins.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use std::num::NonZeroUsize;
2+
3+
use simplicity::node::CoreConstructible;
4+
5+
use crate::{named::CoreExt, ProgNode};
6+
7+
/// Fold an array of size `size` elements using function `f`.
8+
///
9+
/// Function `f: E × A → A`
10+
/// takes an array element of type `E` and an accumulator of type `A`,
11+
/// and it produces an updated accumulator of type `A`.
12+
///
13+
/// The fold `(fold f)_n : E^n × A → A`
14+
/// takes the array of type `E^n` and an initial accumulator of type `A`,
15+
/// and it produces the final accumulator of type `A`.
16+
pub fn array_fold(size: NonZeroUsize, f: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
17+
/// Recursively fold the array using the precomputed folding functions.
18+
fn tree_fold(
19+
n: usize,
20+
f_powers_of_two: &Vec<ProgNode>,
21+
) -> Result<ProgNode, simplicity::types::Error> {
22+
if n == 1 {
23+
return Ok(f_powers_of_two[0].clone());
24+
}
25+
// For n > 1 the next largest power is always >= 0
26+
let max_pow2 = n.ilog2() as usize;
27+
let size_right = 1 << max_pow2;
28+
// Array is a left-balanced (right-associative) binary tree.
29+
let f_right = f_powers_of_two.get(max_pow2).expect("max_pow2 OOB");
30+
let f_left = tree_fold(n - size_right, f_powers_of_two)?;
31+
f_array_fold(&f_left, f_right)
32+
}
33+
34+
/// Fold the two arrays applying the folding function sequentially left -> right.
35+
fn f_array_fold(
36+
f_left: &ProgNode,
37+
f_right: &ProgNode,
38+
) -> Result<ProgNode, simplicity::types::Error> {
39+
// The input is a tuple ((L, R), acc): ([E; n], A) where:
40+
// - L and R are arrays of varying size E^x and E^y respectively (x + y = n).
41+
// - acc is an accumulator of type A.
42+
let ctx = f_left.inference_context();
43+
let left_arr = ProgNode::o().o().h(ctx);
44+
let right_arr = ProgNode::o().i().h(ctx);
45+
let acc = ProgNode::i().h(ctx);
46+
let left_res = left_arr.pair(acc).comp(f_left)?;
47+
let right_res = right_arr.pair(left_res).comp(f_right)?;
48+
Ok(right_res.build())
49+
}
50+
51+
// Precompute the folding functions for arrays of size 2^i where i < n.
52+
let n = size.get();
53+
let mut f_powers_of_two: Vec<ProgNode> = Vec::with_capacity(n.ilog2() as usize);
54+
55+
// An array of size 1 is just the element itself, so f_array_fold_1 is the same as the folding function.
56+
let mut f_prev = f.clone();
57+
f_powers_of_two.push(f_prev.clone());
58+
59+
let mut i = 1;
60+
while i < n {
61+
f_prev = f_array_fold(&f_prev, &f_prev)?;
62+
f_powers_of_two.push(f_prev.clone());
63+
i *= 2;
64+
}
65+
66+
tree_fold(n, &f_powers_of_two)
67+
}
68+
69+
#[cfg(test)]
70+
mod tests {
71+
use crate::{tests::TestCase, WitnessValues};
72+
73+
#[test]
74+
fn array_fold() {
75+
TestCase::program_file("./examples/array_fold.simf")
76+
.with_witness_values(WitnessValues::default())
77+
.assert_run_success();
78+
}
79+
}

src/compile.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::ast::{
1212
Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
1313
SingleExpressionInner, Statement,
1414
};
15+
use crate::builtins::array_fold;
1516
use crate::debug::CallTracker;
1617
use crate::error::{Error, RichError, Span, WithSpan};
1718
use crate::named::{CoreExt, PairBuilder};
@@ -418,6 +419,12 @@ impl Call {
418419
let fold_body = list_fold(*bound, body.as_ref()).with_span(self)?;
419420
args.comp(&fold_body).with_span(self)
420421
}
422+
CallName::ArrayFold(function, size) => {
423+
let mut function_scope = scope.child(function.params_pattern());
424+
let body = function.body().compile(&mut function_scope)?;
425+
let fold_body = array_fold(*size, body.as_ref()).with_span(self)?;
426+
args.comp(&fold_body).with_span(self)
427+
}
421428
CallName::ForWhile(function, bit_width) => {
422429
let mut function_scope = scope.child(function.params_pattern());
423430
let body = function.body().compile(&mut function_scope)?;

src/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ impl From<pest::error::Error<Rule>> for RichError {
294294
/// Records _what_ happened but not where.
295295
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
296296
pub enum Error {
297+
ArraySizeNonZero(usize),
297298
ListBoundPow2(usize),
298299
BitStringPow2(usize),
299300
HexStringLen(usize),
@@ -336,6 +337,10 @@ pub enum Error {
336337
impl fmt::Display for Error {
337338
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338339
match self {
340+
Error::ArraySizeNonZero(size) => write!(
341+
f,
342+
"Expected a non-negative integer as array size, found {size}"
343+
),
339344
Error::ListBoundPow2(bound) => write!(
340345
f,
341346
"Expected a power of two greater than one (2, 4, 8, 16, 32, ...) as list bound, found {bound}"

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub type ProgNode = Arc<named::ConstructNode>;
44

55
pub mod array;
66
pub mod ast;
7+
pub mod builtins;
78
pub mod compile;
89
pub mod debug;
910
pub mod dummy_env;
@@ -274,7 +275,7 @@ pub trait ArbitraryOfType: Sized {
274275
}
275276

276277
#[cfg(test)]
277-
mod tests {
278+
pub(crate) mod tests {
278279
use base64::display::Base64Display;
279280
use base64::engine::general_purpose::STANDARD;
280281
use simplicity::BitMachine;
@@ -283,7 +284,7 @@ mod tests {
283284

284285
use crate::*;
285286

286-
struct TestCase<T> {
287+
pub(crate) struct TestCase<T> {
287288
program: T,
288289
lock_time: elements::LockTime,
289290
sequence: elements::Sequence,

src/minimal.pest

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jet = @{ "jet::" ~ (ASCII_ALPHANUMERIC | "_")+ }
1212
witness_name = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* }
1313
builtin_type = @{ ("Either" | "Option" | "bool" | "List" | unsigned_type) ~ !ASCII_ALPHANUMERIC }
1414

15-
builtin_function = @{ ("unwrap_left" | "unwrap_right" | "for_while" | "is_none" | "unwrap" | "assert" | "panic" | "match" | "into" | "fold" | "dbg") ~ !ASCII_ALPHANUMERIC }
15+
builtin_function = @{ ("unwrap_left" | "unwrap_right" | "array_fold" | "for_while" | "is_none" | "unwrap" | "assert" | "panic" | "match" | "into" | "fold" | "dbg") ~ !ASCII_ALPHANUMERIC }
1616
function_name = { !builtin_function ~ identifier }
1717
typed_identifier = { identifier ~ ":" ~ ty }
1818
function_params = { "(" ~ (typed_identifier ~ ("," ~ typed_identifier)*)? ~ ")" }
@@ -65,9 +65,10 @@ assert = @{ "assert!" }
6565
panic = @{ "panic!" }
6666
type_cast = { "<" ~ ty ~ ">::into" }
6767
debug = @{ "dbg!" }
68+
array_fold = { "array_fold::<" ~ function_name ~ "," ~ array_size ~ ">" }
6869
fold = { "fold::<" ~ function_name ~ "," ~ list_bound ~ ">" }
6970
for_while = { "for_while::<" ~ function_name ~ ">" }
70-
call_name = { jet | unwrap_left | unwrap_right | is_none | unwrap | assert | panic | type_cast | debug | fold | for_while | function_name }
71+
call_name = { jet | unwrap_left | unwrap_right | is_none | unwrap | assert | panic | type_cast | debug | array_fold | fold | for_while | function_name }
7172
call_args = { "(" ~ (expression ~ ("," ~ expression)*)? ~ ")" }
7273
call_expr = { call_name ~ call_args }
7374
dec_literal = @{ (ASCII_DIGIT | "_")+ }

src/parse.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//! tokens into an AST.
33
44
use std::fmt;
5+
use std::num::NonZeroUsize;
56
use std::str::FromStr;
67
use std::sync::Arc;
78

@@ -194,6 +195,8 @@ pub enum CallName {
194195
Custom(FunctionName),
195196
/// Fold of a bounded list with the given function.
196197
Fold(FunctionName, NonZeroPow2Usize),
198+
/// Fold of an array with the given function.
199+
ArrayFold(FunctionName, NonZeroUsize),
197200
/// Loop over the given function a bounded number of times until it returns success.
198201
ForWhile(FunctionName),
199202
}
@@ -757,6 +760,7 @@ impl fmt::Display for CallName {
757760
CallName::TypeCast(ty) => write!(f, "<{ty}>::into"),
758761
CallName::Custom(name) => write!(f, "{name}"),
759762
CallName::Fold(name, bound) => write!(f, "fold::<{name}, {bound}>"),
763+
CallName::ArrayFold(name, size) => write!(f, "array_fold::<{name}, {size}>"),
760764
CallName::ForWhile(name) => write!(f, "for_while::<{name}>"),
761765
}
762766
}
@@ -1037,6 +1041,19 @@ impl PestParse for CallName {
10371041
let bound = NonZeroPow2Usize::parse(it.next().unwrap())?;
10381042
Ok(Self::Fold(name, bound))
10391043
}
1044+
Rule::array_fold => {
1045+
let mut it = pair.into_inner();
1046+
let name = FunctionName::parse(it.next().unwrap())?;
1047+
let non_zero_usize_parse =
1048+
|pair: pest::iterators::Pair<Rule>| -> Result<NonZeroUsize, RichError> {
1049+
let size = pair.as_str().parse::<usize>().with_span(&pair)?;
1050+
NonZeroUsize::new(size)
1051+
.ok_or(Error::ArraySizeNonZero(size))
1052+
.with_span(&pair)
1053+
};
1054+
let size = non_zero_usize_parse(it.next().unwrap())?;
1055+
Ok(Self::ArrayFold(name, size))
1056+
}
10401057
Rule::for_while => {
10411058
let mut it = pair.into_inner();
10421059
let name = FunctionName::parse(it.next().unwrap())?;

0 commit comments

Comments
 (0)