@@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
842842// / one sparse level in the list.
843843static Operation *genCoIteration (CodegenEnv &env, OpBuilder &builder,
844844 ArrayRef<TensorLevel> tidLvls,
845- bool tryParallel, bool needsUniv) {
845+ unsigned numCases, bool tryParallel,
846+ bool needsUniv) {
846847 Operation *loop = *env.genLoopBoundary ([&](MutableArrayRef<Value> reduc) {
847848 // Construct while-loop with a parameter for each index.
848849 return env.emitter ().enterCoIterationOverTensorsAtLvls (
849- builder, env.op ().getLoc (), tidLvls, reduc, tryParallel, needsUniv);
850+ builder, env.op ().getLoc (), tidLvls, numCases, reduc, tryParallel,
851+ needsUniv);
850852 });
851853 assert (loop);
852854 return loop;
@@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
855857// / Generates a for-loop or a while-loop, depending on whether it implements
856858// / singleton iteration or co-iteration over the given conjunction.
857859static Operation *genLoop (CodegenEnv &env, OpBuilder &builder, LoopId curr,
858- bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
860+ unsigned numCases, bool needsUniv,
861+ ArrayRef<TensorLevel> tidLvls) {
859862 bool tryParallel = shouldTryParallize (env, curr, tidLvls);
860- return genCoIteration (env, builder, tidLvls, tryParallel, needsUniv);
863+ return genCoIteration (env, builder, tidLvls, numCases, tryParallel,
864+ needsUniv);
861865}
862866
863867// / Generates the induction structure for a while-loop.
@@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
900904 // basic block where scf::Yield should be inserted.
901905}
902906
907+ // / Generates a case region in the coiterate operation.
908+ static void genCoIterationCase (CodegenEnv &env, OpBuilder &builder,
909+ unsigned caseIdx, LatPointId allCase,
910+ LatPointId curCase,
911+ MutableArrayRef<Value> reduc) {
912+ assert (allCase == curCase || env.merger ().latGT (allCase, curCase));
913+ const BitVector &allCaseBits = env.merger ().lat (allCase).simple ;
914+ const BitVector &curCaseBits = env.merger ().lat (curCase).simple ;
915+
916+ // / Computes the subset of iterators that are valid in the current case being
917+ // / generated.
918+ I64BitSet caseBit (0 );
919+ for (auto [idx, set] : llvm::enumerate (allCaseBits.set_bits ()))
920+ if (curCaseBits.test (set))
921+ caseBit.set (idx);
922+
923+ env.emitter ().enterCurrentCoIterationCase (builder, env.op ().getLoc (), caseBit,
924+ caseIdx, reduc);
925+ }
926+
903927// / Generates a single if-statement within a while-loop.
904928static scf::IfOp genIf (CodegenEnv &env, OpBuilder &builder, LoopId curr,
905929 LatPointId p) {
@@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
11751199// / Starts a single loop in current sequence.
11761200static std::pair<Operation *, bool > startLoop (CodegenEnv &env,
11771201 OpBuilder &builder, LoopId curr,
1178- LatPointId li, bool needsUniv) {
1202+ LatPointId li, unsigned numCases,
1203+ bool needsUniv) {
1204+ // TODO: numCases only used when generating iterator-based loops. Cleanup
1205+ // after fully migration.
11791206 // The set of tensors + lvls to generate loops on
11801207 SmallVector<TensorLevel> tidLvls;
11811208
@@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11861213 translateBitsToTidLvlPairs (env, li, curr, tidLvls, affineTidLvls);
11871214
11881215 // Emit the for/while-loop control.
1189- Operation *loop = genLoop (env, builder, curr, needsUniv, tidLvls);
1216+ Operation *loop = genLoop (env, builder, curr, numCases, needsUniv, tidLvls);
11901217 Location loc = env.op ().getLoc ();
11911218 for (auto [tidLvl, exp] : affineTidLvls) {
11921219 env.emitter ().locateLvlAtAffineAddress (builder, loc, tidLvl, exp);
@@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12591286 // Start a loop sequence.
12601287 bool needsUniv = startLoopSeq (env, rewriter, exp, curr, lts);
12611288
1262- // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1263- // We cannot change this to `for (const LatPointId li : env.set(lts))`
1264- // because the loop body causes data-movement which invalidates
1265- // the iterator.
1289+ // When using sparse-iterator-based loops, we only need one loops, as
1290+ // opposed to a loop sequence, to cover all the iterator spaces.
12661291 const unsigned lsize = env.set (lts).size ();
1267- for (unsigned i = 0 ; i < lsize; i++) {
1268- const LatPointId li = env.set (lts)[i];
1269- // Start a loop.
1270- auto [loop, isSingleCond] = startLoop (env, rewriter, curr, li, needsUniv);
1271-
1272- // Visit all lattices points with Li >= Lj to generate the
1273- // loop-body, possibly with if statements for coiteration.
1274- Value redInput = env.getReduc ();
1275- Value cntInput = env.getExpandCount ();
1276- Value insInput = env.getInsertionChain ();
1277- Value validIns = env.getValidLexInsert ();
1278- // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1279- // because the loop body causes data-movement which invalidates the
1280- // iterator.
1292+ if (env.generatingSparseIterator ()) {
1293+ // Get the largest lattice point and start a loop.
1294+ const LatPointId li = env.set (lts)[0 ];
1295+ auto [loop, isSingleCond] =
1296+ startLoop (env, rewriter, curr, li, lsize, needsUniv);
1297+ assert (isSingleCond == llvm::isa<IterateOp>(loop));
1298+ // We cannot change this to `for (const LatPointId li : env.set(lts))`
1299+ // because the loop body causes data-movement which invalidates
1300+ // the iterator.
12811301 for (unsigned j = 0 ; j < lsize; j++) {
12821302 const LatPointId lj = env.set (lts)[j];
12831303 const ExprId ej = env.lat (lj).exp ;
1284- if (li == lj || env.merger ().latGT (li, lj)) {
1285- // Recurse into body of each branch.
1286- if (!isSingleCond) {
1287- scf::IfOp ifOp = genIf (env, rewriter, curr, lj);
1288- genStmt (env, rewriter, ej, curr + 1 );
1289- endIf (env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1290- } else {
1304+ // Recurse into body of each branch.
1305+ if (!isSingleCond) {
1306+ env.genLoopBoundary ([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307+ genCoIterationCase (env, rewriter, /* caseIdx*/ j, li, lj, reduc);
12911308 genStmt (env, rewriter, ej, curr + 1 );
1292- }
1309+ // TODO: handle yield values.
1310+ assert (reduc.empty () && " Not Implemented" );
1311+ rewriter.create <sparse_tensor::YieldOp>(env.op ().getLoc ());
1312+ return std::nullopt ;
1313+ });
1314+ // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315+ } else {
1316+ genStmt (env, rewriter, ej, curr + 1 );
12931317 }
12941318 }
1295-
12961319 // End a loop.
12971320 needsUniv = endLoop (env, rewriter, loop, curr, needsUniv, isSingleCond);
1321+ } else {
1322+ // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323+ for (unsigned i = 0 ; i < lsize; i++) {
1324+ const LatPointId li = env.set (lts)[i];
1325+ // Start a loop.
1326+ auto [loop, isSingleCond] =
1327+ startLoop (env, rewriter, curr, li, lsize, needsUniv);
1328+
1329+ // Visit all lattices points with Li >= Lj to generate the
1330+ // loop-body, possibly with if statements for coiteration.
1331+ Value redInput = env.getReduc ();
1332+ Value cntInput = env.getExpandCount ();
1333+ Value insInput = env.getInsertionChain ();
1334+ Value validIns = env.getValidLexInsert ();
1335+ // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1336+ // because the loop body causes data-movement which invalidates the
1337+ // iterator.
1338+ for (unsigned j = 0 ; j < lsize; j++) {
1339+ const LatPointId lj = env.set (lts)[j];
1340+ const ExprId ej = env.lat (lj).exp ;
1341+ if (li == lj || env.merger ().latGT (li, lj)) {
1342+ // Recurse into body of each branch.
1343+ if (!isSingleCond) {
1344+ scf::IfOp ifOp = genIf (env, rewriter, curr, lj);
1345+ genStmt (env, rewriter, ej, curr + 1 );
1346+ endIf (env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347+ } else {
1348+ genStmt (env, rewriter, ej, curr + 1 );
1349+ }
1350+ }
1351+ }
1352+
1353+ // End a loop.
1354+ needsUniv = endLoop (env, rewriter, loop, curr, needsUniv, isSingleCond);
1355+ }
12981356 }
12991357
13001358 // End a loop sequence.
0 commit comments