Skip to content

Commit

Permalink
Define and export a sum function
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Apr 10, 2024
1 parent 49a893e commit d87f323
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
13 changes: 13 additions & 0 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,19 @@ impl Block {
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new instruction accumulating `addend` into `accum`.
///
/// Assumes `accum` and `addend` are defined and in scope.
#[wasm_bindgen(js_name = "addTo")]
pub fn add_to(&mut self, f: &mut FuncBuilder, accum: usize, addend: usize) -> usize {
let t = id::ty(f.ty_unit());
let expr = rose::Expr::Add {
accum: id::var(accum),
addend: id::var(addend),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new instruction resolving the given accumulator `var`.
///
/// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type
Expand Down
13 changes: 13 additions & 0 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,19 @@ export const vec = <const I, const T>(
return idVal(ctx, t, id) as Vec<Symbolic<T>>;
};

/** Return the sum after computing each number via `f`. */
export const sum = <const I>(index: I, f: (i: Symbolic<I>) => Real): Real => {
const ctx = getCtx();
const reals = ctx.func.tyF64();
const acc = ctx.block.accum(ctx.func, ctx.func.tyRef(reals), realId(ctx, 0));
vec(index, Null, (i) => {
const x = realId(ctx, f(i));
const t = ctx.func.tyUnit();
return idVal(ctx, t, ctx.block.addTo(ctx.func, acc, x)) as Null;
});
return idVal(ctx, reals, ctx.block.resolve(ctx.func, reals, acc)) as Real;
};

/** Return the variable ID for the abstract number or tangent `x`. */
const numId = (ctx: Context, x: Real | Tan): number => {
if (typeof x === "object") return (x as any)[variable];
Expand Down
19 changes: 3 additions & 16 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
sqrt,
struct,
sub,
sum,
trunc,
vec,
vjp,
Expand Down Expand Up @@ -233,12 +234,7 @@ describe("valid", () => {

test("dot product", () => {
const R3 = Vec(3, Real);
const dot = fn([R3, R3], Real, (u, v) => {
const x = mul(u[0], v[0]);
const y = mul(u[1], v[1]);
const z = mul(u[2], v[2]);
return add(add(x, y), z);
});
const dot = fn([R3, R3], Real, (u, v) => sum(3, (i) => mul(u[i], v[i])));
const f = interp(dot);
expect(f([1, 3, -5], [4, -2, -1])).toBe(3);
});
Expand Down Expand Up @@ -280,16 +276,7 @@ describe("valid", () => {

const Rn = Vec(n, Real);

const dot = fn([Rn, Rn], Real, (u, v) => {
const w = vec(n, Real, (i) => mul(u[i], v[i]));
let s = w[0];
s = add(s, w[1]);
s = add(s, w[2]);
s = add(s, w[3]);
s = add(s, w[4]);
s = add(s, w[5]);
return s;
});
const dot = fn([Rn, Rn], Real, (u, v) => sum(n, (i) => mul(u[i], v[i])));

const m = 5;
const p = 7;
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export {
sqrt,
struct,
sub,
sum,
trunc,
vec,
vjp,
Expand Down

0 comments on commit d87f323

Please sign in to comment.