Skip to content

Commit

Permalink
[CIR][CodeGen] Special treatment of 3-element extended vector load an…
Browse files Browse the repository at this point in the history
…d store (#674)

Continue the work of #613 .

Original CodeGen treat vec3 as vec4 to get aligned memory access. This
PR enable these paths.
  • Loading branch information
seven-mile authored Jun 11, 2024
1 parent 1c0064b commit 5e9148f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 6 deletions.
39 changes: 33 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,23 @@ void CIRGenFunction::buildStoreOfScalar(mlir::Value Value, Address Addr,
return;
}

mlir::Type SrcTy = Value.getType();
if (const auto *ClangVecTy = Ty->getAs<clang::VectorType>()) {
auto VecTy = dyn_cast<mlir::cir::VectorType>(SrcTy);
if (!CGM.getCodeGenOpts().PreserveVec3Type &&
ClangVecTy->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vector store");
ClangVecTy->getNumElements() == 3) {
// Handle vec3 special.
if (VecTy && VecTy.getSize() == 3) {
// Our source is a vec3, do a shuffle vector to make it a vec4.
Value = builder.createVecShuffle(Value.getLoc(), Value,
ArrayRef<int64_t>{0, 1, 2, -1});
SrcTy = mlir::cir::VectorType::get(VecTy.getContext(),
VecTy.getEltType(), 4);
}
if (Addr.getElementType() != SrcTy) {
Addr = Addr.withElementType(SrcTy);
}
}
}

// Update the alloca with more info on initialization.
Expand Down Expand Up @@ -772,7 +785,7 @@ void CIRGenFunction::buildStoreThroughExtVectorComponentLValue(RValue Src,
// of the Elts constant array will be one past the size of the vector.
// Ignore the last element here, if it is greater than the mask size.
if (getAccessedFieldNo(NumSrcElts - 1, Elts) == Mask.size())
llvm_unreachable("NYI");
NumSrcElts--;

// modify when what gets shuffled in
for (unsigned i = 0; i != NumSrcElts; ++i)
Expand Down Expand Up @@ -2770,14 +2783,28 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
llvm_unreachable("NYI");
}

auto ElemTy = Addr.getElementType();

if (const auto *ClangVecTy = Ty->getAs<clang::VectorType>()) {
// Handle vectors of size 3 like size 4 for better performance.
const auto VTy = cast<mlir::cir::VectorType>(ElemTy);

if (!CGM.getCodeGenOpts().PreserveVec3Type &&
ClangVecTy->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vector load");
ClangVecTy->getNumElements() == 3) {
auto loc = Addr.getPointer().getLoc();
auto vec4Ty =
mlir::cir::VectorType::get(VTy.getContext(), VTy.getEltType(), 4);
Address Cast = Addr.withElementType(vec4Ty);
// Now load value.
mlir::Value V = builder.createLoad(loc, Cast);

// Shuffle vector to get vec3.
V = builder.createVecShuffle(loc, V, ArrayRef<int64_t>{0, 1, 2});
return buildFromMemory(V, Ty);
}
}

auto Ptr = Addr.getPointer();
auto ElemTy = Addr.getElementType();
if (ElemTy.isa<mlir::cir::VoidType>()) {
ElemTy = mlir::cir::IntType::get(builder.getContext(), 8, true);
auto ElemPtrTy = mlir::cir::PointerType::get(builder.getContext(), ElemTy);
Expand Down
58 changes: 58 additions & 0 deletions clang/test/CIR/CodeGen/vectype-ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM

typedef int vi4 __attribute__((ext_vector_type(4)));
typedef int vi3 __attribute__((ext_vector_type(3)));
typedef int vi2 __attribute__((ext_vector_type(2)));
typedef double vd2 __attribute__((ext_vector_type(2)));
typedef long vl2 __attribute__((ext_vector_type(2)));
Expand Down Expand Up @@ -349,6 +350,10 @@ void test_store() {
// CIR-NEXT: %[[#PVECB:]] = cir.alloca !cir.vector<!s32i x 2>
// LLVM-NEXT: %[[#PVECB:]] = alloca <2 x i32>

vi3 c = {};
// CIR-NEXT: %[[#PVECC:]] = cir.alloca !cir.vector<!s32i x 3>
// LLVM-NEXT: %[[#PVECC:]] = alloca <3 x i32>

a.xy = b;
// CIR: %[[#LOAD4RHS:]] = cir.load %{{[0-9]+}} : !cir.ptr<!cir.vector<!s32i x 2>>, !cir.vector<!s32i x 2>
// CIR-NEXT: %[[#LOAD5LHS:]] = cir.load %{{[0-9]+}} : !cir.ptr<!cir.vector<!s32i x 4>>, !cir.vector<!s32i x 4>
Expand Down Expand Up @@ -388,6 +393,35 @@ void test_store() {
// LLVM-NEXT: %[[#RESULT:]] = shufflevector <4 x i32> %[[#VECA]], <4 x i32> %[[#EXTVECB]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
// LLVM-NEXT: store <4 x i32> %[[#RESULT]], ptr %[[#PVECA]], align 16

// OpenCL C Specification 6.3.7. Vector Components
// The suffixes .lo (or .even) and .hi (or .odd) for a 3-component vector type
// operate as if the 3-component vector type is a 4-component vector type with
// the value in the w component undefined.
b = c.hi;

// CIR-NEXT: %[[#VECC:]] = cir.load %[[#PVECC]] : !cir.ptr<!cir.vector<!s32i x 3>>, !cir.vector<!s32i x 3>
// CIR-NEXT: %[[#HIPART:]] = cir.vec.shuffle(%[[#VECC]], %[[#VECC]] : !cir.vector<!s32i x 3>) [#cir.int<2> : !s32i, #cir.int<3> : !s32i] : !cir.vector<!s32i x 2>
// CIR-NEXT: cir.store %[[#HIPART]], %[[#PVECB]] : !cir.vector<!s32i x 2>, !cir.ptr<!cir.vector<!s32i x 2>>

// LLVM-NEXT: %[[#VECC:]] = load <3 x i32>, ptr %[[#PVECC]], align 16
// LLVM-NEXT: %[[#HIPART:]] = shufflevector <3 x i32> %[[#VECC]], <3 x i32> %[[#VECC]], <2 x i32> <i32 2, i32 3>
// LLVM-NEXT: store <2 x i32> %[[#HIPART]], ptr %[[#PVECB]], align 8

// c.hi is c[2, 3], in which 3 should be ignored in CIRGen for store
c.hi = b;

// CIR-NEXT: %[[#VECB:]] = cir.load %[[#PVECB]] : !cir.ptr<!cir.vector<!s32i x 2>>, !cir.vector<!s32i x 2>
// CIR-NEXT: %[[#VECC:]] = cir.load %[[#PVECC]] : !cir.ptr<!cir.vector<!s32i x 3>>, !cir.vector<!s32i x 3>
// CIR-NEXT: %[[#EXTVECB:]] = cir.vec.shuffle(%[[#VECB]], %[[#VECB]] : !cir.vector<!s32i x 2>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<-1> : !s32i] : !cir.vector<!s32i x 3>
// CIR-NEXT: %[[#RESULT:]] = cir.vec.shuffle(%[[#VECC]], %[[#EXTVECB]] : !cir.vector<!s32i x 3>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<3> : !s32i] : !cir.vector<!s32i x 3>
// CIR-NEXT: cir.store %[[#RESULT]], %[[#PVECC]] : !cir.vector<!s32i x 3>, !cir.ptr<!cir.vector<!s32i x 3>>

// LLVM-NEXT: %[[#VECB:]] = load <2 x i32>, ptr %[[#PVECB]], align 8
// LLVM-NEXT: %[[#VECC:]] = load <3 x i32>, ptr %[[#PVECC]], align 16
// LLVM-NEXT: %[[#EXTVECB:]] = shufflevector <2 x i32> %[[#VECB]], <2 x i32> %[[#VECB]], <3 x i32> <i32 0, i32 1, i32 poison>
// LLVM-NEXT: %[[#RESULT:]] = shufflevector <3 x i32> %[[#VECC]], <3 x i32> %[[#EXTVECB]], <3 x i32> <i32 0, i32 1, i32 3>
// LLVM-NEXT: store <3 x i32> %[[#RESULT]], ptr %[[#PVECC]], align 16

}

// CIR: cir.func {{@.*test_build_lvalue.*}}
Expand Down Expand Up @@ -452,3 +486,27 @@ void test_build_lvalue() {
// LLVM-NEXT: store i32 %[[#RESULT]], ptr %[[#ALLOCAR]], align 4

}

// CIR: cir.func {{@.*test_vec3.*}}
// LLVM: define void {{@.*test_vec3.*}}
void test_vec3() {
vi3 v = {};
// CIR-NEXT: %[[#PV:]] = cir.alloca !cir.vector<!s32i x 3>, !cir.ptr<!cir.vector<!s32i x 3>>, ["v", init] {alignment = 16 : i64}
// CIR: %[[#VEC4:]] = cir.vec.shuffle(%{{[0-9]+}}, %{{[0-9]+}} : !cir.vector<!s32i x 3>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<-1> : !s32i] : !cir.vector<!s32i x 4>
// CIR-NEXT: %[[#PV4:]] = cir.cast(bitcast, %[[#PV]] : !cir.ptr<!cir.vector<!s32i x 3>>), !cir.ptr<!cir.vector<!s32i x 4>>
// CIR-NEXT: cir.store %[[#VEC4]], %[[#PV4]] : !cir.vector<!s32i x 4>, !cir.ptr<!cir.vector<!s32i x 4>>

// LLVM-NEXT: %[[#PV:]] = alloca <3 x i32>, i64 1, align 16
// LLVM-NEXT: store <4 x i32> <i32 0, i32 0, i32 0, i32 undef>, ptr %[[#PV]], align 16

v + 1;
// CIR-NEXT: %[[#PV4:]] = cir.cast(bitcast, %[[#PV]] : !cir.ptr<!cir.vector<!s32i x 3>>), !cir.ptr<!cir.vector<!s32i x 4>>
// CIR-NEXT: %[[#V4:]] = cir.load %[[#PV4]] : !cir.ptr<!cir.vector<!s32i x 4>>, !cir.vector<!s32i x 4>
// CIR-NEXT: %[[#V3:]] = cir.vec.shuffle(%[[#V4]], %[[#V4]] : !cir.vector<!s32i x 4>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i] : !cir.vector<!s32i x 3>
// CIR: %[[#RES:]] = cir.binop(add, %[[#V3]], %{{[0-9]+}}) : !cir.vector<!s32i x 3>

// LLVM-NEXT: %[[#V4:]] = load <4 x i32>, ptr %[[#PV:]], align 16
// LLVM-NEXT: %[[#V3:]] = shufflevector <4 x i32> %[[#V4]], <4 x i32> %[[#V4]], <3 x i32> <i32 0, i32 1, i32 2>
// LLVM-NEXT: %[[#RES:]] = add <3 x i32> %[[#V3]], <i32 1, i32 1, i32 1>

}

0 comments on commit 5e9148f

Please sign in to comment.