Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
- [Match Expression](./match_expression.md)
- [Functions](./function.md)
- [Programs](./program.md)
- [Builtins](./builtins.md)
69 changes: 69 additions & 0 deletions book/src/builtins.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Builtin functions

## Bounded loop

Run a function repeatedly with a bounded counter. The loop stops early when the function returns a successful value.

- Signature: `for_while::<f>(initial_accumulator: A, readonly_context: C) -> Either<B, A>`
- Loop body: `fn f(acc: A, ctx: C, counter: uN) -> Either<B, A>` where `N ∈ {1, 2, 4, 8, 16}`

Example: stop when `counter == 10`.

```rust
fn stop_at_10(acc: (), _: (), i: u8) -> Either<u8, ()> {
match jet::eq_8(i, 10) {
true => Left(i), // success → exit loop
false => Right(acc), // continue with same accumulator
}
}

fn main() {
let out: Either<u8, ()> = for_while::<stop_at_10>((), ());
assert!(jet::eq_8(10, unwrap_left::<()>(out)));
}
```

## List folding

Fold a list of bounded length by repeatedly applying a function.

- Signature: `fold::<f, N>(list: List<E, N>, initial_accumulator: A) -> A`
- Fold step: `fn f(element: E, acc: A) -> A`
- Note: `N` is a power of two; lists hold fewer than `N` elements.

Example: sum a list of 32-bit integers.

```rust
fn sum(elt: u32, acc: u32) -> u32 {
let (_, acc): (bool, u32) = jet::add_32(elt, acc);
acc
}

fn main() {
let xs: List<u32, 8> = list![1, 2, 3];
let s: u32 = fold::<sum, 8>(xs, 0);
assert!(jet::eq_32(s, 6));
}
```

## Array folding

Fold a fixed-size array by repeatedly applying a function.

- Signature: `array_fold::<f, N>(array: [E; N], initial_accumulator: A) -> A`
- Fold step: `fn f(element: E, acc: A) -> A`

Example: sum an array of 7 elements.

```rust
fn sum(elt: u32, acc: u32) -> u32 {
let (_, acc): (bool, u32) = jet::add_32(elt, acc);
acc
}

fn main() {
let arr: [u32; 7] = [1, 2, 3, 4, 5, 6, 7];
let sum: u32 = array_fold::<sum, 7>(arr, 0);
assert!(jet::eq_32(sum, 28));
}
```
10 changes: 10 additions & 0 deletions examples/array_fold.simf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
fn sum(elt: u32, acc: u32) -> u32 {
let (_, acc): (bool, u32) = jet::add_32(elt, acc);
acc
}

fn main() {
let arr: [u32; 7] = [1, 2, 3, 4, 5, 6, 7];
let sum: u32 = array_fold::<sum, 7>(arr, 0);
assert!(jet::eq_32(sum, 28));
}
38 changes: 38 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;

Expand Down Expand Up @@ -275,6 +276,8 @@ pub enum CallName {
Custom(CustomFunction),
/// Fold of a bounded list with the given function.
Fold(CustomFunction, NonZeroPow2Usize),
/// Fold of an array with the given function.
ArrayFold(CustomFunction, NonZeroUsize),
/// Loop over the given function a bounded number of times until it returns success.
ForWhile(CustomFunction, Pow2Usize),
}
Expand Down Expand Up @@ -1187,6 +1190,26 @@ impl AbstractSyntaxTree for Call {
check_output_type(out_ty, ty).with_span(from)?;
analyze_arguments(from.args(), &args_ty, scope)?
}
CallName::ArrayFold(function, size) => {
// An array fold has the signature:
// array_fold::<f, N>(array: [E; N], initial_accumulator: A) -> A
// where
// fn f(element: E, accumulator: A) -> A
let element_ty = function.params().first().expect("foldable function").ty();
let array_ty = ResolvedType::array(element_ty.clone(), size.get());
let accumulator_ty = function
.params()
.get(1)
.expect("foldable function")
.ty()
.clone();
let args_ty = [array_ty, accumulator_ty];

check_argument_types(from.args(), &args_ty).with_span(from)?;
let out_ty = function.body().ty();
check_output_type(out_ty, ty).with_span(from)?;
analyze_arguments(from.args(), &args_ty, scope)?
}
CallName::ForWhile(function, _bit_width) => {
// A for-while loop has the signature:
// for_while::<f>(initial_accumulator: A, readonly_context: C) -> Either<B, A>
Expand Down Expand Up @@ -1262,6 +1285,21 @@ impl AbstractSyntaxTree for CallName {
.map(Self::Custom)
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from),
parse::CallName::ArrayFold(name, size) => {
let function = scope
.get_function(name)
.cloned()
.ok_or(Error::FunctionUndefined(name.clone()))
.with_span(from)?;
// A function that is used in a array fold has the signature:
// fn f(element: E, accumulator: A) -> A
if function.params().len() != 2 || function.params()[1].ty() != function.body().ty()
{
Err(Error::FunctionNotFoldable(name.clone())).with_span(from)
} else {
Ok(Self::ArrayFold(function, *size))
}
}
parse::CallName::Fold(name, bound) => {
let function = scope
.get_function(name)
Expand Down
79 changes: 79 additions & 0 deletions src/builtins.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::num::NonZeroUsize;

use simplicity::node::CoreConstructible;

use crate::{named::CoreExt, ProgNode};

/// Fold an array of size `size` elements using function `f`.
///
/// Function `f: E × A → A`
/// takes an array element of type `E` and an accumulator of type `A`,
/// and it produces an updated accumulator of type `A`.
///
/// The fold `(fold f)_n : E^n × A → A`
/// takes the array of type `E^n` and an initial accumulator of type `A`,
/// and it produces the final accumulator of type `A`.
pub fn array_fold(size: NonZeroUsize, f: &ProgNode) -> Result<ProgNode, simplicity::types::Error> {
/// Recursively fold the array using the precomputed folding functions.
fn tree_fold(
n: usize,
f_powers_of_two: &Vec<ProgNode>,
) -> Result<ProgNode, simplicity::types::Error> {
if n == 1 {
return Ok(f_powers_of_two[0].clone());
}
// For n > 1 the next largest power is always >= 0
let max_pow2 = n.ilog2() as usize;
let size_right = 1 << max_pow2;
// Array is a left-balanced (right-associative) binary tree.
let f_right = f_powers_of_two.get(max_pow2).expect("max_pow2 OOB");
let f_left = tree_fold(n - size_right, f_powers_of_two)?;
f_array_fold(&f_left, f_right)
}

/// Fold the two arrays applying the folding function sequentially left -> right.
fn f_array_fold(
f_left: &ProgNode,
f_right: &ProgNode,
) -> Result<ProgNode, simplicity::types::Error> {
// The input is a tuple ((L, R), acc): ([E; n], A) where:
// - L and R are arrays of varying size E^x and E^y respectively (x + y = n).
// - acc is an accumulator of type A.
let ctx = f_left.inference_context();
let left_arr = ProgNode::o().o().h(ctx);
let right_arr = ProgNode::o().i().h(ctx);
let acc = ProgNode::i().h(ctx);
let left_res = left_arr.pair(acc).comp(f_left)?;
let right_res = right_arr.pair(left_res).comp(f_right)?;
Ok(right_res.build())
}

// Precompute the folding functions for arrays of size 2^i where i < n.
let n = size.get();
let mut f_powers_of_two: Vec<ProgNode> = Vec::with_capacity(n.ilog2() as usize);

// An array of size 1 is just the element itself, so f_array_fold_1 is the same as the folding function.
let mut f_prev = f.clone();
f_powers_of_two.push(f_prev.clone());

let mut i = 1;
while i < n {
f_prev = f_array_fold(&f_prev, &f_prev)?;
f_powers_of_two.push(f_prev.clone());
i *= 2;
}

tree_fold(n, &f_powers_of_two)
}

#[cfg(test)]
mod tests {
use crate::{tests::TestCase, WitnessValues};

#[test]
fn array_fold() {
TestCase::program_file("./examples/array_fold.simf")
.with_witness_values(WitnessValues::default())
.assert_run_success();
}
}
7 changes: 7 additions & 0 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::ast::{
Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
SingleExpressionInner, Statement,
};
use crate::builtins::array_fold;
use crate::debug::CallTracker;
use crate::error::{Error, RichError, Span, WithSpan};
use crate::named::{CoreExt, PairBuilder};
Expand Down Expand Up @@ -418,6 +419,12 @@ impl Call {
let fold_body = list_fold(*bound, body.as_ref()).with_span(self)?;
args.comp(&fold_body).with_span(self)
}
CallName::ArrayFold(function, size) => {
let mut function_scope = scope.child(function.params_pattern());
let body = function.body().compile(&mut function_scope)?;
let fold_body = array_fold(*size, body.as_ref()).with_span(self)?;
args.comp(&fold_body).with_span(self)
}
CallName::ForWhile(function, bit_width) => {
let mut function_scope = scope.child(function.params_pattern());
let body = function.body().compile(&mut function_scope)?;
Expand Down
5 changes: 5 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ impl From<pest::error::Error<Rule>> for RichError {
/// Records _what_ happened but not where.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub enum Error {
ArraySizeNonZero(usize),
ListBoundPow2(usize),
BitStringPow2(usize),
HexStringLen(usize),
Expand Down Expand Up @@ -336,6 +337,10 @@ pub enum Error {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::ArraySizeNonZero(size) => write!(
f,
"Expected a non-negative integer as array size, found {size}"
),
Error::ListBoundPow2(bound) => write!(
f,
"Expected a power of two greater than one (2, 4, 8, 16, 32, ...) as list bound, found {bound}"
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub type ProgNode = Arc<named::ConstructNode>;

pub mod array;
pub mod ast;
pub mod builtins;
pub mod compile;
pub mod debug;
pub mod dummy_env;
Expand Down Expand Up @@ -274,7 +275,7 @@ pub trait ArbitraryOfType: Sized {
}

#[cfg(test)]
mod tests {
pub(crate) mod tests {
use base64::display::Base64Display;
use base64::engine::general_purpose::STANDARD;
use simplicity::BitMachine;
Expand All @@ -283,7 +284,7 @@ mod tests {

use crate::*;

struct TestCase<T> {
pub(crate) struct TestCase<T> {
program: T,
lock_time: elements::LockTime,
sequence: elements::Sequence,
Expand Down
5 changes: 3 additions & 2 deletions src/minimal.pest
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jet = @{ "jet::" ~ (ASCII_ALPHANUMERIC | "_")+ }
witness_name = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* }
builtin_type = @{ ("Either" | "Option" | "bool" | "List" | unsigned_type) ~ !ASCII_ALPHANUMERIC }

builtin_function = @{ ("unwrap_left" | "unwrap_right" | "for_while" | "is_none" | "unwrap" | "assert" | "panic" | "match" | "into" | "fold" | "dbg") ~ !ASCII_ALPHANUMERIC }
builtin_function = @{ ("unwrap_left" | "unwrap_right" | "array_fold" | "for_while" | "is_none" | "unwrap" | "assert" | "panic" | "match" | "into" | "fold" | "dbg") ~ !ASCII_ALPHANUMERIC }
function_name = { !builtin_function ~ identifier }
typed_identifier = { identifier ~ ":" ~ ty }
function_params = { "(" ~ (typed_identifier ~ ("," ~ typed_identifier)*)? ~ ")" }
Expand Down Expand Up @@ -65,9 +65,10 @@ assert = @{ "assert!" }
panic = @{ "panic!" }
type_cast = { "<" ~ ty ~ ">::into" }
debug = @{ "dbg!" }
array_fold = { "array_fold::<" ~ function_name ~ "," ~ array_size ~ ">" }
fold = { "fold::<" ~ function_name ~ "," ~ list_bound ~ ">" }
for_while = { "for_while::<" ~ function_name ~ ">" }
call_name = { jet | unwrap_left | unwrap_right | is_none | unwrap | assert | panic | type_cast | debug | fold | for_while | function_name }
call_name = { jet | unwrap_left | unwrap_right | is_none | unwrap | assert | panic | type_cast | debug | array_fold | fold | for_while | function_name }
call_args = { "(" ~ (expression ~ ("," ~ expression)*)? ~ ")" }
call_expr = { call_name ~ call_args }
dec_literal = @{ (ASCII_DIGIT | "_")+ }
Expand Down
17 changes: 17 additions & 0 deletions src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! tokens into an AST.

use std::fmt;
use std::num::NonZeroUsize;
use std::str::FromStr;
use std::sync::Arc;

Expand Down Expand Up @@ -194,6 +195,8 @@ pub enum CallName {
Custom(FunctionName),
/// Fold of a bounded list with the given function.
Fold(FunctionName, NonZeroPow2Usize),
/// Fold of an array with the given function.
ArrayFold(FunctionName, NonZeroUsize),
/// Loop over the given function a bounded number of times until it returns success.
ForWhile(FunctionName),
}
Expand Down Expand Up @@ -757,6 +760,7 @@ impl fmt::Display for CallName {
CallName::TypeCast(ty) => write!(f, "<{ty}>::into"),
CallName::Custom(name) => write!(f, "{name}"),
CallName::Fold(name, bound) => write!(f, "fold::<{name}, {bound}>"),
CallName::ArrayFold(name, size) => write!(f, "array_fold::<{name}, {size}>"),
CallName::ForWhile(name) => write!(f, "for_while::<{name}>"),
}
}
Expand Down Expand Up @@ -1037,6 +1041,19 @@ impl PestParse for CallName {
let bound = NonZeroPow2Usize::parse(it.next().unwrap())?;
Ok(Self::Fold(name, bound))
}
Rule::array_fold => {
let mut it = pair.into_inner();
let name = FunctionName::parse(it.next().unwrap())?;
let non_zero_usize_parse =
|pair: pest::iterators::Pair<Rule>| -> Result<NonZeroUsize, RichError> {
let size = pair.as_str().parse::<usize>().with_span(&pair)?;
NonZeroUsize::new(size)
.ok_or(Error::ArraySizeNonZero(size))
.with_span(&pair)
};
let size = non_zero_usize_parse(it.next().unwrap())?;
Ok(Self::ArrayFold(name, size))
}
Rule::for_while => {
let mut it = pair.into_inner();
let name = FunctionName::parse(it.next().unwrap())?;
Expand Down
2 changes: 1 addition & 1 deletion vscode/syntaxes/simfony.tmLanguage.json
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
"patterns": [
{
"name": "entity.name.function.simfony",
"match": "\\b(unwrap_left|unwrap_right|for_while|is_none|unwrap|into|fold|dbg)\\b"
"match": "\\b(unwrap_left|unwrap_right|for_while|is_none|array_fold|unwrap|into|fold|dbg)\\b"
},
{
"match": "\\b(fn)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(",
Expand Down