diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp index 1c30138eb5..05453b6f20 100644 --- a/source/opt/copy_prop_arrays.cpp +++ b/source/opt/copy_prop_arrays.cpp @@ -151,9 +151,17 @@ Instruction* CopyPropagateArrays::BuildNewAccessChain( return source->GetVariable(); } + source->BuildConstants(); + std::vector access_ids(source->AccessChain().size()); + std::transform( + source->AccessChain().cbegin(), source->AccessChain().cend(), + access_ids.begin(), [](const AccessChainEntry& entry) { + assert(entry.is_result_id && "Constants needs to be built first."); + return entry.result_id; + }); + return builder.AddAccessChain(source->GetPointerTypeId(this), - source->GetVariable()->result_id(), - source->AccessChain()); + source->GetVariable()->result_id(), access_ids); } bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) { @@ -270,30 +278,20 @@ std::unique_ptr CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) { assert(extract_inst->opcode() == SpvOpCompositeExtract && "Expecting an OpCompositeExtract instruction."); - analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); - std::unique_ptr result = GetSourceObjectIfAny( extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand)); - if (result) { - analysis::Integer int_type(32, false); - const analysis::Type* uint32_type = - context()->get_type_mgr()->GetRegisteredType(&int_type); - - std::vector components; - // Convert the indices in the extract instruction to a series of ids that - // can be used by the |OpAccessChain| instruction. - for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) { - uint32_t index = extract_inst->GetSingleWordInOperand(i); - const analysis::Constant* index_const = - const_mgr->GetConstant(uint32_type, {index}); - components.push_back( - const_mgr->GetDefiningInstruction(index_const)->result_id()); - } - result->GetMember(components); - return result; + if (!result) { + return nullptr; } - return nullptr; + + // Copy the indices of the extract instruction to |OpAccessChain| indices. + std::vector components; + for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) { + components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}}); + } + result->PushIndirection(components); + return result; } std::unique_ptr @@ -317,19 +315,12 @@ CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( return nullptr; } - analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); - const analysis::Constant* last_access = - const_mgr->FindDeclaredConstant(memory_object->AccessChain().back()); - if (!last_access || !last_access->type()->AsInteger()) { + AccessChainEntry last_access = memory_object->AccessChain().back(); + if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) { return nullptr; } - if (last_access->GetU32() != 0) { - return nullptr; - } - - memory_object->GetParent(); - + memory_object->PopIndirection(); if (memory_object->GetNumberOfMembers() != conststruct_inst->NumInOperands()) { return nullptr; @@ -351,13 +342,8 @@ CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct( return nullptr; } - last_access = - const_mgr->FindDeclaredConstant(member_object->AccessChain().back()); - if (!last_access || !last_access->type()->AsInteger()) { - return nullptr; - } - - if (last_access->GetU32() != i) { + last_access = member_object->AccessChain().back(); + if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) { return nullptr; } } @@ -411,17 +397,12 @@ CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) { return nullptr; } - const analysis::Constant* last_access = - const_mgr->FindDeclaredConstant(memory_object->AccessChain().back()); - if (!last_access || !last_access->type()->AsInteger()) { - return nullptr; - } - - if (last_access->GetU32() != number_of_elements - 1) { + AccessChainEntry last_access = memory_object->AccessChain().back(); + if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) { return nullptr; } - memory_object->GetParent(); + memory_object->PopIndirection(); Instruction* current_insert = def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1)); @@ -458,14 +439,9 @@ CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) { return nullptr; } - const analysis::Constant* current_last_access = - const_mgr->FindDeclaredConstant( - current_memory_object->AccessChain().back()); - if (!current_last_access || !current_last_access->type()->AsInteger()) { - return nullptr; - } - - if (current_last_access->GetU32() != i - 1) { + AccessChainEntry current_last_access = + current_memory_object->AccessChain().back(); + if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) { return nullptr; } current_insert = @@ -475,6 +451,21 @@ CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) { return memory_object; } +bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo( + const AccessChainEntry& entry, uint32_t value) const { + if (!entry.is_result_id) { + return entry.immediate == value; + } + + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + const analysis::Constant* constant = + const_mgr->FindDeclaredConstant(entry.result_id); + if (!constant || !constant->type()->AsInteger()) { + return false; + } + return constant->GetU32() == value; +} + bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) { analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer(); @@ -787,8 +778,8 @@ uint32_t CopyPropagateArrays::GetMemberTypeId( return id; } -void CopyPropagateArrays::MemoryObject::GetMember( - const std::vector& access_chain) { +void CopyPropagateArrays::MemoryObject::PushIndirection( + const std::vector& access_chain) { access_chain_.insert(access_chain_.end(), access_chain.begin(), access_chain.end()); } @@ -823,23 +814,29 @@ uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() { template CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst, iterator begin, iterator end) - : variable_inst_(var_inst), access_chain_(begin, end) {} + : variable_inst_(var_inst) { + std::transform(begin, end, std::back_inserter(access_chain_), + [](uint32_t id) { + return AccessChainEntry{true, {id}}; + }); +} std::vector CopyPropagateArrays::MemoryObject::GetAccessIds() const { analysis::ConstantManager* const_mgr = variable_inst_->context()->get_constant_mgr(); - std::vector access_indices; - for (uint32_t id : AccessChain()) { - const analysis::Constant* element_index_const = - const_mgr->FindDeclaredConstant(id); - if (!element_index_const) { - access_indices.push_back(0); - } else { - access_indices.push_back(element_index_const->GetU32()); - } - } - return access_indices; + std::vector indices(AccessChain().size()); + std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(), + [&const_mgr](const AccessChainEntry& entry) { + if (entry.is_result_id) { + const analysis::Constant* constant = + const_mgr->FindDeclaredConstant(entry.result_id); + return constant == nullptr ? 0 : constant->GetU32(); + } + + return entry.immediate; + }); + return indices; } bool CopyPropagateArrays::MemoryObject::Contains( @@ -860,5 +857,24 @@ bool CopyPropagateArrays::MemoryObject::Contains( return true; } +void CopyPropagateArrays::MemoryObject::BuildConstants() { + for (auto& entry : access_chain_) { + if (entry.is_result_id) { + continue; + } + + auto context = variable_inst_->context(); + analysis::Integer int_type(32, false); + const analysis::Type* uint32_type = + context->get_type_mgr()->GetRegisteredType(&int_type); + analysis::ConstantManager* const_mgr = context->get_constant_mgr(); + const analysis::Constant* index_const = + const_mgr->GetConstant(uint32_type, {entry.immediate}); + entry.result_id = + const_mgr->GetDefiningInstruction(index_const)->result_id(); + entry.is_result_id = true; + } +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/copy_prop_arrays.h b/source/opt/copy_prop_arrays.h index 07747c109b..9e7641f615 100644 --- a/source/opt/copy_prop_arrays.h +++ b/source/opt/copy_prop_arrays.h @@ -52,6 +52,22 @@ class CopyPropagateArrays : public MemPass { } private: + // Represents one index in the OpAccessChain instruction. It can be either + // an instruction's result_id (OpConstant by ex), or a immediate value. + // Immediate values are used to prepare the final access chain without + // creating OpConstant instructions until done. + struct AccessChainEntry { + bool is_result_id; + union { + uint32_t result_id; + uint32_t immediate; + }; + + bool operator!=(const AccessChainEntry& other) const { + return other.is_result_id != is_result_id || other.result_id != result_id; + } + }; + // The class used to identify a particular memory object. This memory object // will be owned by a particular variable, meaning that the memory is part of // that variable. It could be the entire variable or a member of the @@ -70,12 +86,12 @@ class CopyPropagateArrays : public MemPass { // (starting from the current member). The elements in |access_chain| are // interpreted the same as the indices in the |OpAccessChain| // instruction. - void GetMember(const std::vector& access_chain); + void PushIndirection(const std::vector& access_chain); // Change |this| to now represent the first enclosing object to which it // belongs. (Remove the last element off the access_chain). It is invalid // to call this function if |this| does not represent a member of its owner. - void GetParent() { + void PopIndirection() { assert(IsMember()); access_chain_.pop_back(); } @@ -95,7 +111,13 @@ class CopyPropagateArrays : public MemPass { // member that |this| represents starting from the owning variable. These // values are to be interpreted the same way the indices are in an // |OpAccessChain| instruction. - const std::vector& AccessChain() const { return access_chain_; } + const std::vector& AccessChain() const { + return access_chain_; + } + + // Converts all immediate values in the AccessChain their OpConstant + // equivalent. + void BuildConstants(); // Returns the type id of the pointer type that can be used to point to this // memory object. @@ -137,7 +159,7 @@ class CopyPropagateArrays : public MemPass { // The access chain to reach the particular member the memory object // represents. It should be interpreted the same way the indices in an // |OpAccessChain| are interpreted. - std::vector access_chain_; + std::vector access_chain_; std::vector GetAccessIds() const; }; @@ -192,6 +214,10 @@ class CopyPropagateArrays : public MemPass { std::unique_ptr BuildMemoryObjectFromInsert( Instruction* insert_inst); + // Return true if the given entry can represent the given value. + bool IsAccessChainIndexValidAndEqualTo(const AccessChainEntry& entry, + uint32_t value) const; + // Return true if |type_id| is a pointer type whose pointee type is an array. bool IsPointerToArrayType(uint32_t type_id);