Skip to content

Commit 0fbaf03

Browse files
[SPIR-V] Cast ptr kernel args to i8* when used as Store's value operand (#78603)
Handle a special case when StoreInst's value operand is a kernel argument of a pointer type. Since these arguments could have either a basic element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast the StoreInst's value operand to default pointer element type (i8). This pull request addresses the issue #72864
1 parent 07dfa61 commit 0fbaf03

9 files changed

+219
-64
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
2626
SPIRVISelLowering.cpp
2727
SPIRVLegalizerInfo.cpp
2828
SPIRVMCInstLower.cpp
29+
SPIRVMetadata.cpp
2930
SPIRVModuleAnalysis.cpp
3031
SPIRVPreLegalizer.cpp
3132
SPIRVPrepareFunctions.cpp

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

+5-58
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "SPIRVBuiltins.h"
1818
#include "SPIRVGlobalRegistry.h"
1919
#include "SPIRVISelLowering.h"
20+
#include "SPIRVMetadata.h"
2021
#include "SPIRVRegisterInfo.h"
2122
#include "SPIRVSubtarget.h"
2223
#include "SPIRVUtils.h"
@@ -117,64 +118,12 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
117118
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
118119
}
119120

120-
static MDString *getKernelArgAttribute(const Function &KernelFunction,
121-
unsigned ArgIdx,
122-
const StringRef AttributeName) {
123-
assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
124-
"Kernel attributes are attached/belong only to kernel functions");
125-
126-
// Lookup the argument attribute in metadata attached to the kernel function.
127-
MDNode *Node = KernelFunction.getMetadata(AttributeName);
128-
if (Node && ArgIdx < Node->getNumOperands())
129-
return cast<MDString>(Node->getOperand(ArgIdx));
130-
131-
// Sometimes metadata containing kernel attributes is not attached to the
132-
// function, but can be found in the named module-level metadata instead.
133-
// For example:
134-
// !opencl.kernels = !{!0}
135-
// !0 = !{void ()* @someKernelFunction, !1, ...}
136-
// !1 = !{!"kernel_arg_addr_space", ...}
137-
// In this case the actual index of searched argument attribute is ArgIdx + 1,
138-
// since the first metadata node operand is occupied by attribute name
139-
// ("kernel_arg_addr_space" in the example above).
140-
unsigned MDArgIdx = ArgIdx + 1;
141-
NamedMDNode *OpenCLKernelsMD =
142-
KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
143-
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
144-
return nullptr;
145-
146-
// KernelToMDNodeList contains kernel function declarations followed by
147-
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
148-
// to the currently lowered kernel function.
149-
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
150-
bool FoundLoweredKernelFunction = false;
151-
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
152-
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
153-
if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
154-
KernelFunction.getName()) {
155-
FoundLoweredKernelFunction = true;
156-
continue;
157-
}
158-
if (MaybeValue && FoundLoweredKernelFunction)
159-
return nullptr;
160-
161-
MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
162-
if (FoundLoweredKernelFunction && MaybeNode &&
163-
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
164-
AttributeName &&
165-
MDArgIdx < MaybeNode->getNumOperands())
166-
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
167-
}
168-
return nullptr;
169-
}
170-
171121
static SPIRV::AccessQualifier::AccessQualifier
172122
getArgAccessQual(const Function &F, unsigned ArgIdx) {
173123
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
174124
return SPIRV::AccessQualifier::ReadWrite;
175125

176-
MDString *ArgAttribute =
177-
getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
126+
MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
178127
if (!ArgAttribute)
179128
return SPIRV::AccessQualifier::ReadWrite;
180129

@@ -186,9 +135,8 @@ getArgAccessQual(const Function &F, unsigned ArgIdx) {
186135
}
187136

188137
static std::vector<SPIRV::Decoration::Decoration>
189-
getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
190-
MDString *ArgAttribute =
191-
getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
138+
getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
139+
MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
192140
if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
193141
return {SPIRV::Decoration::Volatile};
194142
return {};
@@ -209,8 +157,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
209157
isSpecialOpaqueType(OriginalArgType))
210158
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
211159

212-
MDString *MDKernelArgType =
213-
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
160+
MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
214161
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
215162
!MDKernelArgType->getString().ends_with("_t")))
216163
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

+28-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "SPIRV.h"
15+
#include "SPIRVMetadata.h"
1516
#include "SPIRVTargetMachine.h"
1617
#include "SPIRVUtils.h"
1718
#include "llvm/IR/IRBuilder.h"
@@ -282,7 +283,26 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
282283
Value *Pointer;
283284
Type *ExpectedElementType;
284285
unsigned OperandToReplace;
285-
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
286+
bool AllowCastingToChar = false;
287+
288+
StoreInst *SI = dyn_cast<StoreInst>(I);
289+
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
290+
SI->getValueOperand()->getType()->isPointerTy() &&
291+
isa<Argument>(SI->getValueOperand())) {
292+
Argument *Arg = cast<Argument>(SI->getValueOperand());
293+
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
294+
if (!ArgType || ArgType->getString().starts_with("uchar*"))
295+
return;
296+
297+
// Handle special case when StoreInst's value operand is a kernel argument
298+
// of a pointer type. Since these arguments could have either a basic
299+
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
300+
// the StoreInst's value operand to default pointer element type (i8).
301+
Pointer = Arg;
302+
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
303+
OperandToReplace = 0;
304+
AllowCastingToChar = true;
305+
} else if (SI) {
286306
Pointer = SI->getPointerOperand();
287307
ExpectedElementType = SI->getValueOperand()->getType();
288308
OperandToReplace = 1;
@@ -364,13 +384,15 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
364384

365385
// Do not emit spv_ptrcast if it would cast to the default pointer element
366386
// type (i8) of the same address space.
367-
if (ExpectedElementType->isIntegerTy(8))
387+
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
368388
return;
369389

370-
// If this would be the first spv_ptrcast and there is no spv_assign_ptr_type
371-
// for this pointer before, do not emit spv_ptrcast but emit
372-
// spv_assign_ptr_type instead.
373-
if (FirstPtrCastOrAssignPtrType && isa<Instruction>(Pointer)) {
390+
// If this would be the first spv_ptrcast, the pointer's defining instruction
391+
// requires spv_assign_ptr_type and does not already have one, do not emit
392+
// spv_ptrcast and emit spv_assign_ptr_type instead.
393+
Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
394+
if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
395+
requireAssignPtrType(PointerDefInst)) {
374396
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
375397
ExpectedElementTypeConst, Pointer,
376398
{IRB->getInt32(AddressSpace)});
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//===--- SPIRVMetadata.cpp ---- IR Metadata Parsing Funcs -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains functions needed for parsing LLVM IR metadata relevant
10+
// to the SPIR-V target.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "SPIRVMetadata.h"
15+
16+
using namespace llvm;
17+
18+
static MDString *getOCLKernelArgAttribute(const Function &F, unsigned ArgIdx,
19+
const StringRef AttributeName) {
20+
assert(
21+
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
22+
"Kernel attributes are attached/belong only to OpenCL kernel functions");
23+
24+
// Lookup the argument attribute in metadata attached to the kernel function.
25+
MDNode *Node = F.getMetadata(AttributeName);
26+
if (Node && ArgIdx < Node->getNumOperands())
27+
return cast<MDString>(Node->getOperand(ArgIdx));
28+
29+
// Sometimes metadata containing kernel attributes is not attached to the
30+
// function, but can be found in the named module-level metadata instead.
31+
// For example:
32+
// !opencl.kernels = !{!0}
33+
// !0 = !{void ()* @someKernelFunction, !1, ...}
34+
// !1 = !{!"kernel_arg_addr_space", ...}
35+
// In this case the actual index of searched argument attribute is ArgIdx + 1,
36+
// since the first metadata node operand is occupied by attribute name
37+
// ("kernel_arg_addr_space" in the example above).
38+
unsigned MDArgIdx = ArgIdx + 1;
39+
NamedMDNode *OpenCLKernelsMD =
40+
F.getParent()->getNamedMetadata("opencl.kernels");
41+
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
42+
return nullptr;
43+
44+
// KernelToMDNodeList contains kernel function declarations followed by
45+
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
46+
// to the currently lowered kernel function.
47+
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
48+
bool FoundLoweredKernelFunction = false;
49+
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
50+
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
51+
if (MaybeValue &&
52+
dyn_cast<Function>(MaybeValue->getValue())->getName() == F.getName()) {
53+
FoundLoweredKernelFunction = true;
54+
continue;
55+
}
56+
if (MaybeValue && FoundLoweredKernelFunction)
57+
return nullptr;
58+
59+
MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
60+
if (FoundLoweredKernelFunction && MaybeNode &&
61+
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
62+
AttributeName &&
63+
MDArgIdx < MaybeNode->getNumOperands())
64+
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
65+
}
66+
return nullptr;
67+
}
68+
69+
namespace llvm {
70+
71+
MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx) {
72+
assert(
73+
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
74+
"Kernel attributes are attached/belong only to OpenCL kernel functions");
75+
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
76+
}
77+
78+
MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
79+
assert(
80+
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
81+
"Kernel attributes are attached/belong only to OpenCL kernel functions");
82+
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
83+
}
84+
85+
MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx) {
86+
assert(
87+
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
88+
"Kernel attributes are attached/belong only to OpenCL kernel functions");
89+
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
90+
}
91+
92+
} // namespace llvm

llvm/lib/Target/SPIRV/SPIRVMetadata.h

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===--- SPIRVMetadata.h ---- IR Metadata Parsing Funcs ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains functions needed for parsing LLVM IR metadata relevant
10+
// to the SPIR-V target.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
15+
#define LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
16+
17+
#include "llvm/IR/Metadata.h"
18+
#include "llvm/IR/Module.h"
19+
20+
namespace llvm {
21+
22+
//===----------------------------------------------------------------------===//
23+
// OpenCL Metadata
24+
//
25+
26+
MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx);
27+
MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx);
28+
MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx);
29+
30+
} // namespace llvm
31+
#endif // LLVM_LIB_TARGET_SPIRV_METADATA_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
5+
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]
6+
7+
define spir_kernel void @foo(ptr addrspace(1) %arg) {
8+
ret void
9+
}
10+
11+
; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
5+
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]
6+
7+
define spir_kernel void @foo(i8 %a, ptr addrspace(1) %p) {
8+
store i8 %a, ptr addrspace(1) %p
9+
ret void
10+
}
11+
12+
; CHECK: %[[#A:]] = OpFunctionParameter %[[#CHAR]]
13+
; CHECK: %[[#P:]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
14+
; CHECK: OpStore %[[#P]] %[[#A]]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
5+
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]
6+
7+
define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
8+
%var = alloca ptr addrspace(1), align 8
9+
; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
10+
; CHECK-NOT: %[[#]] = OpBitcast %[[#]] %[[#]]
11+
store ptr addrspace(1) %arg, ptr %var, align 8
12+
ret void
13+
}
14+
15+
!1 = !{i32 1}
16+
!2 = !{!"none"}
17+
!3 = !{!"char*"}
18+
!4 = !{!""}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
5+
%var = alloca ptr addrspace(1), align 8
6+
; CHECK: %[[#VAR:]] = OpVariable %[[#]] Function
7+
store ptr addrspace(1) %arg, ptr %var, align 8
8+
; The test itends to verify that OpStore uses OpVariable result directly (without a bitcast).
9+
; Other type checking is done by spirv-val.
10+
; CHECK: OpStore %[[#VAR]] %[[#]] Aligned 8
11+
%lod = load ptr addrspace(1), ptr %var, align 8
12+
%idx = getelementptr inbounds i64, ptr addrspace(1) %lod, i64 0
13+
ret void
14+
}
15+
16+
!1 = !{i32 1}
17+
!2 = !{!"none"}
18+
!3 = !{!"ulong*"}
19+
!4 = !{!""}

0 commit comments

Comments
 (0)