@@ -3413,6 +3413,154 @@ class SPIRVComplexFloatInst
34133413_SPIRV_OP (ComplexFMulINTEL)
34143414_SPIRV_OP(ComplexFDivINTEL)
34153415#undef _SPIRV_OP
3416+
3417+ class SPIRVMaskedGatherScatterINTELInstBase : public SPIRVInstTemplateBase {
3418+ protected:
3419+ SPIRVCapVec getRequiredCapability () const override {
3420+ return getVec (internal::CapabilityMaskedGatherScatterINTEL);
3421+ }
3422+ llvm::Optional<ExtensionID> getRequiredExtension () const override {
3423+ return ExtensionID::SPV_INTEL_masked_gather_scatter;
3424+ }
3425+ };
3426+
3427+ class SPIRVMaskedGatherINTELInst
3428+ : public SPIRVMaskedGatherScatterINTELInstBase {
3429+ void validate () const override {
3430+ SPIRVInstruction::validate ();
3431+ SPIRVErrorLog &SPVErrLog = this ->getModule ()->getErrorLog ();
3432+ std::string InstName = " MaskedGatherINTEL" ;
3433+
3434+ SPIRVType *ResTy = this ->getType ();
3435+ SPVErrLog.checkError (ResTy->isTypeVector (), SPIRVEC_InvalidInstruction,
3436+ InstName + " \n Result must be a vector type\n " );
3437+ SPIRVWord ResCompCount = ResTy->getVectorComponentCount ();
3438+ SPIRVType *ResCompTy = ResTy->getVectorComponentType ();
3439+
3440+ SPIRVValue *PtrVec =
3441+ const_cast <SPIRVMaskedGatherINTELInst *>(this )->getOperand (0 );
3442+ SPIRVType *PtrVecTy = PtrVec->getType ();
3443+ SPVErrLog.checkError (
3444+ PtrVecTy->isTypeVectorPointer (), SPIRVEC_InvalidInstruction,
3445+ InstName + " \n PtrVector must be a vector of pointers type\n " );
3446+ SPIRVWord PtrVecCompCount = PtrVecTy->getVectorComponentCount ();
3447+ SPIRVType *PtrVecCompTy = PtrVecTy->getVectorComponentType ();
3448+ SPIRVType *PtrElemTy = PtrVecCompTy->getPointerElementType ();
3449+
3450+ SPVErrLog.checkError (
3451+ this ->isOperandLiteral (1 ), SPIRVEC_InvalidInstruction,
3452+ InstName + " \n Alignment must be a constant expression integer\n " );
3453+ const uint32_t Align =
3454+ static_cast <SPIRVConstant *>(
3455+ const_cast <SPIRVMaskedGatherINTELInst *>(this )->getOperand (2 ))
3456+ ->getZExtIntValue ();
3457+ SPVErrLog.checkError (
3458+ ((Align & (Align - 1 )) == 0 ), SPIRVEC_InvalidInstruction,
3459+ InstName + " \n Alignment must be 0 or power-of-two integer\n " );
3460+
3461+ SPIRVValue *Mask =
3462+ const_cast <SPIRVMaskedGatherINTELInst *>(this )->getOperand (2 );
3463+ SPIRVType *MaskTy = Mask->getType ();
3464+ SPVErrLog.checkError (MaskTy->isTypeVector (), SPIRVEC_InvalidInstruction,
3465+ InstName + " \n Mask must be a vector type\n " );
3466+ SPIRVType *MaskCompTy = MaskTy->getVectorComponentType ();
3467+ SPVErrLog.checkError (MaskCompTy->isTypeBool (), SPIRVEC_InvalidInstruction,
3468+ InstName + " \n Mask must be a boolean vector type\n " );
3469+ SPIRVWord MaskCompCount = MaskTy->getVectorComponentCount ();
3470+
3471+ SPIRVValue *FillEmpty =
3472+ const_cast <SPIRVMaskedGatherINTELInst *>(this )->getOperand (3 );
3473+ SPIRVType *FillEmptyTy = FillEmpty->getType ();
3474+ SPVErrLog.checkError (FillEmptyTy->isTypeVector (),
3475+ SPIRVEC_InvalidInstruction,
3476+ InstName + " \n FillEmpty must be a vector type\n " );
3477+ SPIRVWord FillEmptyCompCount = FillEmptyTy->getVectorComponentCount ();
3478+ SPIRVType *FillEmptyCompTy = FillEmptyTy->getVectorComponentType ();
3479+
3480+ SPVErrLog.checkError (
3481+ ResCompCount == PtrVecCompCount &&
3482+ PtrVecCompCount == FillEmptyCompCount &&
3483+ FillEmptyCompCount == MaskCompCount,
3484+ SPIRVEC_InvalidInstruction,
3485+ InstName + " \n Result, PtrVector, Mask and FillEmpty vectors must have "
3486+ " the same size\n " );
3487+
3488+ SPVErrLog.checkError (
3489+ ResCompTy == PtrElemTy && PtrElemTy == FillEmptyCompTy,
3490+ SPIRVEC_InvalidInstruction,
3491+ InstName + " \n Component Type of Result and FillEmpty vector must be "
3492+ " same as base type of PtrVector the same base type\n " );
3493+ }
3494+ };
3495+
3496+ class SPIRVMaskedScatterINTELInst
3497+ : public SPIRVMaskedGatherScatterINTELInstBase {
3498+ void validate () const override {
3499+ SPIRVInstruction::validate ();
3500+ SPIRVErrorLog &SPVErrLog = this ->getModule ()->getErrorLog ();
3501+ std::string InstName = " MaskedScatterINTEL" ;
3502+
3503+ SPIRVValue *InputVec =
3504+ const_cast <SPIRVMaskedScatterINTELInst *>(this )->getOperand (0 );
3505+ SPIRVType *InputVecTy = InputVec->getType ();
3506+ SPVErrLog.checkError (
3507+ InputVecTy->isTypeVector (), SPIRVEC_InvalidInstruction,
3508+ InstName + " \n InputVector must be a vector of pointers type\n " );
3509+ SPIRVWord InputVecCompCount = InputVecTy->getVectorComponentCount ();
3510+ SPIRVType *InputVecCompTy = InputVecTy->getVectorComponentType ();
3511+
3512+ SPIRVValue *PtrVec =
3513+ const_cast <SPIRVMaskedScatterINTELInst *>(this )->getOperand (1 );
3514+ SPIRVType *PtrVecTy = PtrVec->getType ();
3515+ SPVErrLog.checkError (
3516+ PtrVecTy->isTypeVectorPointer (), SPIRVEC_InvalidInstruction,
3517+ InstName + " \n PtrVector must be a vector of pointers type\n " );
3518+ SPIRVWord PtrVecCompCount = PtrVecTy->getVectorComponentCount ();
3519+ SPIRVType *PtrVecCompTy = PtrVecTy->getVectorComponentType ();
3520+ SPIRVType *PtrElemTy = PtrVecCompTy->getPointerElementType ();
3521+
3522+ SPVErrLog.checkError (
3523+ this ->isOperandLiteral (2 ), SPIRVEC_InvalidInstruction,
3524+ InstName + " \n Alignment must be a constant expression integer\n " );
3525+ const uint32_t Align =
3526+ static_cast <SPIRVConstant *>(
3527+ const_cast <SPIRVMaskedScatterINTELInst *>(this )->getOperand (2 ))
3528+ ->getZExtIntValue ();
3529+ SPVErrLog.checkError (
3530+ ((Align & (Align - 1 )) == 0 ), SPIRVEC_InvalidInstruction,
3531+ InstName + " \n Alignment must be 0 or power-of-two integer\n " );
3532+
3533+ SPIRVValue *Mask =
3534+ const_cast <SPIRVMaskedScatterINTELInst *>(this )->getOperand (2 );
3535+ SPIRVType *MaskTy = Mask->getType ();
3536+ SPVErrLog.checkError (MaskTy->isTypeVector (), SPIRVEC_InvalidInstruction,
3537+ InstName + " \n Mask must be a vector type\n " );
3538+ SPIRVType *MaskCompTy = MaskTy->getVectorComponentType ();
3539+ SPVErrLog.checkError (MaskCompTy->isTypeBool (), SPIRVEC_InvalidInstruction,
3540+ InstName + " \n Mask must be a boolean vector type\n " );
3541+ SPIRVWord MaskCompCount = MaskTy->getVectorComponentCount ();
3542+
3543+ SPVErrLog.checkError (
3544+ InputVecCompCount == PtrVecCompCount &&
3545+ PtrVecCompCount == MaskCompCount,
3546+ SPIRVEC_InvalidInstruction,
3547+ InstName + " \n InputVector, PtrVector and Mask vectors must have "
3548+ " the same size\n " );
3549+
3550+ SPVErrLog.checkError (
3551+ InputVecCompTy == PtrElemTy, SPIRVEC_InvalidInstruction,
3552+ InstName + " \n Component Type of InputVector must be "
3553+ " same as base type of PtrVector the same base type\n " );
3554+ }
3555+ };
3556+
3557+ #define _SPIRV_OP (x, ...) \
3558+ typedef SPIRVInstTemplate<SPIRVMaskedGatherScatterINTELInstBase, \
3559+ internal::Op##x##INTEL, __VA_ARGS__> \
3560+ SPIRV##x##INTEL;
3561+ _SPIRV_OP (MaskedGather, true , 7 )
3562+ _SPIRV_OP(MaskedScatter, false , 5 )
3563+ #undef _SPIRV_OP
34163564} // namespace SPIRV
34173565
34183566#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H
0 commit comments