Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support global.get in more constant expressions #7996

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions crates/environ/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ impl ModuleTranslation<'_> {
// Get the end of this segment. If out-of-bounds, or too
// large for our dense table representation, then skip the
// segment.
let top = match segment.offset.checked_add(segment.elements.len() as u32) {
let top = match segment.offset.checked_add(segment.elements.len()) {
Some(top) => top,
None => break,
};
Expand All @@ -482,6 +482,13 @@ impl ModuleTranslation<'_> {
WasmHeapType::Extern => break,
}

// Function indices can be optimized here, but fully general
// expressions are deferred to get evaluated at runtime.
let function_elements = match &segment.elements {
TableSegmentElements::Functions(indices) => indices,
TableSegmentElements::Expressions(_) => break,
};

let precomputed =
match &mut self.module.table_initialization.initial_values[defined_index] {
TableInitialValue::Null { precomputed } => precomputed,
Expand All @@ -492,7 +499,7 @@ impl ModuleTranslation<'_> {
// Technically this won't trap so it's possible to process
// further initializers, but that's left as a future
// optimization.
TableInitialValue::FuncRef(_) => break,
TableInitialValue::FuncRef(_) | TableInitialValue::GlobalGet(_) => break,
};

// At this point we're committing to pre-initializing the table
Expand All @@ -504,7 +511,7 @@ impl ModuleTranslation<'_> {
precomputed.resize(top as usize, FuncIndex::reserved_value());
}
let dst = &mut precomputed[(segment.offset as usize)..(top as usize)];
dst.copy_from_slice(&segment.elements[..]);
dst.copy_from_slice(&function_elements);

// advance the iterator to see the next segment
let _ = segments.next();
Expand Down Expand Up @@ -757,6 +764,10 @@ pub enum TableInitialValue {
/// Initialize each table element to the function reference given
/// by the `FuncIndex`.
FuncRef(FuncIndex),

/// At instantiation time this global is loaded and the funcref value is
/// used to initialize the table.
GlobalGet(GlobalIndex),
}

/// A WebAssembly table initializer segment.
Expand All @@ -769,7 +780,39 @@ pub struct TableSegment {
/// The offset to add to the base.
pub offset: u32,
/// The values to write into the table elements.
pub elements: Box<[FuncIndex]>,
pub elements: TableSegmentElements,
}

/// Elements of a table segment, either a list of functions or list of arbitrary
/// expressions.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TableSegmentElements {
/// A sequential list of functions where `FuncIndex::reserved_value()`
/// indicates a null function.
Functions(Box<[FuncIndex]>),
/// Arbitrary expressions, aka either functions, null or a load of a global.
Expressions(Box<[TableElementExpression]>),
}

impl TableSegmentElements {
/// Returns the number of elements in this segment.
pub fn len(&self) -> u32 {
match self {
Self::Functions(s) => s.len() as u32,
Self::Expressions(s) => s.len() as u32,
}
}
}

/// Different kinds of expression that can initialize table elements.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TableElementExpression {
/// `ref.func $f`
Function(FuncIndex),
/// `global.get $g`
GlobalGet(GlobalIndex),
/// `ref.null $ty`
Null,
}

/// Different types that can appear in a module.
Expand Down Expand Up @@ -815,7 +858,7 @@ pub struct Module {
pub memory_initialization: MemoryInitialization,

/// WebAssembly passive elements.
pub passive_elements: Vec<Box<[FuncIndex]>>,
pub passive_elements: Vec<TableSegmentElements>,

/// The map from passive element index (element segment index space) to index in `passive_elements`.
pub passive_elements_map: BTreeMap<ElemIndex, usize>,
Expand Down
38 changes: 24 additions & 14 deletions crates/environ/src/module_environ.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use crate::module::{
FuncRefIndex, Initializer, MemoryInitialization, MemoryInitializer, MemoryPlan, Module,
ModuleType, TablePlan, TableSegment,
ModuleType, TableElementExpression, TablePlan, TableSegment, TableSegmentElements,
};
use crate::{
DataIndex, DefinedFuncIndex, ElemIndex, EntityIndex, EntityType, FuncIndex, GlobalIndex,
GlobalInit, MemoryIndex, ModuleTypesBuilder, PrimaryMap, TableIndex, TableInitialValue,
Tunables, TypeConvert, TypeIndex, Unsigned, WasmError, WasmHeapType, WasmResult, WasmValType,
WasmparserTypeConverter,
};
use cranelift_entity::packed_option::ReservedValue;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
Expand Down Expand Up @@ -320,6 +319,10 @@ impl<'a, 'data> ModuleEnvironment<'a, 'data> {
self.flag_func_escaped(index);
TableInitialValue::FuncRef(index)
}
Operator::GlobalGet { global_index } => {
let index = GlobalIndex::from_u32(global_index);
TableInitialValue::GlobalGet(index)
}
s => {
return Err(WasmError::Unsupported(format!(
"unsupported init expr in table section: {:?}",
Expand Down Expand Up @@ -449,25 +452,31 @@ impl<'a, 'data> ModuleEnvironment<'a, 'data> {
// possible to create anything other than a `ref.null
// extern` for externref segments, so those just get
// translated to the reserved value of `FuncIndex`.
let mut elements = Vec::new();
match items {
let elements = match items {
ElementItems::Functions(funcs) => {
elements.reserve(usize::try_from(funcs.count()).unwrap());
let mut elems =
Vec::with_capacity(usize::try_from(funcs.count()).unwrap());
for func in funcs {
let func = FuncIndex::from_u32(func?);
self.flag_func_escaped(func);
elements.push(func);
elems.push(func);
}
TableSegmentElements::Functions(elems.into())
}
ElementItems::Expressions(_ty, funcs) => {
elements.reserve(usize::try_from(funcs.count()).unwrap());
for func in funcs {
let func = match func?.get_binary_reader().read_operator()? {
Operator::RefNull { .. } => FuncIndex::reserved_value(),
ElementItems::Expressions(_ty, items) => {
let mut exprs =
Vec::with_capacity(usize::try_from(items.count()).unwrap());
for expr in items {
let expr = match expr?.get_binary_reader().read_operator()? {
Operator::RefNull { .. } => TableElementExpression::Null,
Operator::RefFunc { function_index } => {
let func = FuncIndex::from_u32(function_index);
self.flag_func_escaped(func);
func
TableElementExpression::Function(func)
}
Operator::GlobalGet { global_index } => {
let global = GlobalIndex::from_u32(global_index);
TableElementExpression::GlobalGet(global)
}
s => {
return Err(WasmError::Unsupported(format!(
Expand All @@ -476,10 +485,11 @@ impl<'a, 'data> ModuleEnvironment<'a, 'data> {
)));
}
};
elements.push(func);
exprs.push(expr);
}
TableSegmentElements::Expressions(exprs.into())
}
}
};

match kind {
ElementKind::Active {
Expand Down
86 changes: 61 additions & 25 deletions crates/runtime/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ use wasmtime_environ::ModuleInternedTypeIndex;
use wasmtime_environ::{
packed_option::ReservedValue, DataIndex, DefinedGlobalIndex, DefinedMemoryIndex,
DefinedTableIndex, ElemIndex, EntityIndex, EntityRef, EntitySet, FuncIndex, GlobalIndex,
GlobalInit, HostPtr, MemoryIndex, MemoryPlan, Module, PrimaryMap, TableIndex,
TableInitialValue, Trap, VMOffsets, WasmHeapType, WasmRefType, WasmValType, VMCONTEXT_MAGIC,
GlobalInit, HostPtr, MemoryIndex, MemoryPlan, Module, PrimaryMap, TableElementExpression,
TableIndex, TableInitialValue, TableSegmentElements, Trap, VMOffsets, WasmHeapType,
WasmRefType, WasmValType, VMCONTEXT_MAGIC,
};
#[cfg(feature = "wmemcheck")]
use wasmtime_wmemcheck::Wmemcheck;
Expand Down Expand Up @@ -803,50 +804,83 @@ impl Instance {
// disconnected from the lifetime of `self`.
let module = self.module().clone();

let empty = TableSegmentElements::Functions(Box::new([]));
let elements = match module.passive_elements_map.get(&elem_index) {
Some(index) if !self.dropped_elements.contains(elem_index) => {
module.passive_elements[*index].as_ref()
&module.passive_elements[*index]
}
_ => &[],
_ => &empty,
};
self.table_init_segment(table_index, elements, dst, src, len)
}

pub(crate) fn table_init_segment(
&mut self,
table_index: TableIndex,
elements: &[FuncIndex],
elements: &TableSegmentElements,
dst: u32,
src: u32,
len: u32,
) -> Result<(), Trap> {
// https://webassembly.github.io/bulk-memory-operations/core/exec/instructions.html#exec-table-init

let table = unsafe { &mut *self.get_table(table_index) };

let elements = match elements
.get(usize::try_from(src).unwrap()..)
.and_then(|s| s.get(..usize::try_from(len).unwrap()))
{
Some(elements) => elements,
None => return Err(Trap::TableOutOfBounds),
};

match table.element_type() {
TableElementType::Func => {
table.init_funcs(
let src = usize::try_from(src).map_err(|_| Trap::TableOutOfBounds)?;
let len = usize::try_from(len).map_err(|_| Trap::TableOutOfBounds)?;

match elements {
TableSegmentElements::Functions(funcs) => {
let elements = funcs
.get(src..)
.and_then(|s| s.get(..len))
.ok_or(Trap::TableOutOfBounds)?;
table.init(
dst,
elements
.iter()
.map(|idx| self.get_func_ref(*idx).unwrap_or(std::ptr::null_mut())),
elements.iter().map(|idx| {
TableElement::FuncRef(
self.get_func_ref(*idx).unwrap_or(std::ptr::null_mut()),
)
}),
)?;
}

TableElementType::Extern => {
debug_assert!(elements.iter().all(|e| *e == FuncIndex::reserved_value()));
table.fill(dst, TableElement::ExternRef(None), len)?;
TableSegmentElements::Expressions(exprs) => {
let ty = table.element_type();
let exprs = exprs
.get(src..)
.and_then(|s| s.get(..len))
.ok_or(Trap::TableOutOfBounds)?;
table.init(
dst,
exprs.iter().map(|expr| match ty {
TableElementType::Func => {
let funcref = match expr {
TableElementExpression::Null => std::ptr::null_mut(),
TableElementExpression::Function(idx) => {
self.get_func_ref(*idx).unwrap()
}
TableElementExpression::GlobalGet(idx) => {
let global = self.defined_or_imported_global_ptr(*idx);
unsafe { (*global).as_func_ref() }
}
};
TableElement::FuncRef(funcref)
}
TableElementType::Extern => {
let externref = match expr {
TableElementExpression::Null => None,
TableElementExpression::Function(_) => unreachable!(),
TableElementExpression::GlobalGet(idx) => {
let global = self.defined_or_imported_global_ptr(*idx);
unsafe { (*global).as_externref().clone() }
}
};
TableElement::ExternRef(externref)
}
}),
)?;
}
}

Ok(())
}

Expand Down Expand Up @@ -1059,7 +1093,9 @@ impl Instance {
let module = self.module();
let precomputed = match &module.table_initialization.initial_values[idx] {
TableInitialValue::Null { precomputed } => precomputed,
TableInitialValue::FuncRef(_) => unreachable!(),
TableInitialValue::FuncRef(_) | TableInitialValue::GlobalGet(_) => {
unreachable!()
}
};
let func_index = precomputed.get(i as usize).cloned();
let func_ref = func_index
Expand Down
11 changes: 9 additions & 2 deletions crates/runtime/src/instance/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ fn check_table_init_bounds(instance: &mut Instance, module: &Module) -> Result<(
let table = unsafe { &*instance.get_table(segment.table_index) };
let start = get_table_init_start(segment, instance)?;
let start = usize::try_from(start).unwrap();
let end = start.checked_add(segment.elements.len());
let end = start.checked_add(usize::try_from(segment.elements.len()).unwrap());

match end {
Some(end) if end <= table.size() as usize => {
Expand All @@ -533,6 +533,13 @@ fn initialize_tables(instance: &mut Instance, module: &Module) -> Result<()> {
let table = unsafe { &mut *instance.get_defined_table(table) };
table.init_func(funcref)?;
}

TableInitialValue::GlobalGet(idx) => unsafe {
let global = instance.defined_or_imported_global_ptr(*idx);
let funcref = (*global).as_func_ref();
let table = &mut *instance.get_defined_table(table);
table.init_func(funcref)?;
},
}
}

Expand All @@ -550,7 +557,7 @@ fn initialize_tables(instance: &mut Instance, module: &Module) -> Result<()> {
&segment.elements,
start,
0,
segment.elements.len() as u32,
segment.elements.len(),
)?;
}

Expand Down
Loading