From 440ddab5aaaeaac53d1d93e33a4a6b65fbeb0f3f Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 21 Jan 2025 10:10:04 -0500 Subject: [PATCH] Support `f32.mul`, `f32.div`, and `f64.div` --- crates/floretta/src/lib.rs | 92 +++++++++++++++++++++++++++++ crates/floretta/src/run.rs | 56 +++++++++++++++--- crates/floretta/src/wat/f32_div.wat | 3 + crates/floretta/src/wat/f32_mul.wat | 3 + crates/floretta/src/wat/f64_div.wat | 3 + crates/floretta/src/wat/f64_mul.wat | 3 + 6 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 crates/floretta/src/wat/f32_div.wat create mode 100644 crates/floretta/src/wat/f32_mul.wat create mode 100644 crates/floretta/src/wat/f64_div.wat create mode 100644 crates/floretta/src/wat/f64_mul.wat diff --git a/crates/floretta/src/lib.rs b/crates/floretta/src/lib.rs index a9c70a7..5b7068e 100644 --- a/crates/floretta/src/lib.rs +++ b/crates/floretta/src/lib.rs @@ -177,4 +177,96 @@ mod tests { assert_eq!(fwd.call(&mut store, 1.1).unwrap(), -0.99); assert_eq!(bwd.call(&mut store, 1.).unwrap(), 0.20000000000000018); } + + #[test] + fn test_f32_mul() { + let input = wat::parse_str(include_str!("wat/f32_mul.wat")).unwrap(); + + let mut ad = crate::Autodiff::new(); + ad.export("mul", "backprop"); + let output = ad.transform(&input).unwrap(); + + let engine = Engine::default(); + let mut store = Store::new(&engine, ()); + let module = Module::new(&engine, &output).unwrap(); + let instance = Instance::new(&mut store, &module, &[]).unwrap(); + let fwd = instance + .get_typed_func::<(f32, f32), f32>(&mut store, "mul") + .unwrap(); + let bwd = instance + .get_typed_func::(&mut store, "backprop") + .unwrap(); + + assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 6.); + assert_eq!(bwd.call(&mut store, 1.).unwrap(), (2., 3.)); + } + + #[test] + fn test_f32_div() { + let input = wat::parse_str(include_str!("wat/f32_div.wat")).unwrap(); + + let mut ad = crate::Autodiff::new(); + ad.export("div", "backprop"); + let output = ad.transform(&input).unwrap(); + + let engine = Engine::default(); + let mut store = Store::new(&engine, ()); + let module = Module::new(&engine, &output).unwrap(); + let instance = Instance::new(&mut store, &module, &[]).unwrap(); + let fwd = instance + .get_typed_func::<(f32, f32), f32>(&mut store, "div") + .unwrap(); + let bwd = instance + .get_typed_func::(&mut store, "backprop") + .unwrap(); + + assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 1.5); + assert_eq!(bwd.call(&mut store, 1.).unwrap(), (0.5, -0.75)); + } + + #[test] + fn test_f64_mul() { + let input = wat::parse_str(include_str!("wat/f64_mul.wat")).unwrap(); + + let mut ad = crate::Autodiff::new(); + ad.export("mul", "backprop"); + let output = ad.transform(&input).unwrap(); + + let engine = Engine::default(); + let mut store = Store::new(&engine, ()); + let module = Module::new(&engine, &output).unwrap(); + let instance = Instance::new(&mut store, &module, &[]).unwrap(); + let fwd = instance + .get_typed_func::<(f64, f64), f64>(&mut store, "mul") + .unwrap(); + let bwd = instance + .get_typed_func::(&mut store, "backprop") + .unwrap(); + + assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 6.); + assert_eq!(bwd.call(&mut store, 1.).unwrap(), (2., 3.)); + } + + #[test] + fn test_f64_div() { + let input = wat::parse_str(include_str!("wat/f64_div.wat")).unwrap(); + + let mut ad = crate::Autodiff::new(); + ad.export("div", "backprop"); + let output = ad.transform(&input).unwrap(); + + let engine = Engine::default(); + let mut store = Store::new(&engine, ()); + let module = Module::new(&engine, &output).unwrap(); + let instance = Instance::new(&mut store, &module, &[]).unwrap(); + let fwd = instance + .get_typed_func::<(f64, f64), f64>(&mut store, "div") + .unwrap(); + let bwd = instance + .get_typed_func::(&mut store, "backprop") + .unwrap(); + + assert_eq!(fwd.call(&mut store, (3., 2.)).unwrap(), 1.5); + assert_eq!(bwd.call(&mut store, 1.).unwrap(), (0.5, -0.75)); + } } diff --git a/crates/floretta/src/run.rs b/crates/floretta/src/run.rs index 197e549..77df97e 100644 --- a/crates/floretta/src/run.rs +++ b/crates/floretta/src/run.rs @@ -9,9 +9,11 @@ use wasmparser::{FunctionBody, Operator, Parser, Payload}; use crate::{ helper::{ - helpers, FUNC_CONTROL_LOAD, FUNC_CONTROL_STORE, FUNC_F64_MUL_BWD, FUNC_F64_MUL_FWD, - OFFSET_FUNCTIONS, OFFSET_GLOBALS, OFFSET_MEMORIES, OFFSET_TYPES, TYPE_CONTROL_LOAD, - TYPE_CONTROL_STORE, TYPE_F32_BIN_BWD, TYPE_F32_BIN_FWD, TYPE_F64_BIN_BWD, TYPE_F64_BIN_FWD, + helpers, FUNC_CONTROL_LOAD, FUNC_CONTROL_STORE, FUNC_F32_DIV_BWD, FUNC_F32_DIV_FWD, + FUNC_F32_MUL_BWD, FUNC_F32_MUL_FWD, FUNC_F64_DIV_BWD, FUNC_F64_DIV_FWD, FUNC_F64_MUL_BWD, + FUNC_F64_MUL_FWD, OFFSET_FUNCTIONS, OFFSET_GLOBALS, OFFSET_MEMORIES, OFFSET_TYPES, + TYPE_CONTROL_LOAD, TYPE_CONTROL_STORE, TYPE_F32_BIN_BWD, TYPE_F32_BIN_FWD, + TYPE_F64_BIN_BWD, TYPE_F64_BIN_FWD, }, util::u32_to_usize, validate::{FunctionValidator, ModuleValidator}, @@ -233,14 +235,23 @@ fn function( bwd.local(local); } let tmp_f64 = bwd.local(wasm_encoder::ValType::F64); - for i in 0..num_params { + // The first basic block in the forward pass corresponds to the last basic block in the backward + // pass, and because each basic block will be reversed, the first instructions we write will + // become the last instructions in the function body of the backward pass. Because Wasm + // parameters appear in locals but Wasm results appear on the stack, when we reach the end of + // the backward pass, we'll have stored all the parameter adjoints in locals corresponding to + // those from the forward pass. So, we need to push the values of those locals onto the stack in + // order; and because everything will be reversed, that means that here we need to start with + // pushing the local corresponding to the last parameter first, then work downward to pushing + // the first parameter on the bottom of the stack. + for i in (0..num_params).rev() { bwd.instruction(&Instruction::LocalGet(num_results + i)); } let mut operators_reader = body.get_operators_reader()?; let mut func = Func { num_results, locals: &locals, - offset: 0, // this initial value should be unused; to be set before each instruction + offset: 0, // This initial value should be unused; to be set before each instruction. operand_stack: Vec::new(), operand_stack_height: StackHeight::new(), operand_stack_height_min: 0, @@ -333,9 +344,14 @@ impl Func<'_> { let ty = self.local(local_index); self.push(ty); self.fwd.instruction(&Instruction::LocalGet(local_index)); + let i = self.bwd_local(local_index); match ty { + wasm_encoder::ValType::F32 => { + self.bwd.instruction(&Instruction::LocalSet(i)); + self.bwd.instruction(&Instruction::F32Add); + self.bwd.instruction(&Instruction::LocalGet(i)); + } wasm_encoder::ValType::F64 => { - let i = self.num_results + local_index; self.bwd.instruction(&Instruction::LocalSet(i)); self.bwd.instruction(&Instruction::F64Add); self.bwd.instruction(&Instruction::LocalGet(i)); @@ -348,9 +364,9 @@ impl Func<'_> { self.pop(); self.push(ty); self.fwd.instruction(&Instruction::LocalTee(local_index)); + let i = self.bwd_local(local_index); match ty { wasm_encoder::ValType::F64 => { - let i = self.num_results + local_index; self.bwd.instruction(&Instruction::LocalSet(i)); self.bwd.instruction(&Instruction::F64Const(0.)); self.bwd.instruction(&Instruction::F64Add); @@ -372,6 +388,18 @@ impl Func<'_> { self.bwd.instruction(&Instruction::F64Const(0.)); self.bwd.instruction(&Instruction::Drop); } + Operator::F32Mul => { + self.pop2(); + self.push_f32(); + self.fwd.instruction(&Instruction::Call(FUNC_F32_MUL_FWD)); + self.bwd.instruction(&Instruction::Call(FUNC_F32_MUL_BWD)); + } + Operator::F32Div => { + self.pop2(); + self.push_f32(); + self.fwd.instruction(&Instruction::Call(FUNC_F32_DIV_FWD)); + self.bwd.instruction(&Instruction::Call(FUNC_F32_DIV_BWD)); + } Operator::F64Sub => { self.pop2(); self.push_f64(); @@ -386,6 +414,12 @@ impl Func<'_> { self.fwd.instruction(&Instruction::Call(FUNC_F64_MUL_FWD)); self.bwd.instruction(&Instruction::Call(FUNC_F64_MUL_BWD)); } + Operator::F64Div => { + self.pop2(); + self.push_f64(); + self.fwd.instruction(&Instruction::Call(FUNC_F64_DIV_FWD)); + self.bwd.instruction(&Instruction::Call(FUNC_F64_DIV_BWD)); + } _ => unimplemented!("{op:?}"), } Ok(()) @@ -400,6 +434,10 @@ impl Func<'_> { self.push(wasm_encoder::ValType::I32); } + fn push_f32(&mut self) { + self.push(wasm_encoder::ValType::F32); + } + fn push_f64(&mut self) { self.push(wasm_encoder::ValType::F64); } @@ -424,6 +462,10 @@ impl Func<'_> { self.locals[u32_to_usize(index)] } + fn bwd_local(&self, index: u32) -> u32 { + self.num_results + index + } + /// In the forward pass, store the current basic block index on the tape. fn fwd_control_store(&mut self) { self.fwd diff --git a/crates/floretta/src/wat/f32_div.wat b/crates/floretta/src/wat/f32_div.wat new file mode 100644 index 0000000..5235526 --- /dev/null +++ b/crates/floretta/src/wat/f32_div.wat @@ -0,0 +1,3 @@ +(module + (func (export "div") (param f32 f32) (result f32) + (f32.div (local.get 0) (local.get 1)))) diff --git a/crates/floretta/src/wat/f32_mul.wat b/crates/floretta/src/wat/f32_mul.wat new file mode 100644 index 0000000..a0b2f4f --- /dev/null +++ b/crates/floretta/src/wat/f32_mul.wat @@ -0,0 +1,3 @@ +(module + (func (export "mul") (param f32 f32) (result f32) + (f32.mul (local.get 0) (local.get 1)))) diff --git a/crates/floretta/src/wat/f64_div.wat b/crates/floretta/src/wat/f64_div.wat new file mode 100644 index 0000000..3b6808a --- /dev/null +++ b/crates/floretta/src/wat/f64_div.wat @@ -0,0 +1,3 @@ +(module + (func (export "div") (param f64 f64) (result f64) + (f64.div (local.get 0) (local.get 1)))) diff --git a/crates/floretta/src/wat/f64_mul.wat b/crates/floretta/src/wat/f64_mul.wat new file mode 100644 index 0000000..e558c3e --- /dev/null +++ b/crates/floretta/src/wat/f64_mul.wat @@ -0,0 +1,3 @@ +(module + (func (export "mul") (param f64 f64) (result f64) + (f64.mul (local.get 0) (local.get 1))))