From 811d645e7a16b5a5f315f0748fd217ca79b46a7c Mon Sep 17 00:00:00 2001 From: Thomas Lively Date: Tue, 16 Dec 2025 13:56:38 -0800 Subject: [PATCH] Avoid public type collisions in GlobalTypeRewriter Move BrandTypeIterator from MinimizeRecGroups to wasm-type-shape.h and use it in a new UniqueRecGroups utility that can rebuild types to be distinct from previously seen rec groups. Use UniqueRecGroups in GlobalTypeRewriter to ensure the newly built private types do not conflict with public types. Split off from #8119 because this can land sooner. --- src/ir/type-updating.cpp | 40 +++++--- src/ir/type-updating.h | 10 ++ src/passes/MinimizeRecGroups.cpp | 78 --------------- src/wasm-type-shape.h | 98 +++++++++++++++++++ src/wasm/wasm-type-shape.cpp | 35 +++++++ .../signature-pruning-public-collision.wast | 59 +++++++++++ 6 files changed, 227 insertions(+), 93 deletions(-) create mode 100644 test/lit/passes/signature-pruning-public-collision.wast diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index 4870152d74f..d402cd221cc 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -26,24 +26,37 @@ namespace wasm { -GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} - -void GlobalTypeRewriter::update() { - mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors()))); -} - -GlobalTypeRewriter::PredecessorGraph -GlobalTypeRewriter::getPrivatePredecessors() { +GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) + : wasm(wasm), publicGroups(wasm.features) { // Find the heap types that are not publicly observable. Even in a closed // world scenario, don't modify public types because we assume that they may // be reflected on or used for linking. Figure out where each private type // will be located in the builder. - auto typeInfo = ModuleUtils::collectHeapTypeInfo( + typeInfo = ModuleUtils::collectHeapTypeInfo( wasm, ModuleUtils::TypeInclusion::UsedIRTypes, ModuleUtils::VisibilityHandling::FindVisibility); - // Check if a type is private, by looking up its info. + std::unordered_set seenGroups; + for (auto& [type, info] : typeInfo) { + if (info.visibility == ModuleUtils::Visibility::Public) { + auto group = type.getRecGroup(); + if (seenGroups.insert(type.getRecGroup()).second) { + std::vector groupTypes(group.begin(), group.end()); + publicGroups.insert(std::move(groupTypes)); + } + } + } +} + +void GlobalTypeRewriter::update() { + mapTypes(rebuildTypes(getSortedTypes(getPrivatePredecessors()))); +} + +GlobalTypeRewriter::PredecessorGraph +GlobalTypeRewriter::getPrivatePredecessors() { + // Check if a type is private, looking for its info (if there is none, it is + // not private). auto isPublic = [&](HeapType type) { auto it = typeInfo.find(type); assert(it != typeInfo.end()); @@ -185,11 +198,8 @@ GlobalTypeRewriter::rebuildTypes(std::vector types) { << " at index " << err->index; } #endif - auto& newTypes = *buildResults; - - // TODO: It is possible that the newly built rec group matches some public rec - // group. If that is the case, we need to try a different permutation of the - // types or add a brand type to distinguish the private types. + // Ensure the new types are different from any public rec group. + const auto& newTypes = publicGroups.insert(*buildResults); // Map the old types to the new ones. TypeMap oldToNewTypes; diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index c01a97a0977..3333b288461 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -18,8 +18,11 @@ #define wasm_ir_type_updating_h #include "ir/branch-utils.h" +#include "ir/module-utils.h" #include "support/insert_ordered.h" #include "wasm-traversal.h" +#include "wasm-type-shape.h" +#include "wasm-type.h" namespace wasm { @@ -348,6 +351,13 @@ class GlobalTypeRewriter { Module& wasm; + // The module's types and their visibilities. + InsertOrderedMap typeInfo; + + // The shapes of public rec groups, so we can be sure that the rewritten + // private types do not conflict with public types. + UniqueRecGroups publicGroups; + GlobalTypeRewriter(Module& wasm); virtual ~GlobalTypeRewriter() {} diff --git a/src/passes/MinimizeRecGroups.cpp b/src/passes/MinimizeRecGroups.cpp index 13f569e464a..6a3c3223034 100644 --- a/src/passes/MinimizeRecGroups.cpp +++ b/src/passes/MinimizeRecGroups.cpp @@ -100,84 +100,6 @@ struct TypeSCCs } }; -// After all their permutations with distinct shapes have been used, different -// groups with the same shapes must be differentiated by adding in a "brand" -// type. Even with a brand mixed in, we might run out of permutations with -// distinct shapes, in which case we need a new brand type. This iterator -// provides an infinite sequence of possible brand types, prioritizing those -// with the most compact encoding. -struct BrandTypeIterator { - static constexpr Index optionCount = 18; - static constexpr std::array fieldOptions = {{ - Field(Field::i8, Mutable), - Field(Field::i16, Mutable), - Field(Type::i32, Mutable), - Field(Type::i64, Mutable), - Field(Type::f32, Mutable), - Field(Type::f64, Mutable), - Field(Type(HeapType::any, Nullable), Mutable), - Field(Type(HeapType::func, Nullable), Mutable), - Field(Type(HeapType::ext, Nullable), Mutable), - Field(Type(HeapType::none, Nullable), Mutable), - Field(Type(HeapType::nofunc, Nullable), Mutable), - Field(Type(HeapType::noext, Nullable), Mutable), - Field(Type(HeapType::any, NonNullable), Mutable), - Field(Type(HeapType::func, NonNullable), Mutable), - Field(Type(HeapType::ext, NonNullable), Mutable), - Field(Type(HeapType::none, NonNullable), Mutable), - Field(Type(HeapType::nofunc, NonNullable), Mutable), - Field(Type(HeapType::noext, NonNullable), Mutable), - }}; - - struct FieldInfo { - uint8_t index = 0; - bool immutable = false; - - operator Field() const { - auto field = fieldOptions[index]; - if (immutable) { - field.mutable_ = Immutable; - } - return field; - } - - bool advance() { - if (!immutable) { - immutable = true; - return true; - } - immutable = false; - index = (index + 1) % optionCount; - return index != 0; - } - }; - - bool useArray = false; - std::vector fields; - - HeapType operator*() const { - if (useArray) { - return Array(fields[0]); - } - return Struct(std::vector(fields.begin(), fields.end())); - } - - BrandTypeIterator& operator++() { - for (Index i = fields.size(); i > 0; --i) { - if (fields[i - 1].advance()) { - return *this; - } - } - if (useArray) { - useArray = false; - return *this; - } - fields.emplace_back(); - useArray = fields.size() == 1; - return *this; - } -}; - // Create an adjacency list with edges from supertype to subtype and from // described type to descriptor. std::vector> diff --git a/src/wasm-type-shape.h b/src/wasm-type-shape.h index e72f28dd530..b649e1b72db 100644 --- a/src/wasm-type-shape.h +++ b/src/wasm-type-shape.h @@ -18,6 +18,8 @@ #define wasm_wasm_type_shape_h #include +#include +#include #include #include "wasm-features.h" @@ -79,4 +81,100 @@ template<> class hash { } // namespace std +namespace wasm { + +// Provides an infinite sequence of possible brand types, prioritizing those +// with the most compact encoding. +struct BrandTypeIterator { + static constexpr Index optionCount = 18; + static constexpr std::array fieldOptions = {{ + Field(Field::i8, Mutable), + Field(Field::i16, Mutable), + Field(Type::i32, Mutable), + Field(Type::i64, Mutable), + Field(Type::f32, Mutable), + Field(Type::f64, Mutable), + Field(Type(HeapType::any, Nullable), Mutable), + Field(Type(HeapType::func, Nullable), Mutable), + Field(Type(HeapType::ext, Nullable), Mutable), + Field(Type(HeapType::none, Nullable), Mutable), + Field(Type(HeapType::nofunc, Nullable), Mutable), + Field(Type(HeapType::noext, Nullable), Mutable), + Field(Type(HeapType::any, NonNullable), Mutable), + Field(Type(HeapType::func, NonNullable), Mutable), + Field(Type(HeapType::ext, NonNullable), Mutable), + Field(Type(HeapType::none, NonNullable), Mutable), + Field(Type(HeapType::nofunc, NonNullable), Mutable), + Field(Type(HeapType::noext, NonNullable), Mutable), + }}; + + struct FieldInfo { + uint8_t index = 0; + bool immutable = false; + + operator Field() const { + auto field = fieldOptions[index]; + if (immutable) { + field.mutable_ = Immutable; + } + return field; + } + + bool advance() { + if (!immutable) { + immutable = true; + return true; + } + immutable = false; + index = (index + 1) % optionCount; + return index != 0; + } + }; + + bool useArray = false; + std::vector fields; + + HeapType operator*() const { + if (useArray) { + return Array(fields[0]); + } + return Struct(std::vector(fields.begin(), fields.end())); + } + + BrandTypeIterator& operator++() { + for (Index i = fields.size(); i > 0; --i) { + if (fields[i - 1].advance()) { + return *this; + } + } + if (useArray) { + useArray = false; + return *this; + } + fields.emplace_back(); + useArray = fields.size() == 1; + return *this; + } +}; + +// A set of unique rec group shapes. Upon inserting a new group of types, if it +// has the same shape as a previously inserted group, the types will be rebuilt +// with an extra brand type at the end of the group that differentiates it from +// previous group. +struct UniqueRecGroups { + std::list> groups; + std::unordered_set shapes; + + FeatureSet features; + + UniqueRecGroups(FeatureSet features) : features(features) {} + + // Insert a rec group. If it is already unique, return the original types. + // Otherwise rebuild the group make it unique and return the rebuilt types, + // including the brand. + const std::vector& insert(std::vector group); +}; + +} // namespace wasm + #endif // wasm_wasm_type_shape_h diff --git a/src/wasm/wasm-type-shape.cpp b/src/wasm/wasm-type-shape.cpp index 5541d8b72a1..d2de6505d44 100644 --- a/src/wasm/wasm-type-shape.cpp +++ b/src/wasm/wasm-type-shape.cpp @@ -370,6 +370,41 @@ bool ComparableRecGroupShape::operator>(const RecGroupShape& other) const { return GT == compareComparable(*this, other); } +const std::vector& +UniqueRecGroups::insert(std::vector types) { + auto& group = *groups.emplace(groups.end(), std::move(types)); + if (shapes.emplace(RecGroupShape(group, features)).second) { + // The types are already unique. + return group; + } + // There is a conflict. Find a brand that makes the group unique. + BrandTypeIterator brand; + group.push_back(*brand); + while (!shapes.emplace(RecGroupShape(group, features)).second) { + group.back() = *++brand; + } + // Rebuild the rec group to include the brand. Map the old types (excluding + // the brand) to their corresponding new types to preserve recursions within + // the group. + Index size = group.size(); + TypeBuilder builder(size); + std::unordered_map newTypes; + for (Index i = 0; i < size - 1; ++i) { + newTypes[group[i]] = builder[i]; + } + for (Index i = 0; i < size; ++i) { + builder[i].copy(group[i], [&](HeapType type) { + if (auto newType = newTypes.find(type); newType != newTypes.end()) { + return newType->second; + } + return type; + }); + } + builder.createRecGroup(0, size); + group = *builder.build(); + return group; +} + } // namespace wasm namespace std { diff --git a/test/lit/passes/signature-pruning-public-collision.wast b/test/lit/passes/signature-pruning-public-collision.wast new file mode 100644 index 00000000000..4b23ed900db --- /dev/null +++ b/test/lit/passes/signature-pruning-public-collision.wast @@ -0,0 +1,59 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; RUN: wasm-opt %s -all --closed-world --signature-pruning --fuzz-exec -S -o - | filecheck %s + +(module + ;; CHECK: (type $public (func)) + + ;; CHECK: (rec + ;; CHECK-NEXT: (type $private (func)) + + ;; CHECK: (type $2 (struct)) + + ;; CHECK: (type $test (func (result i32))) + (type $test (func (result i32))) + + (type $public (func)) + + ;; After signature pruning this will be (func), which is the same as $public. + ;; We must make sure we keep $private a distinct type. + (type $private (func (param i32))) + + ;; CHECK: (import "" "" (func $public (type $public))) + (import "" "" (func $public (type $public))) + + ;; CHECK: (elem declare func $public) + + ;; CHECK: (export "test" (func $test)) + + ;; CHECK: (func $private (type $private) + ;; CHECK-NEXT: (local $0 i32) + ;; CHECK-NEXT: (nop) + ;; CHECK-NEXT: ) + (func $private (type $private) (param $unused i32) + (nop) + ) + + ;; CHECK: (func $test (type $test) (result i32) + ;; CHECK-NEXT: (local $0 funcref) + ;; CHECK-NEXT: (ref.test (ref $private) + ;; CHECK-NEXT: (select (result funcref) + ;; CHECK-NEXT: (ref.func $public) + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $test (export "test") (type $test) (result i32) + (local funcref) + ;; Test that $private and $public are separate types. This should return 0. + (ref.test (ref $private) + ;; Use select to prevent the ref.test from being optimized in + ;; finalization. + (select (result funcref) + (ref.func $public) + (local.get 0) + (i32.const 1) + ) + ) + ) +)