From f7aaf92a04e36785befbba738a218fea20a49d62 Mon Sep 17 00:00:00 2001
From: Greg Fischer <greg@lunarg.com>
Date: Sun, 8 May 2022 15:10:07 -0600
Subject: [PATCH] Add structs to eliminate dead input components

Will eliminate all trailing members of input struct that are not
referenced.
---
 .../eliminate_dead_input_components_pass.cpp  | 81 ++++++++++++++-----
 .../eliminate_dead_input_components_pass.h    |  8 +-
 .../eliminate_dead_input_components_test.cpp  | 64 +++++++++++++++
 3 files changed, 133 insertions(+), 20 deletions(-)

diff --git a/source/opt/eliminate_dead_input_components_pass.cpp b/source/opt/eliminate_dead_input_components_pass.cpp
index f383136d55..aa2776bbd3 100644
--- a/source/opt/eliminate_dead_input_components_pass.cpp
+++ b/source/opt/eliminate_dead_input_components_pass.cpp
@@ -56,21 +56,30 @@ Pass::Status EliminateDeadInputComponentsPass::Process() {
       continue;
     }
     const analysis::Array* arr_type = ptr_type->pointee_type()->AsArray();
-    if (arr_type == nullptr) {
+    if (arr_type != nullptr) {
+      unsigned arr_len_id = arr_type->LengthId();
+      Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id);
+      if (arr_len_inst->opcode() != SpvOpConstant) {
+        continue;
+      }
+      // SPIR-V requires array size is >= 1, so this works for signed or
+      // unsigned size
+      unsigned original_max =
+          arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1;
+      unsigned max_idx = FindMaxIndex(var, original_max);
+      if (max_idx != original_max) {
+        ChangeArrayLength(var, max_idx + 1);
+        modified = true;
+      }
       continue;
     }
-    unsigned arr_len_id = arr_type->LengthId();
-    Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id);
-    if (arr_len_inst->opcode() != SpvOpConstant) {
-      continue;
-    }
-    // SPIR-V requires array size is >= 1, so this works for signed or
-    // unsigned size
-    unsigned original_max =
-        arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1;
+    const analysis::Struct* struct_type = ptr_type->pointee_type()->AsStruct();
+    if (struct_type == nullptr) continue;
+    const auto elt_types = struct_type->element_types();
+    unsigned original_max = static_cast<unsigned>(elt_types.size()) - 1;
     unsigned max_idx = FindMaxIndex(var, original_max);
     if (max_idx != original_max) {
-      ChangeArrayLength(var, max_idx + 1);
+      ChangeStructLength(var, max_idx + 1);
       modified = true;
     }
   }
@@ -116,12 +125,13 @@ unsigned EliminateDeadInputComponentsPass::FindMaxIndex(Instruction& var,
   return seen_non_const_ac ? original_max : max;
 }
 
-void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr,
+void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr_var,
                                                          unsigned length) {
   analysis::TypeManager* type_mgr = context()->get_type_mgr();
   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
-  analysis::Pointer* ptr_type = type_mgr->GetType(arr.type_id())->AsPointer();
+  analysis::Pointer* ptr_type =
+      type_mgr->GetType(arr_var.type_id())->AsPointer();
   const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray();
   assert(arr_ty && "expecting array type");
   uint32_t length_id = const_mgr->GetUIntConst(length);
@@ -131,15 +141,48 @@ void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr,
   analysis::Pointer new_ptr_ty(reg_new_arr_ty, SpvStorageClassInput);
   analysis::Type* reg_new_ptr_ty = type_mgr->GetRegisteredType(&new_ptr_ty);
   uint32_t new_ptr_ty_id = type_mgr->GetTypeInstruction(reg_new_ptr_ty);
-  arr.SetResultType(new_ptr_ty_id);
-  def_use_mgr->AnalyzeInstUse(&arr);
-  // Move array OpVariable instruction after its new type to preserve order
-  USE_ASSERT(arr.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
+  arr_var.SetResultType(new_ptr_ty_id);
+  def_use_mgr->AnalyzeInstUse(&arr_var);
+  // Move arr_var after its new type to preserve order
+  USE_ASSERT(arr_var.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
+                 SpvStorageClassFunction &&
+             "cannot move Function variable");
+  Instruction* new_ptr_ty_inst = def_use_mgr->GetDef(new_ptr_ty_id);
+  arr_var.RemoveFromList();
+  arr_var.InsertAfter(new_ptr_ty_inst);
+}
+
+void EliminateDeadInputComponentsPass::ChangeStructLength(
+    Instruction& struct_var, unsigned length) {
+  analysis::TypeManager* type_mgr = context()->get_type_mgr();
+  analysis::Pointer* ptr_type =
+      type_mgr->GetType(struct_var.type_id())->AsPointer();
+  const analysis::Struct* struct_ty = ptr_type->pointee_type()->AsStruct();
+  assert(struct_ty && "expecting struct type");
+  const auto orig_elt_types = struct_ty->element_types();
+  std::vector<const analysis::Type*> new_elt_types;
+  for (unsigned u = 0; u < length; ++u)
+    new_elt_types.push_back(orig_elt_types[u]);
+  analysis::Struct new_struct_ty(new_elt_types);
+  analysis::Type* reg_new_struct_ty =
+      type_mgr->GetRegisteredType(&new_struct_ty);
+  uint32_t new_struct_ty_id = type_mgr->GetTypeInstruction(reg_new_struct_ty);
+  uint32_t old_struct_ty_id = type_mgr->GetTypeInstruction(struct_ty);
+  analysis::DecorationManager* deco_mgr = context()->get_decoration_mgr();
+  deco_mgr->CloneDecorations(old_struct_ty_id, new_struct_ty_id);
+  analysis::Pointer new_ptr_ty(reg_new_struct_ty, SpvStorageClassInput);
+  analysis::Type* reg_new_ptr_ty = type_mgr->GetRegisteredType(&new_ptr_ty);
+  uint32_t new_ptr_ty_id = type_mgr->GetTypeInstruction(reg_new_ptr_ty);
+  struct_var.SetResultType(new_ptr_ty_id);
+  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+  def_use_mgr->AnalyzeInstUse(&struct_var);
+  // Move struct_var after its new type to preserve order
+  USE_ASSERT(struct_var.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
                  SpvStorageClassFunction &&
              "cannot move Function variable");
   Instruction* new_ptr_ty_inst = def_use_mgr->GetDef(new_ptr_ty_id);
-  arr.RemoveFromList();
-  arr.InsertAfter(new_ptr_ty_inst);
+  struct_var.RemoveFromList();
+  struct_var.InsertAfter(new_ptr_ty_inst);
 }
 
 }  // namespace opt
diff --git a/source/opt/eliminate_dead_input_components_pass.h b/source/opt/eliminate_dead_input_components_pass.h
index b77857f4e9..a3a133c2bb 100644
--- a/source/opt/eliminate_dead_input_components_pass.h
+++ b/source/opt/eliminate_dead_input_components_pass.h
@@ -30,7 +30,10 @@ class EliminateDeadInputComponentsPass : public Pass {
  public:
   explicit EliminateDeadInputComponentsPass() {}
 
-  const char* name() const override { return "reduce-load-size"; }
+  const char* name() const override {
+    return "eliminate-dead-input-components";
+  }
+
   Status Process() override;
 
   // Return the mask of preserved Analyses.
@@ -51,6 +54,9 @@ class EliminateDeadInputComponentsPass : public Pass {
 
   // Change the length of the array |inst| to |length|
   void ChangeArrayLength(Instruction& inst, unsigned length);
+
+  // Change the length of the struct |struct_var| to |length|
+  void ChangeStructLength(Instruction& struct_var, unsigned length);
 };
 
 }  // namespace opt
diff --git a/test/opt/eliminate_dead_input_components_test.cpp b/test/opt/eliminate_dead_input_components_test.cpp
index b0098f733a..822914a860 100644
--- a/test/opt/eliminate_dead_input_components_test.cpp
+++ b/test/opt/eliminate_dead_input_components_test.cpp
@@ -399,6 +399,70 @@ TEST_F(ElimDeadInputComponentsTest, NoElimNonIndexedAccessChain) {
   SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
 }
 
+TEST_F(ElimDeadInputComponentsTest, ElimStructMember) {
+  // Should eliminate uv
+  //
+  // #version 450
+  //
+  // in Vertex {
+  //   vec4 Cd;
+  //   vec2 uv;
+  // } iVert;
+  //
+  // out vec4 fragColor;
+  //
+  // void main()
+  // {
+  //   vec4 color = vec4(iVert.Cd);
+  //   fragColor = color;
+  // }
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %iVert %fragColor
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %Vertex "Vertex"
+               OpMemberName %Vertex 0 "Cd"
+               OpMemberName %Vertex 1 "uv"
+               OpName %iVert "iVert"
+               OpName %fragColor "fragColor"
+               OpDecorate %Vertex Block
+               OpDecorate %iVert Location 0
+               OpDecorate %fragColor Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+    %v2float = OpTypeVector %float 2
+     %Vertex = OpTypeStruct %v4float %v2float
+; CHECK: %Vertex = OpTypeStruct %v4float %v2float
+; CHECK: [[sty:%\w+]] = OpTypeStruct %v4float
+%_ptr_Input_Vertex = OpTypePointer Input %Vertex
+; CHECK: [[pty:%\w+]] = OpTypePointer Input [[sty]]
+      %iVert = OpVariable %_ptr_Input_Vertex Input
+; CHECK: %iVert = OpVariable [[pty]] Input
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+  %fragColor = OpVariable %_ptr_Output_v4float Output
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+         %17 = OpAccessChain %_ptr_Input_v4float %iVert %int_0
+         %18 = OpLoad %v4float %17
+               OpStore %fragColor %18
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SetTargetEnv(SPV_ENV_VULKAN_1_3);
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<EliminateDeadInputComponentsPass>(text, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools