Skip to content

Commit

Permalink
Support f32.mul, f32.div, and f64.div
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Jan 21, 2025
1 parent dbb3858 commit 440ddab
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 7 deletions.
92 changes: 92 additions & 0 deletions crates/floretta/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,96 @@ mod tests {
assert_eq!(fwd.call(&mut store, 1.1).unwrap(), -0.99);
assert_eq!(bwd.call(&mut store, 1.).unwrap(), 0.20000000000000018);
}

#[test]
fn test_f32_mul() {
let input = wat::parse_str(include_str!("wat/f32_mul.wat")).unwrap();

let mut ad = crate::Autodiff::new();
ad.export("mul", "backprop");
let output = ad.transform(&input).unwrap();

let engine = Engine::default();
let mut store = Store::new(&engine, ());
let module = Module::new(&engine, &output).unwrap();
let instance = Instance::new(&mut store, &module, &[]).unwrap();
let fwd = instance
.get_typed_func::<(f32, f32), f32>(&mut store, "mul")
.unwrap();
let bwd = instance
.get_typed_func::<f32, (f32, f32)>(&mut store, "backprop")
.unwrap();

assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 6.);
assert_eq!(bwd.call(&mut store, 1.).unwrap(), (2., 3.));
}

#[test]
fn test_f32_div() {
let input = wat::parse_str(include_str!("wat/f32_div.wat")).unwrap();

let mut ad = crate::Autodiff::new();
ad.export("div", "backprop");
let output = ad.transform(&input).unwrap();

let engine = Engine::default();
let mut store = Store::new(&engine, ());
let module = Module::new(&engine, &output).unwrap();
let instance = Instance::new(&mut store, &module, &[]).unwrap();
let fwd = instance
.get_typed_func::<(f32, f32), f32>(&mut store, "div")
.unwrap();
let bwd = instance
.get_typed_func::<f32, (f32, f32)>(&mut store, "backprop")
.unwrap();

assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 1.5);
assert_eq!(bwd.call(&mut store, 1.).unwrap(), (0.5, -0.75));
}

#[test]
fn test_f64_mul() {
let input = wat::parse_str(include_str!("wat/f64_mul.wat")).unwrap();

let mut ad = crate::Autodiff::new();
ad.export("mul", "backprop");
let output = ad.transform(&input).unwrap();

let engine = Engine::default();
let mut store = Store::new(&engine, ());
let module = Module::new(&engine, &output).unwrap();
let instance = Instance::new(&mut store, &module, &[]).unwrap();
let fwd = instance
.get_typed_func::<(f64, f64), f64>(&mut store, "mul")
.unwrap();
let bwd = instance
.get_typed_func::<f64, (f64, f64)>(&mut store, "backprop")
.unwrap();

assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 6.);
assert_eq!(bwd.call(&mut store, 1.).unwrap(), (2., 3.));
}

#[test]
fn test_f64_div() {
let input = wat::parse_str(include_str!("wat/f64_div.wat")).unwrap();

let mut ad = crate::Autodiff::new();
ad.export("div", "backprop");
let output = ad.transform(&input).unwrap();

let engine = Engine::default();
let mut store = Store::new(&engine, ());
let module = Module::new(&engine, &output).unwrap();
let instance = Instance::new(&mut store, &module, &[]).unwrap();
let fwd = instance
.get_typed_func::<(f64, f64), f64>(&mut store, "div")
.unwrap();
let bwd = instance
.get_typed_func::<f64, (f64, f64)>(&mut store, "backprop")
.unwrap();

assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 1.5);
assert_eq!(bwd.call(&mut store, 1.).unwrap(), (0.5, -0.75));
}
}
56 changes: 49 additions & 7 deletions crates/floretta/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ use wasmparser::{FunctionBody, Operator, Parser, Payload};

use crate::{
helper::{
helpers, FUNC_CONTROL_LOAD, FUNC_CONTROL_STORE, FUNC_F64_MUL_BWD, FUNC_F64_MUL_FWD,
OFFSET_FUNCTIONS, OFFSET_GLOBALS, OFFSET_MEMORIES, OFFSET_TYPES, TYPE_CONTROL_LOAD,
TYPE_CONTROL_STORE, TYPE_F32_BIN_BWD, TYPE_F32_BIN_FWD, TYPE_F64_BIN_BWD, TYPE_F64_BIN_FWD,
helpers, FUNC_CONTROL_LOAD, FUNC_CONTROL_STORE, FUNC_F32_DIV_BWD, FUNC_F32_DIV_FWD,
FUNC_F32_MUL_BWD, FUNC_F32_MUL_FWD, FUNC_F64_DIV_BWD, FUNC_F64_DIV_FWD, FUNC_F64_MUL_BWD,
FUNC_F64_MUL_FWD, OFFSET_FUNCTIONS, OFFSET_GLOBALS, OFFSET_MEMORIES, OFFSET_TYPES,
TYPE_CONTROL_LOAD, TYPE_CONTROL_STORE, TYPE_F32_BIN_BWD, TYPE_F32_BIN_FWD,
TYPE_F64_BIN_BWD, TYPE_F64_BIN_FWD,
},
util::u32_to_usize,
validate::{FunctionValidator, ModuleValidator},
Expand Down Expand Up @@ -233,14 +235,23 @@ fn function(
bwd.local(local);
}
let tmp_f64 = bwd.local(wasm_encoder::ValType::F64);
for i in 0..num_params {
// The first basic block in the forward pass corresponds to the last basic block in the backward
// pass, and because each basic block will be reversed, the first instructions we write will
// become the last instructions in the function body of the backward pass. Because Wasm
// parameters appear in locals but Wasm results appear on the stack, when we reach the end of
// the backward pass, we'll have stored all the parameter adjoints in locals corresponding to
// those from the forward pass. So, we need to push the values of those locals onto the stack in
// order; and because everything will be reversed, that means that here we need to start with
// pushing the local corresponding to the last parameter first, then work downward to pushing
// the first parameter on the bottom of the stack.
for i in (0..num_params).rev() {
bwd.instruction(&Instruction::LocalGet(num_results + i));
}
let mut operators_reader = body.get_operators_reader()?;
let mut func = Func {
num_results,
locals: &locals,
offset: 0, // this initial value should be unused; to be set before each instruction
offset: 0, // This initial value should be unused; to be set before each instruction.
operand_stack: Vec::new(),
operand_stack_height: StackHeight::new(),
operand_stack_height_min: 0,
Expand Down Expand Up @@ -333,9 +344,14 @@ impl Func<'_> {
let ty = self.local(local_index);
self.push(ty);
self.fwd.instruction(&Instruction::LocalGet(local_index));
let i = self.bwd_local(local_index);
match ty {
wasm_encoder::ValType::F32 => {
self.bwd.instruction(&Instruction::LocalSet(i));
self.bwd.instruction(&Instruction::F32Add);
self.bwd.instruction(&Instruction::LocalGet(i));
}
wasm_encoder::ValType::F64 => {
let i = self.num_results + local_index;
self.bwd.instruction(&Instruction::LocalSet(i));
self.bwd.instruction(&Instruction::F64Add);
self.bwd.instruction(&Instruction::LocalGet(i));
Expand All @@ -348,9 +364,9 @@ impl Func<'_> {
self.pop();
self.push(ty);
self.fwd.instruction(&Instruction::LocalTee(local_index));
let i = self.bwd_local(local_index);
match ty {
wasm_encoder::ValType::F64 => {
let i = self.num_results + local_index;
self.bwd.instruction(&Instruction::LocalSet(i));
self.bwd.instruction(&Instruction::F64Const(0.));
self.bwd.instruction(&Instruction::F64Add);
Expand All @@ -372,6 +388,18 @@ impl Func<'_> {
self.bwd.instruction(&Instruction::F64Const(0.));
self.bwd.instruction(&Instruction::Drop);
}
Operator::F32Mul => {
self.pop2();
self.push_f32();
self.fwd.instruction(&Instruction::Call(FUNC_F32_MUL_FWD));
self.bwd.instruction(&Instruction::Call(FUNC_F32_MUL_BWD));
}
Operator::F32Div => {
self.pop2();
self.push_f32();
self.fwd.instruction(&Instruction::Call(FUNC_F32_DIV_FWD));
self.bwd.instruction(&Instruction::Call(FUNC_F32_DIV_BWD));
}
Operator::F64Sub => {
self.pop2();
self.push_f64();
Expand All @@ -386,6 +414,12 @@ impl Func<'_> {
self.fwd.instruction(&Instruction::Call(FUNC_F64_MUL_FWD));
self.bwd.instruction(&Instruction::Call(FUNC_F64_MUL_BWD));
}
Operator::F64Div => {
self.pop2();
self.push_f64();
self.fwd.instruction(&Instruction::Call(FUNC_F64_DIV_FWD));
self.bwd.instruction(&Instruction::Call(FUNC_F64_DIV_BWD));
}
_ => unimplemented!("{op:?}"),
}
Ok(())
Expand All @@ -400,6 +434,10 @@ impl Func<'_> {
self.push(wasm_encoder::ValType::I32);
}

fn push_f32(&mut self) {
self.push(wasm_encoder::ValType::F32);
}

fn push_f64(&mut self) {
self.push(wasm_encoder::ValType::F64);
}
Expand All @@ -424,6 +462,10 @@ impl Func<'_> {
self.locals[u32_to_usize(index)]
}

fn bwd_local(&self, index: u32) -> u32 {
self.num_results + index
}

/// In the forward pass, store the current basic block index on the tape.
fn fwd_control_store(&mut self) {
self.fwd
Expand Down
3 changes: 3 additions & 0 deletions crates/floretta/src/wat/f32_div.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(module
(func (export "div") (param f32 f32) (result f32)
(f32.div (local.get 0) (local.get 1))))
3 changes: 3 additions & 0 deletions crates/floretta/src/wat/f32_mul.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(module
(func (export "mul") (param f32 f32) (result f32)
(f32.mul (local.get 0) (local.get 1))))
3 changes: 3 additions & 0 deletions crates/floretta/src/wat/f64_div.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(module
(func (export "div") (param f64 f64) (result f64)
(f64.div (local.get 0) (local.get 1))))
3 changes: 3 additions & 0 deletions crates/floretta/src/wat/f64_mul.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
(module
(func (export "mul") (param f64 f64) (result f64)
(f64.mul (local.get 0) (local.get 1))))

0 comments on commit 440ddab

Please sign in to comment.