Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Clang] Fix Handling of Init Capture with Parameter Packs in LambdaScopeForCallOperatorInstantiationRAII #100766

Merged
merged 3 commits into from
Aug 9, 2024

Conversation

LYP951018
Copy link
Contributor

@LYP951018 LYP951018 commented Jul 26, 2024

This PR addresses issues related to the handling of init capture with parameter packs in Clang's LambdaScopeForCallOperatorInstantiationRAII.

Previously, addInstantiatedCapturesToScope would add init capture containing packs to the scope using the type of the init capture to determine the expanded pack size. However, this approach resulted in a pack size of 0 because getType()->containsUnexpandedParameterPack() returns false. After extensive testing, it appears that the correct pack size can only be inferred from getInit.

But getInit may reference parameters and init capture from an outer lambda, as shown in the following example:

auto L = [](auto... z) {
    return [... w = z](auto... y) {
        // ...
    };
};

To address this, addInstantiatedCapturesToScope in LambdaScopeForCallOperatorInstantiationRAII should be called last. Additionally, addInstantiatedCapturesToScope has been modified to only add init capture to the scope. The previous implementation incorrectly called MakeInstantiatedLocalArgPack for other non-init captures containing packs, resulting in a pack size of 0.

Impact

This patch affects scenarios where LambdaScopeForCallOperatorInstantiationRAII is passed with ShouldAddDeclsFromParentScope = false, preventing the correct addition of the current lambda's init capture to the scope. There are two main scenarios for ShouldAddDeclsFromParentScope = false:

  1. Constraints: Sometimes constraints are instantiated in place rather than delayed. In this case, LambdaScopeForCallOperatorInstantiationRAII does not need to add init capture to the scope.
  2. noexcept Expressions: The expressions inside noexcept have already been transformed, and the packs referenced within have been expanded. Only RebuildLambdaInfo needs to add the expanded captures to the scope, without requiring addInstantiatedCapturesToScope from LambdaScopeForCallOperatorInstantiationRAII.

Considerations

An alternative approach could involve adding a data structure within the lambda to record the expanded size of the init capture pack. However, this would increase the lambda's size and require extensive modifications.

This PR is a prerequisite for implmenting #61426

@LYP951018 LYP951018 requested a review from Endilll as a code owner July 26, 2024 15:44
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" labels Jul 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 26, 2024

@llvm/pr-subscribers-clang

Author: Yupei Liu (LYP951018)

Changes

This PR addresses issues related to the handling of init capture with parameter packs in Clang's LambdaScopeForCallOperatorInstantiationRAII.

Previously, addInstantiatedCapturesToScope would add init capture containing packs to the scope using the type of the init capture to determine the expanded pack size. However, this approach resulted in a pack size of 0 because getType()->containsUnexpandedParameterPack() returns false. After extensive testing, it appears that the correct pack size can only be inferred from getInit.

But getInit may reference parameters and init capture from an outer lambda, as shown in the following example:

auto L = [](auto... z) {
    return [... w = z](auto... y) {
        // ...
    };
};

To address this, addInstantiatedCapturesToScope in LambdaScopeForCallOperatorInstantiationRAII should be called last. Additionally, addInstantiatedCapturesToScope has been modified to only add init capture to the scope. The previous implementation incorrectly called MakeInstantiatedLocalArgPack for other non-init captures containing packs, resulting in a pack size of 0.

Impact

This patch affects scenarios where LambdaScopeForCallOperatorInstantiationRAII is passed with ShouldAddDeclsFromParentScope = false, preventing the correct addition of the current lambda's init capture to the scope. There are two main scenarios for ShouldAddDeclsFromParentScope = false:

  1. Constraints: Sometimes constraints are instantiated in place rather than delayed. In this case, LambdaScopeForCallOperatorInstantiationRAII does not need to add init capture to the scope.
  2. noexcept Expressions: The expressions inside noexcept have already been transformed, and the packs referenced within have been expanded. Only RebuildLambdaInfo needs to add the expanded captures to the scope, without requiring addInstantiatedCapturesToScope from LambdaScopeForCallOperatorInstantiationRAII.

Considerations

An alternative approach could involve adding a data structure within the lambda to record the expanded size of the init capture pack. However, this would increase the lambda's size and require extensive modifications.

This PR is a prerequisite for implmenting #61426


Full diff: https://github.com/llvm/llvm-project/pull/100766.diff

5 Files Affected:

  • (modified) clang/include/clang/Sema/Sema.h (+4)
  • (modified) clang/lib/Sema/SemaConcept.cpp (+13-5)
  • (modified) clang/lib/Sema/SemaLambda.cpp (+7-2)
  • (modified) clang/lib/Sema/SemaTemplateVariadic.cpp (+11-6)
  • (modified) clang/test/SemaTemplate/concepts-lambda.cpp (+28)
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 2ec6367eccea0..d8f903bc05eb6 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14186,6 +14186,10 @@ class Sema final : public SemaBase {
   std::optional<unsigned> getNumArgumentsInExpansion(
       QualType T, const MultiLevelTemplateArgumentList &TemplateArgs);
 
+  std::optional<unsigned> getNumArgumentsInExpansionFromUnexpanded(
+      llvm::ArrayRef<UnexpandedParameterPack> Unexpanded,
+      const MultiLevelTemplateArgumentList &TemplateArgs);
+
   /// Determine whether the given declarator contains any unexpanded
   /// parameter packs.
   ///
diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp
index 9e16b67284be4..49065af10b452 100644
--- a/clang/lib/Sema/SemaConcept.cpp
+++ b/clang/lib/Sema/SemaConcept.cpp
@@ -712,8 +712,8 @@ bool Sema::addInstantiatedCapturesToScope(
   auto AddSingleCapture = [&](const ValueDecl *CapturedPattern,
                               unsigned Index) {
     ValueDecl *CapturedVar = LambdaClass->getCapture(Index)->getCapturedVar();
-    if (CapturedVar->isInitCapture())
-      Scope.InstantiatedLocal(CapturedPattern, CapturedVar);
+    assert(CapturedVar->isInitCapture());
+    Scope.InstantiatedLocal(CapturedPattern, CapturedVar);
   };
 
   for (const LambdaCapture &CapturePattern : LambdaPattern->captures()) {
@@ -721,13 +721,21 @@ bool Sema::addInstantiatedCapturesToScope(
       Instantiated++;
       continue;
     }
-    const ValueDecl *CapturedPattern = CapturePattern.getCapturedVar();
+    ValueDecl *CapturedPattern = CapturePattern.getCapturedVar();
+
+    if (!CapturedPattern->isInitCapture()) {
+      continue;
+    }
+
     if (!CapturedPattern->isParameterPack()) {
       AddSingleCapture(CapturedPattern, Instantiated++);
     } else {
       Scope.MakeInstantiatedLocalArgPack(CapturedPattern);
-      std::optional<unsigned> NumArgumentsInExpansion =
-          getNumArgumentsInExpansion(CapturedPattern->getType(), TemplateArgs);
+      SmallVector<UnexpandedParameterPack, 2> Unexpanded;
+      SemaRef.collectUnexpandedParameterPacks(
+          dyn_cast<VarDecl>(CapturedPattern)->getInit(), Unexpanded);
+      auto NumArgumentsInExpansion =
+          getNumArgumentsInExpansionFromUnexpanded(Unexpanded, TemplateArgs);
       if (!NumArgumentsInExpansion)
         continue;
       for (unsigned Arg = 0; Arg < *NumArgumentsInExpansion; ++Arg)
diff --git a/clang/lib/Sema/SemaLambda.cpp b/clang/lib/Sema/SemaLambda.cpp
index 601077e9f3334..f39d35df56ed6 100644
--- a/clang/lib/Sema/SemaLambda.cpp
+++ b/clang/lib/Sema/SemaLambda.cpp
@@ -2389,11 +2389,10 @@ Sema::LambdaScopeForCallOperatorInstantiationRAII::
   if (!FDPattern)
     return;
 
-  SemaRef.addInstantiatedCapturesToScope(FD, FDPattern, Scope, MLTAL);
-
   if (!ShouldAddDeclsFromParentScope)
     return;
 
+  FunctionDecl *OutermostFD = FD, *OutermostFDPattern = FDPattern;
   llvm::SmallVector<std::pair<FunctionDecl *, FunctionDecl *>, 4>
       ParentInstantiations;
   while (true) {
@@ -2417,5 +2416,11 @@ Sema::LambdaScopeForCallOperatorInstantiationRAII::
   for (const auto &[FDPattern, FD] : llvm::reverse(ParentInstantiations)) {
     SemaRef.addInstantiatedParametersToScope(FD, FDPattern, Scope, MLTAL);
     SemaRef.addInstantiatedLocalVarsToScope(FD, FDPattern, Scope);
+
+    if (isLambdaCallOperator(FD))
+      SemaRef.addInstantiatedCapturesToScope(FD, FDPattern, Scope, MLTAL);
   }
+
+  SemaRef.addInstantiatedCapturesToScope(OutermostFD, OutermostFDPattern, Scope,
+                                         MLTAL);
 }
diff --git a/clang/lib/Sema/SemaTemplateVariadic.cpp b/clang/lib/Sema/SemaTemplateVariadic.cpp
index 3d4ccaf68c700..d9886bd06fee4 100644
--- a/clang/lib/Sema/SemaTemplateVariadic.cpp
+++ b/clang/lib/Sema/SemaTemplateVariadic.cpp
@@ -825,12 +825,9 @@ bool Sema::CheckParameterPacksForExpansion(
   return false;
 }
 
-std::optional<unsigned> Sema::getNumArgumentsInExpansion(
-    QualType T, const MultiLevelTemplateArgumentList &TemplateArgs) {
-  QualType Pattern = cast<PackExpansionType>(T)->getPattern();
-  SmallVector<UnexpandedParameterPack, 2> Unexpanded;
-  CollectUnexpandedParameterPacksVisitor(Unexpanded).TraverseType(Pattern);
-
+std::optional<unsigned> Sema::getNumArgumentsInExpansionFromUnexpanded(
+    llvm::ArrayRef<UnexpandedParameterPack> Unexpanded,
+    const MultiLevelTemplateArgumentList &TemplateArgs) {
   std::optional<unsigned> Result;
   for (unsigned I = 0, N = Unexpanded.size(); I != N; ++I) {
     // Compute the depth and index for this parameter pack.
@@ -878,6 +875,14 @@ std::optional<unsigned> Sema::getNumArgumentsInExpansion(
   return Result;
 }
 
+std::optional<unsigned> Sema::getNumArgumentsInExpansion(
+    QualType T, const MultiLevelTemplateArgumentList &TemplateArgs) {
+  QualType Pattern = cast<PackExpansionType>(T)->getPattern();
+  SmallVector<UnexpandedParameterPack, 2> Unexpanded;
+  CollectUnexpandedParameterPacksVisitor(Unexpanded).TraverseType(Pattern);
+  return getNumArgumentsInExpansionFromUnexpanded(Unexpanded, TemplateArgs);
+}
+
 bool Sema::containsUnexpandedParameterPacks(Declarator &D) {
   const DeclSpec &DS = D.getDeclSpec();
   switch (DS.getTypeSpecType()) {
diff --git a/clang/test/SemaTemplate/concepts-lambda.cpp b/clang/test/SemaTemplate/concepts-lambda.cpp
index 252ef08549a48..9c5807bbabdcb 100644
--- a/clang/test/SemaTemplate/concepts-lambda.cpp
+++ b/clang/test/SemaTemplate/concepts-lambda.cpp
@@ -251,3 +251,31 @@ void dependent_param() {
   L(0, 1)(1, 2)(1);
 }
 } // namespace dependent_param_concept
+
+namespace init_captures {
+template <int N> struct V {};
+
+void sink(V<0>, V<1>, V<2>, V<3>, V<4>) {}
+
+void init_capture_pack() {
+  auto L = [](auto... z) {
+    return [=](auto... y) {
+      return [... w = z, y...](auto)
+        requires requires { sink(w..., y...); }
+      {};
+    };
+  };
+  L(V<0>{}, V<1>{}, V<2>{})(V<3>{}, V<4>{})(1);
+}
+
+void dependent_capture_packs() {
+  auto L = [](auto... z) {
+    return [... w = z](auto... y) {
+      return [... c = w](auto)
+        requires requires { sink(c..., y...); }
+      {};
+    };
+  };
+  L(V<0>{}, V<1>{}, V<2>{})(V<3>{}, V<4>{})(1);
+}
+} // namespace init_captures

@cor3ntin cor3ntin requested review from Sirraide and zyn0217 July 26, 2024 18:24
@LYP951018
Copy link
Contributor Author

ping~

Comment on lines +726 to +728
if (!CapturedPattern->isInitCapture()) {
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is CapturedPattern not an initCapture?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

captures like

[a]() {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, nvm, I misread that code , I failed to realize the assert L715 was in a different scope, so I was confused.

Copy link
Contributor

@cor3ntin cor3ntin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +726 to +728
if (!CapturedPattern->isInitCapture()) {
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, nvm, I misread that code , I failed to realize the assert L715 was in a different scope, so I was confused.

@LYP951018 LYP951018 merged commit 52126dc into llvm:main Aug 9, 2024
9 checks passed
@LYP951018
Copy link
Contributor Author

Thanks for the review ;)

getNumArgumentsInExpansion(CapturedPattern->getType(), TemplateArgs);
SmallVector<UnexpandedParameterPack, 2> Unexpanded;
SemaRef.collectUnexpandedParameterPacks(
dyn_cast<VarDecl>(CapturedPattern)->getInit(), Unexpanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use cast?

kutemeikito added a commit to kutemeikito/llvm-project that referenced this pull request Aug 10, 2024
* 'main' of https://github.com/llvm/llvm-project: (700 commits)
  [SandboxIR][NFC] SingleLLVMInstructionImpl class (llvm#102687)
  [ThinLTO]Clean up 'import-assume-unique-local' flag. (llvm#102424)
  [nsan] Make #include more conventional
  [SandboxIR][NFC] Use Tracker.emplaceIfTracking()
  [libc]  Moved range_reduction_double ifdef statement (llvm#102659)
  [libc] Fix CFP long double and add tests (llvm#102660)
  [TargetLowering] Handle vector types in expandFixedPointMul (llvm#102635)
  [compiler-rt][NFC] Replace environment variable with %t (llvm#102197)
  [UnitTests] Convert a test to use opaque pointers (llvm#102668)
  [CodeGen][NFCI] Don't re-implement parts of ASTContext::getIntWidth (llvm#101765)
  [SandboxIR] Clean up tracking code with the help of emplaceIfTracking() (llvm#102406)
  [mlir][bazel] remove extra blanks in mlir-tblgen test
  [NVPTX][NFC] Update tests to use bfloat type (llvm#101493)
  [mlir] Add support for parsing nested PassPipelineOptions (llvm#101118)
  [mlir][bazel] add missing td dependency in mlir-tblgen test
  [flang][cuda] Fix lib dependency
  [libc] Clean up remaining use of *_WIDTH macros in printf (llvm#102679)
  [flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (llvm#102662)
  [SandboxIR] Implement the InsertElementInst class (llvm#102404)
  [libc] Fix use of cpp::numeric_limits<...>::digits (llvm#102674)
  [mlir][ODS] Verify type constraints in Types and Attributes (llvm#102326)
  [LTO] enable `ObjCARCContractPass` only on optimized build  (llvm#101114)
  [mlir][ODS] Consistent `cppType` / `cppClassName` usage (llvm#102657)
  [lldb] Move definition of SBSaveCoreOptions dtor out of header (llvm#102539)
  [libc] Use cpp::numeric_limits in preference to C23 <limits.h> macros (llvm#102665)
  [clang] Implement -fptrauth-auth-traps. (llvm#102417)
  [LLVM][rtsan] rtsan transform to preserve CFGAnalyses (llvm#102651)
  Revert "[AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)"
  [RISCV][GISel] Add missing tests for G_CTLZ/CTTZ instruction selection. NFC
  Return available function types for BindingDecls. (llvm#102196)
  [clang] Wire -fptrauth-returns to "ptrauth-returns" fn attribute. (llvm#102416)
  [RISCV] Remove riscv-experimental-rv64-legal-i32. (llvm#102509)
  [RISCV] Move PseudoVSET(I)VLI expansion to use PseudoInstExpansion. (llvm#102496)
  [NVPTX] support switch statement with brx.idx (reland) (llvm#102550)
  [libc][newhdrgen]sorted function names in yaml (llvm#102544)
  [GlobalIsel] Combine G_ADD and G_SUB with constants (llvm#97771)
  Suppress spurious warnings due to R_RISCV_SET_ULEB128
  [scudo] Separated committed and decommitted entries. (llvm#101409)
  [MIPS] Fix missing ANDI optimization (llvm#97689)
  [Clang] Add env var for nvptx-arch/amdgpu-arch timeout (llvm#102521)
  [asan] Switch allocator to dynamic base address (llvm#98511)
  [AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)
  [libc][math][c23] Add fadd{l,f128} C23 math functions (llvm#102531)
  [mlir][bazel] revert bazel rule change for DLTITransformOps
  [msan] Support vst{2,3,4}_lane instructions (llvm#101215)
  Revert "[MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)"
  [X86] pr57673.ll - generate MIR test checks
  [mlir][vector][test] Split tests from vector-transfer-flatten.mlir (llvm#102584)
  [mlir][bazel] add bazel rule for DLTITransformOps
  OpenMPOpt: Remove dead include
  [IR] Add method to GlobalVariable to change type of initializer. (llvm#102553)
  [flang][cuda] Force default allocator in device code (llvm#102238)
  [llvm] Construct SmallVector<SDValue> with ArrayRef (NFC) (llvm#102578)
  [MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)
  [AMDGPU][AsmParser][NFC] Remove a misleading comment. (llvm#102604)
  [Arm][AArch64][Clang] Respect function's branch protection attributes. (llvm#101978)
  [mlir] Verifier: steal bit to track seen instead of set. (llvm#102626)
  [Clang] Fix Handling of Init Capture with Parameter Packs in LambdaScopeForCallOperatorInstantiationRAII (llvm#100766)
  [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.
  [gn] Give two scripts argparse.RawDescriptionHelpFormatter
  [bazel] Add missing dep for the SPIRVToLLVM target
  [Clang] Simplify specifying passes via -Xoffload-linker (llvm#102483)
  [bazel] Port for d45de80
  [SelectionDAG] Use unaligned store/load to move AVX registers onto stack for `insertelement` (llvm#82130)
  [Clang][OMPX] Add the code generation for multi-dim `num_teams` (llvm#101407)
  [ARM] Regenerate big-endian-vmov.ll. NFC
  [AMDGPU][AsmParser][NFCI] All NamedIntOperands to be of the i32 type. (llvm#102616)
  [libc][math][c23] Add totalorderl function. (llvm#102564)
  [mlir][spirv] Support `memref` in `convert-to-spirv` pass (llvm#102534)
  [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (llvm#101664)
  Fix a unit test input file (llvm#102567)
  [llvm-readobj][COFF] Dump hybrid objects for ARM64X files. (llvm#102245)
  AMDGPU/NewPM: Port SIFixSGPRCopies to new pass manager (llvm#102614)
  [MemoryBuiltins] Simplify getCalledFunction() helper (NFC)
  [AArch64] Add invalid 1 x vscale costs for reductions and reduction-operations. (llvm#102105)
  [MemoryBuiltins] Handle allocator attributes on call-site
  LSV/test/AArch64: add missing lit.local.cfg; fix build (llvm#102607)
  Revert "Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)"
  [RISCV] Add Syntacore SCR5 RV32/64 processors definition (llvm#102285)
  [InstCombine] Remove unnecessary RUN line from test (NFC)
  [flang][OpenMP] Handle multiple ranges in `num_teams` clause (llvm#102535)
  [mlir][vector] Add tests for scalable vectors in one-shot-bufferize.mlir (llvm#102361)
  [mlir][vector] Disable `vector.matrix_multiply` for scalable vectors (llvm#102573)
  [clang] Implement CWG2627 Bit-fields and narrowing conversions (llvm#78112)
  [NFC] Use references to avoid copying (llvm#99863)
  Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (llvm#100731)" (llvm#102457)
  [IRBuilder] Generate nuw GEPs for struct member accesses (llvm#99538)
  [bazel] Port for 9b06e25
  [CodeGen][NewPM] Improve start/stop pass error message CodeGenPassBuilder (llvm#102591)
  [AArch64] Implement TRBMPAM_EL1 system register (llvm#102485)
  [InstCombine] Fixing wrong select folding in vectors with undef elements (llvm#102244)
  [AArch64] Sink operands to fmuladd. (llvm#102297)
  LSV: document hang reported in llvm#37865 (llvm#102479)
  Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)
  [RISCV][clang] Remove bfloat base type in non-zvfbfmin vcreate (llvm#102146)
  [RISCV][clang] Add missing `zvfbfmin` to `vget_v` intrinsic (llvm#102149)
  [mlir][vector] Add mask elimination transform (llvm#99314)
  [Clang][Interp] Fix display of syntactically-invalid note for member function calls (llvm#102170)
  [bazel] Port for 3fffa6d
  [DebugInfo][RemoveDIs] Use iterator-inserters in clang (llvm#102006)
  ...

Signed-off-by: Edwiin Kusuma Jaya <kutemeikito0905@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants