From 78cde9798ca10bac9359f9ef1c2d78ab44433cac Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Wed, 20 Nov 2024 17:58:44 -0500 Subject: [PATCH] Shrink more by making validation optional --- crates/floretta-wasm/src/lib.rs | 2 +- crates/floretta/src/lib.rs | 272 ++++++-------------------------- crates/floretta/src/run.rs | 231 +++++++++++++++++++++++++++ crates/floretta/src/validate.rs | 142 +++++++++++++++++ 4 files changed, 422 insertions(+), 225 deletions(-) create mode 100644 crates/floretta/src/run.rs create mode 100644 crates/floretta/src/validate.rs diff --git a/crates/floretta-wasm/src/lib.rs b/crates/floretta-wasm/src/lib.rs index 98221f2..a2a770b 100644 --- a/crates/floretta-wasm/src/lib.rs +++ b/crates/floretta-wasm/src/lib.rs @@ -1,4 +1,4 @@ #[no_mangle] fn autodiff(wasm_module: &[u8]) -> Result, floretta::Error> { - floretta::Autodiff::new().transform(wasm_module) + floretta::Autodiff::no_validate().transform(wasm_module) } diff --git a/crates/floretta/src/lib.rs b/crates/floretta/src/lib.rs index 859dedf..20d15fb 100644 --- a/crates/floretta/src/lib.rs +++ b/crates/floretta/src/lib.rs @@ -31,17 +31,13 @@ //! [`wat`]: https://crates.io/crates/wat //! [wasmtime]: https://crates.io/crates/wasmtime +mod run; +mod validate; + use std::collections::HashMap; -use wasm_encoder::{ - reencode::{self, Reencode, RoundtripReencoder}, - CodeSection, ExportKind, ExportSection, Function, FunctionSection, GlobalSection, Instruction, - MemorySection, Module, TypeSection, -}; -use wasmparser::{ - BinaryReaderError, FuncToValidate, FuncValidatorAllocations, FunctionBody, Operator, Parser, - Payload, Validator, ValidatorResources, WasmFeatures, -}; +use wasm_encoder::reencode; +use wasmparser::{BinaryReaderError, Validator, WasmFeatures}; /// An error that occurred during code transformation. #[derive(Debug, thiserror::Error)] @@ -55,249 +51,77 @@ pub enum Error { Reencode(#[from] reencode::Error), } -/// WebAssembly code transformation to perform reverse-mode automatic differentiation. #[derive(Default)] -pub struct Autodiff { +struct Config { /// Exported functions whose backward passes should also be exported. exports: HashMap, } +/// WebAssembly code transformation to perform reverse-mode automatic differentiation. +pub struct Autodiff { + runner: Box, + config: Config, +} + +impl Default for Autodiff { + fn default() -> Self { + Self { + runner: Box::new(Validate), + config: Default::default(), + } + } +} + impl Autodiff { /// Default configuration. pub fn new() -> Self { Self::default() } + /// Do not validate input Wasm. + pub fn no_validate() -> Self { + Self { + runner: Box::new(NoValidate), + config: Default::default(), + } + } + /// Export the backward pass of a function that is already exported. pub fn export(&mut self, function: impl ToString, gradient: impl ToString) { - self.exports + self.config + .exports .insert(function.to_string(), gradient.to_string()); } /// Transform a WebAssembly module using this configuration. pub fn transform(self, wasm_module: &[u8]) -> Result, Error> { - let mut types = TypeSection::new(); - // Types for helper functions to push a floating-point values onto the tape. - types.ty().func_type(&wasm_encoder::FuncType::new( - [wasm_encoder::ValType::F32], - [wasm_encoder::ValType::F32], - )); - types.ty().func_type(&wasm_encoder::FuncType::new( - [wasm_encoder::ValType::F64], - [wasm_encoder::ValType::F64], - )); - assert_eq!(types.len(), OFFSET_TYPES); - let mut functions = FunctionSection::new(); - // Type indices for the tape helper functions. - functions.function(0); - functions.function(1); - assert_eq!(functions.len(), OFFSET_FUNCTIONS); - let mut memories = MemorySection::new(); - // The first memory is always the tape, so it is possible to translate function bodies - // without knowing the total number of memories. - memories.memory(wasm_encoder::MemoryType { - minimum: 0, - maximum: None, - memory64: false, - shared: false, - page_size_log2: None, - }); - assert_eq!(memories.len(), OFFSET_MEMORIES); - let mut globals = GlobalSection::new(); - // The first global is always the tape pointer. - globals.global( - wasm_encoder::GlobalType { - val_type: wasm_encoder::ValType::I32, - mutable: true, - shared: false, - }, - &wasm_encoder::ConstExpr::i32_const(0), - ); - assert_eq!(globals.len(), OFFSET_GLOBALS); - let mut exports = ExportSection::new(); - let mut code = CodeSection::new(); - code.function(&tee_f32()); - code.function(&tee_f64()); - assert_eq!(code.len(), OFFSET_FUNCTIONS); - let mut validator = Validator::new_with_features(features()); - for payload in Parser::new(0).parse_all(wasm_module) { - match payload? { - Payload::TypeSection(section) => { - validator.type_section(§ion)?; - for func_ty in section.into_iter_err_on_gc_types() { - let ty = RoundtripReencoder.func_type(func_ty?)?; - // Forward pass: same type as the original function. For integers, all the - // adjoint values are assumed to be equal to the primal values (e.g. - // pointers, because of our multi-memory strategy), and for floating point, - // all the adjoint values are assumed to be zero. - types.ty().func_type(&ty); - // Backward pass: results become parameters, and parameters become results. - types.ty().func_type(&wasm_encoder::FuncType::new( - ty.results().iter().copied(), - ty.params().iter().copied(), - )); - } - } - Payload::FunctionSection(section) => { - validator.function_section(§ion)?; - for type_index in section { - let t = type_index?; - // Index arithmetic to account for the fact that we split each original - // function type into two; similarly, we also split each actual function - // into two. - functions.function(OFFSET_TYPES + 2 * t); - functions.function(OFFSET_TYPES + 2 * t + 1); - } - } - Payload::MemorySection(section) => { - validator.memory_section(§ion)?; - for memory_ty in section { - let memory_type = RoundtripReencoder.memory_type(memory_ty?); - memories.memory(memory_type); - // Duplicate the memory to store adjoint values. - memories.memory(memory_type); - } - } - Payload::GlobalSection(section) => { - validator.global_section(§ion)?; - for global in section { - let g = global?; - globals.global( - RoundtripReencoder.global_type(g.ty)?, - &RoundtripReencoder.const_expr(g.init_expr)?, - ); - } - } - Payload::ExportSection(section) => { - validator.export_section(§ion)?; - for export in section { - let e = export?; - let kind = RoundtripReencoder.export_kind(e.kind); - match kind { - ExportKind::Func => { - // More index arithmetic because we split every function into a - // forward pass and a backward pass. - exports.export(e.name, kind, OFFSET_FUNCTIONS + 2 * e.index); - if let Some(name) = self.exports.get(e.name) { - // TODO: Should we check that no export with this name already - // exists? - exports.export(name, kind, OFFSET_FUNCTIONS + 2 * e.index + 1); - } - } - ExportKind::Memory => { - exports.export(e.name, kind, OFFSET_MEMORIES + 2 * e.index); - } - _ => { - exports.export(e.name, kind, e.index); - } - } - } - } - Payload::CodeSectionEntry(body) => { - let func = validator.code_section_entry(&body)?; - let (fwd, bwd) = function(body, func)?; - code.function(&fwd); - code.function(&bwd); - } - other => { - validator.payload(&other)?; - } - } - } - let mut module = Module::new(); - module.section(&types); - module.section(&functions); - module.section(&memories); - module.section(&globals); - module.section(&exports); - module.section(&code); - Ok(module.finish()) + self.runner.transform(self.config, wasm_module) } } -fn features() -> WasmFeatures { - WasmFeatures::empty() | WasmFeatures::FLOATS +trait Runner { + fn transform(&self, config: Config, wasm_module: &[u8]) -> Result, Error>; } -const OFFSET_TYPES: u32 = 2; -const OFFSET_FUNCTIONS: u32 = 2; -const OFFSET_MEMORIES: u32 = 1; -const OFFSET_GLOBALS: u32 = 1; +// We make `Runner` a `trait` instead of just an `enum`, to facilitate dead code elimination when +// validation is not needed. -fn tee_f32() -> Function { - let mut f = Function::new([(1, wasm_encoder::ValType::I32)]); - f.instruction(&Instruction::GlobalGet(0)); - f.instruction(&Instruction::LocalTee(1)); - f.instruction(&Instruction::LocalGet(0)); - f.instruction(&Instruction::F32Store(wasm_encoder::MemArg { - offset: 0, - align: 2, - memory_index: 0, - })); - f.instruction(&Instruction::LocalGet(1)); - f.instruction(&Instruction::I32Const(4)); - f.instruction(&Instruction::I32Add); - f.instruction(&Instruction::GlobalSet(0)); - f.instruction(&Instruction::LocalGet(0)); - f.instruction(&Instruction::End); - f -} +struct Validate; -fn tee_f64() -> Function { - let mut f = Function::new([(1, wasm_encoder::ValType::I32)]); - f.instruction(&Instruction::GlobalGet(0)); - f.instruction(&Instruction::LocalTee(1)); - f.instruction(&Instruction::LocalGet(0)); - f.instruction(&Instruction::F64Store(wasm_encoder::MemArg { - offset: 0, - align: 3, - memory_index: 0, - })); - f.instruction(&Instruction::LocalGet(1)); - f.instruction(&Instruction::I32Const(8)); - f.instruction(&Instruction::I32Add); - f.instruction(&Instruction::GlobalSet(0)); - f.instruction(&Instruction::LocalGet(0)); - f.instruction(&Instruction::End); - f -} +struct NoValidate; -fn function( - body: FunctionBody, - func: FuncToValidate, -) -> Result<(Function, Function), Error> { - let mut validator = func.into_validator(FuncValidatorAllocations::default()); - let mut locals = Vec::new(); - let mut locals_reader = body.get_locals_reader()?; - for _ in 0..locals_reader.get_count() { - let offset = locals_reader.original_position(); - let (count, ty) = locals_reader.read()?; - validator.define_locals(offset, count, ty)?; - locals.push((count, RoundtripReencoder.val_type(ty)?)); +impl Runner for Validate { + fn transform(&self, config: Config, wasm_module: &[u8]) -> Result, Error> { + let features = WasmFeatures::empty() | WasmFeatures::FLOATS; + let validator = Validator::new_with_features(features); + run::transform(validator, config, wasm_module) } - let mut fwd = Function::new(locals); - let mut bwd = Function::new([]); - let mut operators_reader = body.get_operators_reader()?; - while !operators_reader.eof() { - let (op, offset) = operators_reader.read_with_offset()?; - validator.op(offset, &op)?; - match op { - Operator::End => { - fwd.instruction(&Instruction::End); - bwd.instruction(&Instruction::End); - } - Operator::LocalGet { .. } => { - // TODO: Don't just hardcode constant return values. - } - Operator::F64Mul => { - fwd.instruction(&Instruction::F64Const(9.)); - bwd.instruction(&Instruction::F64Const(6.)); - } - _ => todo!(), - } +} + +impl Runner for NoValidate { + fn transform(&self, config: Config, wasm_module: &[u8]) -> Result, Error> { + run::transform((), config, wasm_module) } - validator.finish(operators_reader.original_position())?; - Ok((fwd, bwd)) } #[cfg(test)] diff --git a/crates/floretta/src/run.rs b/crates/floretta/src/run.rs new file mode 100644 index 0000000..f1a3df2 --- /dev/null +++ b/crates/floretta/src/run.rs @@ -0,0 +1,231 @@ +use wasm_encoder::{ + reencode::{Reencode, RoundtripReencoder}, + CodeSection, ExportKind, ExportSection, Function, FunctionSection, GlobalSection, Instruction, + MemorySection, Module, TypeSection, +}; +use wasmparser::{FunctionBody, Operator, Parser, Payload}; + +use crate::{ + validate::{FunctionValidator, ModuleValidator}, + Config, Error, +}; + +pub fn transform( + mut validator: impl ModuleValidator, + config: Config, + wasm_module: &[u8], +) -> Result, Error> { + let mut types = TypeSection::new(); + // Types for helper functions to push a floating-point values onto the tape. + types.ty().func_type(&wasm_encoder::FuncType::new( + [wasm_encoder::ValType::F32], + [wasm_encoder::ValType::F32], + )); + types.ty().func_type(&wasm_encoder::FuncType::new( + [wasm_encoder::ValType::F64], + [wasm_encoder::ValType::F64], + )); + assert_eq!(types.len(), OFFSET_TYPES); + let mut functions = FunctionSection::new(); + // Type indices for the tape helper functions. + functions.function(0); + functions.function(1); + assert_eq!(functions.len(), OFFSET_FUNCTIONS); + let mut memories = MemorySection::new(); + // The first memory is always the tape, so it is possible to translate function bodies + // without knowing the total number of memories. + memories.memory(wasm_encoder::MemoryType { + minimum: 0, + maximum: None, + memory64: false, + shared: false, + page_size_log2: None, + }); + assert_eq!(memories.len(), OFFSET_MEMORIES); + let mut globals = GlobalSection::new(); + // The first global is always the tape pointer. + globals.global( + wasm_encoder::GlobalType { + val_type: wasm_encoder::ValType::I32, + mutable: true, + shared: false, + }, + &wasm_encoder::ConstExpr::i32_const(0), + ); + assert_eq!(globals.len(), OFFSET_GLOBALS); + let mut exports = ExportSection::new(); + let mut code = CodeSection::new(); + code.function(&tee_f32()); + code.function(&tee_f64()); + assert_eq!(code.len(), OFFSET_FUNCTIONS); + for payload in Parser::new(0).parse_all(wasm_module) { + match payload? { + Payload::TypeSection(section) => { + validator.type_section(§ion)?; + for func_ty in section.into_iter_err_on_gc_types() { + let ty = RoundtripReencoder.func_type(func_ty?)?; + // Forward pass: same type as the original function. For integers, all the + // adjoint values are assumed to be equal to the primal values (e.g. + // pointers, because of our multi-memory strategy), and for floating point, + // all the adjoint values are assumed to be zero. + types.ty().func_type(&ty); + // Backward pass: results become parameters, and parameters become results. + types.ty().func_type(&wasm_encoder::FuncType::new( + ty.results().iter().copied(), + ty.params().iter().copied(), + )); + } + } + Payload::FunctionSection(section) => { + validator.function_section(§ion)?; + for type_index in section { + let t = type_index?; + // Index arithmetic to account for the fact that we split each original + // function type into two; similarly, we also split each actual function + // into two. + functions.function(OFFSET_TYPES + 2 * t); + functions.function(OFFSET_TYPES + 2 * t + 1); + } + } + Payload::MemorySection(section) => { + validator.memory_section(§ion)?; + for memory_ty in section { + let memory_type = RoundtripReencoder.memory_type(memory_ty?); + memories.memory(memory_type); + // Duplicate the memory to store adjoint values. + memories.memory(memory_type); + } + } + Payload::GlobalSection(section) => { + validator.global_section(§ion)?; + for global in section { + let g = global?; + globals.global( + RoundtripReencoder.global_type(g.ty)?, + &RoundtripReencoder.const_expr(g.init_expr)?, + ); + } + } + Payload::ExportSection(section) => { + validator.export_section(§ion)?; + for export in section { + let e = export?; + let kind = RoundtripReencoder.export_kind(e.kind); + match kind { + ExportKind::Func => { + // More index arithmetic because we split every function into a + // forward pass and a backward pass. + exports.export(e.name, kind, OFFSET_FUNCTIONS + 2 * e.index); + if let Some(name) = config.exports.get(e.name) { + // TODO: Should we check that no export with this name already + // exists? + exports.export(name, kind, OFFSET_FUNCTIONS + 2 * e.index + 1); + } + } + ExportKind::Memory => { + exports.export(e.name, kind, OFFSET_MEMORIES + 2 * e.index); + } + _ => { + exports.export(e.name, kind, e.index); + } + } + } + } + Payload::CodeSectionEntry(body) => { + let func = validator.code_section_entry(&body)?; + let (fwd, bwd) = function(func, body)?; + code.function(&fwd); + code.function(&bwd); + } + other => validator.payload(&other)?, + } + } + let mut module = Module::new(); + module.section(&types); + module.section(&functions); + module.section(&memories); + module.section(&globals); + module.section(&exports); + module.section(&code); + Ok(module.finish()) +} + +const OFFSET_TYPES: u32 = 2; +const OFFSET_FUNCTIONS: u32 = 2; +const OFFSET_MEMORIES: u32 = 1; +const OFFSET_GLOBALS: u32 = 1; + +fn tee_f32() -> Function { + let mut f = Function::new([(1, wasm_encoder::ValType::I32)]); + f.instruction(&Instruction::GlobalGet(0)); + f.instruction(&Instruction::LocalTee(1)); + f.instruction(&Instruction::LocalGet(0)); + f.instruction(&Instruction::F32Store(wasm_encoder::MemArg { + offset: 0, + align: 2, + memory_index: 0, + })); + f.instruction(&Instruction::LocalGet(1)); + f.instruction(&Instruction::I32Const(4)); + f.instruction(&Instruction::I32Add); + f.instruction(&Instruction::GlobalSet(0)); + f.instruction(&Instruction::LocalGet(0)); + f.instruction(&Instruction::End); + f +} + +fn tee_f64() -> Function { + let mut f = Function::new([(1, wasm_encoder::ValType::I32)]); + f.instruction(&Instruction::GlobalGet(0)); + f.instruction(&Instruction::LocalTee(1)); + f.instruction(&Instruction::LocalGet(0)); + f.instruction(&Instruction::F64Store(wasm_encoder::MemArg { + offset: 0, + align: 3, + memory_index: 0, + })); + f.instruction(&Instruction::LocalGet(1)); + f.instruction(&Instruction::I32Const(8)); + f.instruction(&Instruction::I32Add); + f.instruction(&Instruction::GlobalSet(0)); + f.instruction(&Instruction::LocalGet(0)); + f.instruction(&Instruction::End); + f +} + +fn function( + mut validator: impl FunctionValidator, + body: FunctionBody, +) -> Result<(Function, Function), Error> { + let mut locals = Vec::new(); + let mut locals_reader = body.get_locals_reader()?; + for _ in 0..locals_reader.get_count() { + let offset = locals_reader.original_position(); + let (count, ty) = locals_reader.read()?; + validator.define_locals(offset, count, ty)?; + locals.push((count, RoundtripReencoder.val_type(ty)?)); + } + let mut fwd = Function::new(locals); + let mut bwd = Function::new([]); + let mut operators_reader = body.get_operators_reader()?; + while !operators_reader.eof() { + let (op, offset) = operators_reader.read_with_offset()?; + validator.op(offset, &op)?; + match op { + Operator::End => { + fwd.instruction(&Instruction::End); + bwd.instruction(&Instruction::End); + } + Operator::LocalGet { .. } => { + // TODO: Don't just hardcode constant return values. + } + Operator::F64Mul => { + fwd.instruction(&Instruction::F64Const(9.)); + bwd.instruction(&Instruction::F64Const(6.)); + } + _ => todo!(), + } + } + validator.finish(operators_reader.original_position())?; + Ok((fwd, bwd)) +} diff --git a/crates/floretta/src/validate.rs b/crates/floretta/src/validate.rs new file mode 100644 index 0000000..e790803 --- /dev/null +++ b/crates/floretta/src/validate.rs @@ -0,0 +1,142 @@ +use wasmparser::{ + ExportSectionReader, FuncValidator, FuncValidatorAllocations, FunctionBody, + FunctionSectionReader, GlobalSectionReader, MemorySectionReader, Operator, Payload, + TypeSectionReader, Validator, ValidatorResources, WasmModuleResources, +}; + +/// Trait counterpart to [`wasmparser::Validator`]. +pub trait ModuleValidator { + type Func: FunctionValidator; + + fn payload(&mut self, payload: &Payload) -> wasmparser::Result<()>; + + fn type_section(&mut self, section: &TypeSectionReader) -> wasmparser::Result<()>; + + fn function_section(&mut self, section: &FunctionSectionReader) -> wasmparser::Result<()>; + + fn memory_section(&mut self, section: &MemorySectionReader) -> wasmparser::Result<()>; + + fn global_section(&mut self, section: &GlobalSectionReader) -> wasmparser::Result<()>; + + fn export_section(&mut self, section: &ExportSectionReader) -> wasmparser::Result<()>; + + fn code_section_entry(&mut self, body: &FunctionBody) -> wasmparser::Result; +} + +/// Trait counterpart to [`wasmparser::FuncValidator`]. +pub trait FunctionValidator { + fn define_locals( + &mut self, + offset: usize, + count: u32, + ty: wasmparser::ValType, + ) -> wasmparser::Result<()>; + + fn op(&mut self, offset: usize, operator: &Operator) -> wasmparser::Result<()>; + + fn finish(&mut self, offset: usize) -> wasmparser::Result<()>; +} + +impl ModuleValidator for () { + type Func = (); + + fn payload(&mut self, _: &Payload) -> wasmparser::Result<()> { + Ok(()) + } + + fn type_section(&mut self, _: &TypeSectionReader) -> wasmparser::Result<()> { + Ok(()) + } + + fn function_section(&mut self, _: &FunctionSectionReader) -> wasmparser::Result<()> { + Ok(()) + } + + fn memory_section(&mut self, _: &MemorySectionReader) -> wasmparser::Result<()> { + Ok(()) + } + + fn global_section(&mut self, _: &GlobalSectionReader) -> wasmparser::Result<()> { + Ok(()) + } + + fn export_section(&mut self, _: &ExportSectionReader) -> wasmparser::Result<()> { + Ok(()) + } + + fn code_section_entry(&mut self, _: &FunctionBody) -> wasmparser::Result { + Ok(()) + } +} + +impl FunctionValidator for () { + fn define_locals( + &mut self, + _: usize, + _: u32, + _: wasmparser::ValType, + ) -> wasmparser::Result<()> { + Ok(()) + } + + fn op(&mut self, _: usize, _: &Operator) -> wasmparser::Result<()> { + Ok(()) + } + + fn finish(&mut self, _: usize) -> wasmparser::Result<()> { + Ok(()) + } +} + +impl ModuleValidator for Validator { + type Func = FuncValidator; + + fn payload(&mut self, payload: &Payload) -> wasmparser::Result<()> { + self.payload(payload)?; + Ok(()) + } + + fn type_section(&mut self, section: &TypeSectionReader) -> wasmparser::Result<()> { + self.type_section(section) + } + + fn function_section(&mut self, section: &FunctionSectionReader) -> wasmparser::Result<()> { + self.function_section(section) + } + + fn memory_section(&mut self, section: &MemorySectionReader) -> wasmparser::Result<()> { + self.memory_section(section) + } + + fn global_section(&mut self, section: &GlobalSectionReader) -> wasmparser::Result<()> { + self.global_section(section) + } + + fn export_section(&mut self, section: &ExportSectionReader) -> wasmparser::Result<()> { + self.export_section(section) + } + + fn code_section_entry(&mut self, body: &FunctionBody) -> wasmparser::Result { + let func = self.code_section_entry(&body)?; + Ok(func.into_validator(FuncValidatorAllocations::default())) + } +} + +impl FunctionValidator for FuncValidator { + fn define_locals( + &mut self, + offset: usize, + count: u32, + ty: wasmparser::ValType, + ) -> wasmparser::Result<()> { + self.define_locals(offset, count, ty) + } + + fn op(&mut self, offset: usize, operator: &Operator) -> wasmparser::Result<()> { + self.op(offset, operator) + } + + fn finish(&mut self, offset: usize) -> wasmparser::Result<()> { + self.finish(offset) + } +}