-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Clang] Add vector gather / scatter builtins to clang #157895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: Joseph Huber (jhuber6) ChangesSummary: Full diff: https://github.com/llvm/llvm-project/pull/157895.diff 7 Files Affected:
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index ad190eace5b05..f3ce5ee534609 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -957,6 +957,11 @@ builtins have the same interface but store the result in consecutive indices.
Effectively this performs the ``if (mask[i]) val[i] = ptr[j++]`` and ``if
(mask[i]) ptr[j++] = val[i]`` pattern respectively.
+The ``__builtin_masked_gather`` and ``__builtin_masked_scatter`` builtins handle
+non-sequential memory access for vector types. These use a base pointer and a
+vector of integer indices to gather memory into a vector type or scatter it to
+separate indices.
+
Example:
.. code-block:: c++
@@ -978,6 +983,14 @@ Example:
__builtin_masked_compress_store(mask, val, ptr);
}
+ v8i gather(v8b mask, v8i idx, int *ptr) {
+ return __builtin_masked_gather(mask, idx, ptr);
+ }
+
+ void scatter(v8b mask, v8i val, v8i idx, int *ptr) {
+ __builtin_masked_scatter(mask, idx, val, ptr);
+ }
+
Matrix Types
============
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index a20b1ab298f9c..8094e1d8ca4cb 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -183,6 +183,10 @@ Non-comprehensive list of changes in this release
conditional memory loads from vectors. Binds to the LLVM intrinsics of the
same name.
+- Added ``__builtin_masked_gather`` and ``__builtin_masked_scatter`` for
+ conditional gathering and scattering operations on vectors. Binds to the LLVM
+ intrinsics of the same name.
+
- The ``__builtin_popcountg``, ``__builtin_ctzg``, and ``__builtin_clzg``
functions now accept fixed-size boolean vectors.
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 27639f06529cb..97be087aa752a 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -1256,6 +1256,18 @@ def MaskedCompressStore : Builtin {
let Prototype = "void(...)";
}
+def MaskedGather : Builtin {
+ let Spellings = ["__builtin_masked_gather"];
+ let Attributes = [NoThrow, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
+def MaskedScatter : Builtin {
+ let Spellings = ["__builtin_masked_scatter"];
+ let Attributes = [NoThrow, CustomTypeChecking];
+ let Prototype = "void(...)";
+}
+
def AllocaUninitialized : Builtin {
let Spellings = ["__builtin_alloca_uninitialized"];
let Attributes = [FunctionWithBuiltinPrefix, NoThrow];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 172a521e63c17..ef50ba8328bfd 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -4298,6 +4298,30 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
}
return RValue::get(Result);
};
+ case Builtin::BI__builtin_masked_gather: {
+ llvm::Value *Mask = EmitScalarExpr(E->getArg(0));
+ llvm::Value *Idx = EmitScalarExpr(E->getArg(1));
+ llvm::Value *Ptr = EmitScalarExpr(E->getArg(2));
+
+ llvm::Type *RetTy = CGM.getTypes().ConvertType(E->getType());
+ llvm::Type *ElemTy = CGM.getTypes().ConvertType(
+ E->getType()->getAs<VectorType>()->getElementType());
+ llvm::Value *AlignVal = llvm::ConstantInt::get(Int32Ty, 1);
+
+ llvm::Value *PassThru = llvm::PoisonValue::get(RetTy);
+ if (E->getNumArgs() > 3)
+ PassThru = EmitScalarExpr(E->getArg(3));
+
+ llvm::Value *PtrVec = Builder.CreateGEP(ElemTy, Ptr, Idx);
+
+ llvm::Value *Result;
+ Function *F =
+ CGM.getIntrinsic(Intrinsic::masked_gather, {RetTy, PtrVec->getType()});
+
+ Result = Builder.CreateCall(F, {PtrVec, AlignVal, Mask, PassThru},
+ "masked_gather");
+ return RValue::get(Result);
+ }
case Builtin::BI__builtin_masked_store:
case Builtin::BI__builtin_masked_compress_store: {
llvm::Value *Mask = EmitScalarExpr(E->getArg(0));
@@ -4323,7 +4347,24 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
}
return RValue::get(nullptr);
}
+ case Builtin::BI__builtin_masked_scatter: {
+ llvm::Value *Mask = EmitScalarExpr(E->getArg(0));
+ llvm::Value *Val = EmitScalarExpr(E->getArg(1));
+ llvm::Value *Idx = EmitScalarExpr(E->getArg(2));
+ llvm::Value *Ptr = EmitScalarExpr(E->getArg(3));
+ llvm::Type *ElemTy = CGM.getTypes().ConvertType(
+ E->getArg(1)->getType()->getAs<VectorType>()->getElementType());
+ llvm::Value *AlignVal = llvm::ConstantInt::get(Int32Ty, 1);
+
+ llvm::Value *PtrVec = Builder.CreateGEP(ElemTy, Ptr, Idx);
+
+ Function *F = CGM.getIntrinsic(Intrinsic::masked_scatter,
+ {Val->getType(), PtrVec->getType()});
+
+ Builder.CreateCall(F, {Val, PtrVec, AlignVal, Mask});
+ return RValue();
+ }
case Builtin::BI__builtin_isinf_sign: {
// isinf_sign(x) -> fabs(x) == infinity ? (signbit(x) ? -1 : 1) : 0
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(*this, E);
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 077f4311ed729..6634f38182e41 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2270,7 +2270,7 @@ static bool BuiltinCountZeroBitsGeneric(Sema &S, CallExpr *TheCall) {
}
static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
- unsigned Pos) {
+ unsigned Pos, bool Vector = true) {
QualType MaskTy = MaskArg->getType();
if (!MaskTy->isExtVectorBoolType())
return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
@@ -2278,9 +2278,11 @@ static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
<< MaskTy;
QualType PtrTy = PtrArg->getType();
- if (!PtrTy->isPointerType() || !PtrTy->getPointeeType()->isVectorType())
+ if (!PtrTy->isPointerType() ||
+ (Vector && !PtrTy->getPointeeType()->isVectorType()) ||
+ (!Vector && PtrTy->getPointeeType()->isVectorType()))
return S.Diag(PtrArg->getExprLoc(), diag::err_vec_masked_load_store_ptr)
- << Pos << "pointer to vector";
+ << Pos << (Vector ? "pointer to vector" : "scalar pointer");
return false;
}
@@ -2361,6 +2363,101 @@ static ExprResult BuiltinMaskedStore(Sema &S, CallExpr *TheCall) {
return TheCall;
}
+static ExprResult BuiltinMaskedGather(Sema &S, CallExpr *TheCall) {
+ if (S.checkArgCountRange(TheCall, 3, 4))
+ return ExprError();
+
+ Expr *MaskArg = TheCall->getArg(0);
+ Expr *IdxArg = TheCall->getArg(1);
+ Expr *PtrArg = TheCall->getArg(2);
+ if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*Vector=*/false))
+ return ExprError();
+
+ QualType IdxTy = IdxArg->getType();
+ const VectorType *IdxVecTy = IdxTy->getAs<VectorType>();
+ if (!IdxTy->isExtVectorType() || !IdxVecTy->getElementType()->isIntegerType())
+ return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+ << 1 << /* vector of */ 4 << /* integer */ 1 << /* no fp */ 0
+ << IdxTy;
+
+ QualType MaskTy = MaskArg->getType();
+ QualType PtrTy = PtrArg->getType();
+ QualType PointeeTy = PtrTy->getPointeeType();
+ const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();
+ if (MaskVecTy->getNumElements() != IdxVecTy->getNumElements())
+ return ExprError(
+ S.Diag(TheCall->getBeginLoc(), diag::err_vec_masked_load_store_size)
+ << S.getASTContext().BuiltinInfo.getQuotedName(
+ TheCall->getBuiltinCallee())
+ << MaskTy << IdxTy);
+
+ QualType RetTy =
+ S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
+ if (TheCall->getNumArgs() == 4) {
+ Expr *PassThruArg = TheCall->getArg(3);
+ QualType PassThruTy = PassThruArg->getType();
+ if (!S.Context.hasSameType(PassThruTy, RetTy))
+ return S.Diag(PassThruArg->getExprLoc(),
+ diag::err_vec_masked_load_store_ptr)
+ << /* fourth argument */ 4 << RetTy;
+ }
+
+ TheCall->setType(RetTy);
+ return TheCall;
+}
+
+static ExprResult BuiltinMaskedScatter(Sema &S, CallExpr *TheCall) {
+ if (S.checkArgCount(TheCall, 4))
+ return ExprError();
+
+ Expr *MaskArg = TheCall->getArg(0);
+ Expr *IdxArg = TheCall->getArg(1);
+ Expr *ValArg = TheCall->getArg(2);
+ Expr *PtrArg = TheCall->getArg(3);
+
+ if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*Vector=*/false))
+ return ExprError();
+
+ QualType IdxTy = IdxArg->getType();
+ const VectorType *IdxVecTy = IdxTy->getAs<VectorType>();
+ if (!IdxTy->isExtVectorType() || !IdxVecTy->getElementType()->isIntegerType())
+ return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+ << 2 << /* vector of */ 4 << /* integer */ 1 << /* no fp */ 0
+ << IdxTy;
+
+ QualType ValTy = ValArg->getType();
+ QualType MaskTy = MaskArg->getType();
+ QualType PtrTy = PtrArg->getType();
+ QualType PointeeTy = PtrTy->getPointeeType();
+
+ const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();
+ const VectorType *ValVecTy = ValTy->getAs<VectorType>();
+ if (MaskVecTy->getNumElements() != IdxVecTy->getNumElements())
+ return ExprError(
+ S.Diag(TheCall->getBeginLoc(), diag::err_vec_masked_load_store_size)
+ << S.getASTContext().BuiltinInfo.getQuotedName(
+ TheCall->getBuiltinCallee())
+ << MaskTy << IdxTy);
+ if (MaskVecTy->getNumElements() != ValVecTy->getNumElements())
+ return ExprError(
+ S.Diag(TheCall->getBeginLoc(), diag::err_vec_masked_load_store_size)
+ << S.getASTContext().BuiltinInfo.getQuotedName(
+ TheCall->getBuiltinCallee())
+ << MaskTy << ValTy);
+
+ QualType ArgTy =
+ S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
+ if (!S.Context.hasSameType(ValTy, ArgTy))
+ return ExprError(S.Diag(TheCall->getBeginLoc(),
+ diag::err_vec_builtin_incompatible_vector)
+ << TheCall->getDirectCallee() << /*isMorethantwoArgs*/ 2
+ << SourceRange(TheCall->getArg(1)->getBeginLoc(),
+ TheCall->getArg(1)->getEndLoc()));
+
+ TheCall->setType(S.Context.VoidTy);
+ return TheCall;
+}
+
static ExprResult BuiltinInvoke(Sema &S, CallExpr *TheCall) {
SourceLocation Loc = TheCall->getBeginLoc();
MutableArrayRef Args(TheCall->getArgs(), TheCall->getNumArgs());
@@ -2619,6 +2716,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
case Builtin::BI__builtin_masked_store:
case Builtin::BI__builtin_masked_compress_store:
return BuiltinMaskedStore(*this, TheCall);
+ case Builtin::BI__builtin_masked_gather:
+ return BuiltinMaskedGather(*this, TheCall);
+ case Builtin::BI__builtin_masked_scatter:
+ return BuiltinMaskedScatter(*this, TheCall);
case Builtin::BI__builtin_invoke:
return BuiltinInvoke(*this, TheCall);
case Builtin::BI__builtin_prefetch:
diff --git a/clang/test/CodeGen/builtin-masked.c b/clang/test/CodeGen/builtin-masked.c
index 579cf5c413c9b..66e6d10f1f3b1 100644
--- a/clang/test/CodeGen/builtin-masked.c
+++ b/clang/test/CodeGen/builtin-masked.c
@@ -129,3 +129,61 @@ void test_store(v8b m, v8i v, v8i *p) {
void test_compress_store(v8b m, v8i v, v8i *p) {
__builtin_masked_compress_store(m, v, p);
}
+
+// CHECK-LABEL: define dso_local <8 x i32> @test_gather(
+// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr noundef [[PTR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
+// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
+// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
+// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr, align 8
+// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
+// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
+// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
+// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
+// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
+// CHECK-NEXT: store i8 [[TMP1]], ptr [[MASK_ADDR]], align 1
+// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
+// CHECK-NEXT: store ptr [[PTR]], ptr [[PTR_ADDR]], align 8
+// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
+// CHECK-NEXT: [[TMP2:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
+// CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
+// CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[PTR_ADDR]], align 8
+// CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32, ptr [[TMP4]], <8 x i32> [[TMP3]]
+// CHECK-NEXT: [[MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[TMP5]], i32 1, <8 x i1> [[TMP2]], <8 x i32> poison)
+// CHECK-NEXT: ret <8 x i32> [[MASKED_GATHER]]
+//
+v8i test_gather(v8b mask, v8i idx, int *ptr) {
+ return __builtin_masked_gather(mask, idx, ptr);
+}
+
+// CHECK-LABEL: define dso_local void @test_scatter(
+// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP1:%.*]], ptr noundef [[PTR:%.*]]) #[[ATTR3]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
+// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
+// CHECK-NEXT: [[VAL_ADDR:%.*]] = alloca <8 x i32>, align 32
+// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
+// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr, align 8
+// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
+// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
+// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
+// CHECK-NEXT: [[VAL:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
+// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP1]], align 32
+// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
+// CHECK-NEXT: store i8 [[TMP2]], ptr [[MASK_ADDR]], align 1
+// CHECK-NEXT: store <8 x i32> [[VAL]], ptr [[VAL_ADDR]], align 32
+// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
+// CHECK-NEXT: store ptr [[PTR]], ptr [[PTR_ADDR]], align 8
+// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
+// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
+// CHECK-NEXT: [[TMP4:%.*]] = load <8 x i32>, ptr [[VAL_ADDR]], align 32
+// CHECK-NEXT: [[TMP5:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
+// CHECK-NEXT: [[TMP6:%.*]] = load ptr, ptr [[PTR_ADDR]], align 8
+// CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], <8 x i32> [[TMP5]]
+// CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> [[TMP4]], <8 x ptr> [[TMP7]], i32 1, <8 x i1> [[TMP3]])
+// CHECK-NEXT: ret void
+//
+void test_scatter(v8b mask, v8i val, v8i idx, int *ptr) {
+ __builtin_masked_scatter(mask, val, idx, ptr);
+}
diff --git a/clang/test/Sema/builtin-masked.c b/clang/test/Sema/builtin-masked.c
index 05c6580651964..eb0070b0276af 100644
--- a/clang/test/Sema/builtin-masked.c
+++ b/clang/test/Sema/builtin-masked.c
@@ -44,3 +44,23 @@ void test_masked_compress_store(v8i *pf, v8f *pf2, v8b mask, v2b mask2) {
__builtin_masked_compress_store(mask2, *pf, pf); // expected-error {{all arguments to '__builtin_masked_compress_store' must have the same number of elements}}
__builtin_masked_compress_store(mask, *pf, pf2); // expected-error {{last two arguments to '__builtin_masked_compress_store' must have the same type}}
}
+
+void test_masked_gather(int *p, v8i idx, v8b mask, v2b mask2, v2b thru) {
+ __builtin_masked_gather(mask); // expected-error {{too few arguments to function call, expected 3, have 1}}
+ __builtin_masked_gather(mask, p, p, p, p, p); // expected-error {{too many arguments to function call, expected at most 4, have 6}}
+ __builtin_masked_gather(p, p, p); // expected-error {{1st argument must be a vector of boolean types (was 'int *')}}
+ __builtin_masked_gather(mask, p, p); // expected-error {{1st argument must be a vector of integer types (was 'int *')}}
+ __builtin_masked_gather(mask2, idx, p); // expected-error {{all arguments to '__builtin_masked_gather' must have the same number of elements (was 'v2b'}}
+ __builtin_masked_gather(mask, idx, p, thru); // expected-error {{4th argument must be a 'int __attribute__((ext_vector_type(8)))' (vector of 8 'int' values)}}
+ __builtin_masked_gather(mask, idx, &idx); // expected-error {{3rd argument must be a scalar pointer}}
+}
+
+void test_masked_scatter(int *p, v8i idx, v8b mask, v2b mask2, v8i val) {
+ __builtin_masked_scatter(mask); // expected-error {{too few arguments to function call, expected 4, have 1}}
+ __builtin_masked_scatter(mask, p, p, p, p, p); // expected-error {{too many arguments to function call, expected 4, have 6}}
+ __builtin_masked_scatter(p, p, p, p); // expected-error {{1st argument must be a vector of boolean types (was 'int *')}}
+ __builtin_masked_scatter(mask, p, p, p); // expected-error {{2nd argument must be a vector of integer types (was 'int *')}}
+ __builtin_masked_scatter(mask, idx, mask, p); // expected-error {{last two arguments to '__builtin_masked_scatter' must have the same type}}
+ __builtin_masked_scatter(mask, idx, val, idx); // expected-error {{3rd argument must be a scalar pointer}}
+ __builtin_masked_scatter(mask, idx, val, &idx); // expected-error {{3rd argument must be a scalar pointer}}
+}
|
c5ee870
to
6e5fdeb
Compare
Ping, one of the last remaining functions we're intending to use in the LLVM libmvec. |
Summary: This patch exposes `__builtin_masked_gather` and `__builtin_masked_scatter` to clang. These map to the underlying intrinsic relatively cleanly, needing only a level of indirection to take a vector of indices and a base pointer to a vector of pointers.
ping |
Summary:
This patch exposes
__builtin_masked_gather
and__builtin_masked_scatter
to clang. These map to the underlyingintrinsic relatively cleanly, needing only a level of indirection to
take a vector of indices and a base pointer to a vector of pointers.