@@ -1055,6 +1055,136 @@ func.func @warpgroup_mma_store(
10551055 return
10561056}
10571057
1058+ // CHECK-LABEL: @warpgroup_mma_store_multiple
1059+ func.func @warpgroup_mma_store_multiple (
1060+ %shmem_m64n8k : memref <64 x8 xf32 >,
1061+ %shmem_m64n16k : memref <64 x16 xf32 >,
1062+ %shmem_m64n24k : memref <64 x24 xf32 >,
1063+ %shmem_m64n32k : memref <64 x32 xf32 >,
1064+ %shmem_m64n40k : memref <64 x40 xf32 >,
1065+ %shmem_m64n48k : memref <64 x48 xf32 >,
1066+ %shmem_m64n56k : memref <64 x56 xf32 >,
1067+ %shmem_m64n64k : memref <64 x64 xf32 >,
1068+ %shmem_m64n72k : memref <64 x72 xf32 >,
1069+ %shmem_m64n80k : memref <64 x80 xf32 >,
1070+ %shmem_m64n88k : memref <64 x88 xf32 >,
1071+ %shmem_m64n96k : memref <64 x96 xf32 >,
1072+ %shmem_m64n104k : memref <64 x104 xf32 >,
1073+ %shmem_m64n112k : memref <64 x112 xf32 >,
1074+ %shmem_m64n120k : memref <64 x120 xf32 >,
1075+ %shmem_m64n128k : memref <64 x128 xf32 >,
1076+ %shmem_m64n136k : memref <64 x136 xf32 >,
1077+ %shmem_m64n144k : memref <64 x144 xf32 >,
1078+ %shmem_m64n152k : memref <64 x152 xf32 >,
1079+ %shmem_m64n160k : memref <64 x160 xf32 >,
1080+ %shmem_m64n168k : memref <64 x168 xf32 >,
1081+ %shmem_m64n176k : memref <64 x176 xf32 >,
1082+ %shmem_m64n184k : memref <64 x184 xf32 >,
1083+ %shmem_m64n192k : memref <64 x192 xf32 >,
1084+ %shmem_m64n200k : memref <64 x200 xf32 >,
1085+ %shmem_m64n208k : memref <64 x208 xf32 >,
1086+ %shmem_m64n216k : memref <64 x216 xf32 >,
1087+ %shmem_m64n224k : memref <64 x224 xf32 >,
1088+ %shmem_m64n232k : memref <64 x232 xf32 >,
1089+ %shmem_m64n240k : memref <64 x240 xf32 >,
1090+ %shmem_m64n248k : memref <64 x248 xf32 >,
1091+ %shmem_m64n256k : memref <64 x256 xf32 >,
1092+ %res_m64n16k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x16 xf32 >>,
1093+ %res_m64n24k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x24 xf32 >>,
1094+ %res_m64n32k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x32 xf32 >>,
1095+ %res_m64n40k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x40 xf32 >>,
1096+ %res_m64n48k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x48 xf32 >>,
1097+ %res_m64n56k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x56 xf32 >>,
1098+ %res_m64n64k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x64 xf32 >>,
1099+ %res_m64n72k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x72 xf32 >>,
1100+ %res_m64n80k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x80 xf32 >>,
1101+ %res_m64n88k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x88 xf32 >>,
1102+ %res_m64n96k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x96 xf32 >>,
1103+ %res_m64n104k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x104 xf32 >>,
1104+ %res_m64n112k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x112 xf32 >>,
1105+ %res_m64n120k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x120 xf32 >>,
1106+ %res_m64n128k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x128 xf32 >>,
1107+ %res_m64n136k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x136 xf32 >>,
1108+ %res_m64n144k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x144 xf32 >>,
1109+ %res_m64n152k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x152 xf32 >>,
1110+ %res_m64n160k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x160 xf32 >>,
1111+ %res_m64n168k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x168 xf32 >>,
1112+ %res_m64n176k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x176 xf32 >>,
1113+ %res_m64n184k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x184 xf32 >>,
1114+ %res_m64n192k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x192 xf32 >>,
1115+ %res_m64n200k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x200 xf32 >>,
1116+ %res_m64n208k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x208 xf32 >>,
1117+ %res_m64n216k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x216 xf32 >>,
1118+ %res_m64n224k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x224 xf32 >>,
1119+ %res_m64n232k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x232 xf32 >>,
1120+ %res_m64n240k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x240 xf32 >>,
1121+ %res_m64n248k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x248 xf32 >>,
1122+ %res_m64n256k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x256 xf32 >>) {
1123+ // CHECK-COUNT-8: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32>
1124+ // CHECK-COUNT-12: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x24xf32>
1125+ // CHECK-COUNT-16: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x32xf32>
1126+ // CHECK-COUNT-20: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x40xf32>
1127+ // CHECK-COUNT-24: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x48xf32>
1128+ // CHECK-COUNT-28: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x56xf32>
1129+ // CHECK-COUNT-32: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x64xf32>
1130+ // CHECK-COUNT-36: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x72xf32>
1131+ // CHECK-COUNT-40: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x80xf32>
1132+ // CHECK-COUNT-44: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x88xf32>
1133+ // CHECK-COUNT-48: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x96xf32>
1134+ // CHECK-COUNT-52: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x104xf32>
1135+ // CHECK-COUNT-56: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x112xf32>
1136+ // CHECK-COUNT-60: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x120xf32>
1137+ // CHECK-COUNT-64: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x128xf32>
1138+ // CHECK-COUNT-68: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x136xf32>
1139+ // CHECK-COUNT-72: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x144xf32>
1140+ // CHECK-COUNT-76: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x152xf32>
1141+ // CHECK-COUNT-80: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x160xf32>
1142+ // CHECK-COUNT-84: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x168xf32>
1143+ // CHECK-COUNT-88: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x176xf32>
1144+ // CHECK-COUNT-92: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x184xf32>
1145+ // CHECK-COUNT-96: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x192xf32>
1146+ // CHECK-COUNT-100: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x200xf32>
1147+ // CHECK-COUNT-104: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x208xf32>
1148+ // CHECK-COUNT-108: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x216xf32>
1149+ // CHECK-COUNT-112: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x224xf32>
1150+ // CHECK-COUNT-116: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x232xf32>
1151+ // CHECK-COUNT-120: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x240xf32>
1152+ // CHECK-COUNT-124: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x248xf32>
1153+ // CHECK-COUNT-128: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x256xf32>
1154+ nvgpu.warpgroup.mma.store %res_m64n16k , %shmem_m64n16k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x16 xf32 >> to memref <64 x16 xf32 >
1155+ nvgpu.warpgroup.mma.store %res_m64n24k , %shmem_m64n24k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x24 xf32 >> to memref <64 x24 xf32 >
1156+ nvgpu.warpgroup.mma.store %res_m64n32k , %shmem_m64n32k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x32 xf32 >> to memref <64 x32 xf32 >
1157+ nvgpu.warpgroup.mma.store %res_m64n40k , %shmem_m64n40k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x40 xf32 >> to memref <64 x40 xf32 >
1158+ nvgpu.warpgroup.mma.store %res_m64n48k , %shmem_m64n48k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x48 xf32 >> to memref <64 x48 xf32 >
1159+ nvgpu.warpgroup.mma.store %res_m64n56k , %shmem_m64n56k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x56 xf32 >> to memref <64 x56 xf32 >
1160+ nvgpu.warpgroup.mma.store %res_m64n64k , %shmem_m64n64k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x64 xf32 >> to memref <64 x64 xf32 >
1161+ nvgpu.warpgroup.mma.store %res_m64n72k , %shmem_m64n72k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x72 xf32 >> to memref <64 x72 xf32 >
1162+ nvgpu.warpgroup.mma.store %res_m64n80k , %shmem_m64n80k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x80 xf32 >> to memref <64 x80 xf32 >
1163+ nvgpu.warpgroup.mma.store %res_m64n88k , %shmem_m64n88k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x88 xf32 >> to memref <64 x88 xf32 >
1164+ nvgpu.warpgroup.mma.store %res_m64n96k , %shmem_m64n96k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x96 xf32 >> to memref <64 x96 xf32 >
1165+ nvgpu.warpgroup.mma.store %res_m64n104k , %shmem_m64n104k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x104 xf32 >> to memref <64 x104 xf32 >
1166+ nvgpu.warpgroup.mma.store %res_m64n112k , %shmem_m64n112k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x112 xf32 >> to memref <64 x112 xf32 >
1167+ nvgpu.warpgroup.mma.store %res_m64n120k , %shmem_m64n120k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x120 xf32 >> to memref <64 x120 xf32 >
1168+ nvgpu.warpgroup.mma.store %res_m64n128k , %shmem_m64n128k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x128 xf32 >> to memref <64 x128 xf32 >
1169+ nvgpu.warpgroup.mma.store %res_m64n136k , %shmem_m64n136k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x136 xf32 >> to memref <64 x136 xf32 >
1170+ nvgpu.warpgroup.mma.store %res_m64n144k , %shmem_m64n144k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x144 xf32 >> to memref <64 x144 xf32 >
1171+ nvgpu.warpgroup.mma.store %res_m64n152k , %shmem_m64n152k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x152 xf32 >> to memref <64 x152 xf32 >
1172+ nvgpu.warpgroup.mma.store %res_m64n160k , %shmem_m64n160k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x160 xf32 >> to memref <64 x160 xf32 >
1173+ nvgpu.warpgroup.mma.store %res_m64n168k , %shmem_m64n168k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x168 xf32 >> to memref <64 x168 xf32 >
1174+ nvgpu.warpgroup.mma.store %res_m64n176k , %shmem_m64n176k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x176 xf32 >> to memref <64 x176 xf32 >
1175+ nvgpu.warpgroup.mma.store %res_m64n184k , %shmem_m64n184k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x184 xf32 >> to memref <64 x184 xf32 >
1176+ nvgpu.warpgroup.mma.store %res_m64n192k , %shmem_m64n192k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x192 xf32 >> to memref <64 x192 xf32 >
1177+ nvgpu.warpgroup.mma.store %res_m64n200k , %shmem_m64n200k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x200 xf32 >> to memref <64 x200 xf32 >
1178+ nvgpu.warpgroup.mma.store %res_m64n208k , %shmem_m64n208k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x208 xf32 >> to memref <64 x208 xf32 >
1179+ nvgpu.warpgroup.mma.store %res_m64n216k , %shmem_m64n216k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x216 xf32 >> to memref <64 x216 xf32 >
1180+ nvgpu.warpgroup.mma.store %res_m64n224k , %shmem_m64n224k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x224 xf32 >> to memref <64 x224 xf32 >
1181+ nvgpu.warpgroup.mma.store %res_m64n232k , %shmem_m64n232k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x232 xf32 >> to memref <64 x232 xf32 >
1182+ nvgpu.warpgroup.mma.store %res_m64n240k , %shmem_m64n240k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x240 xf32 >> to memref <64 x240 xf32 >
1183+ nvgpu.warpgroup.mma.store %res_m64n248k , %shmem_m64n248k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x248 xf32 >> to memref <64 x248 xf32 >
1184+ nvgpu.warpgroup.mma.store %res_m64n256k , %shmem_m64n256k : !nvgpu.warpgroup.accumulator <fragmented = vector <64 x256 xf32 >> to memref <64 x256 xf32 >
1185+ return
1186+ }
1187+
10581188func.func @warpgroup_mma_init () {
10591189 //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
10601190 //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
0 commit comments