@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
829829 }
830830};
831831
832+ // A trivial wrapper to help generate different operations for dense/sparse
833+ // tensors.
832834struct TensorLike {
833835 TensorLike (OpBuilder &builder, Location loc, RankedTensorType rtt,
834- ValueRange sizes)
835- : isSparse(rtt.getEncoding() != nullptr ) {
836+ ValueRange sizes) {
836837 SmallVector<Value> dynSzs;
837838 getDynamicSizes (rtt, sizes, dynSzs);
838839
839- if (isSparse)
840- val = builder.create <AllocTensorOp>(loc, rtt, dynSzs);
841- else
842- val = allocDenseTensor (builder, loc, rtt, sizes);
843- };
844-
845- void insertOrStore (OpBuilder &builder, Location loc, Value v,
846- ValueRange crds) {
847- if (isSparse)
848- val = builder.create <InsertOp>(loc, v, val, crds);
849- else
850- builder.create <memref::StoreOp>(loc, v, val, crds);
840+ val = builder.create <AllocTensorOp>(loc, rtt, dynSzs);
841+ if (!isSparse ()) {
842+ Value c0 = constantZero (builder, loc, rtt.getElementType ());
843+ val = builder.create <linalg::FillOp>(loc, c0, val).getResult (0 );
844+ }
851845 }
852846
853- Value getSSA () const {
854- // We don't need to maintain the SSA chain for a memref value.
855- return isSparse ? val : nullptr ;
847+ void insert (OpBuilder &builder, Location loc, Value v, ValueRange crds) {
848+ // TODO: Unify these two.
849+ if (isSparse ())
850+ val = builder.create <sparse_tensor::InsertOp>(loc, v, val, crds);
851+ else
852+ val = builder.create <tensor::InsertOp>(loc, v, val, crds);
856853 }
857854
858855 Value finalize (OpBuilder &builder, Location loc, RankedTensorType rtp) const {
859- if (isSparse)
856+ if (isSparse () )
860857 return builder.create <LoadOp>(loc, val, true );
861- return builder. create <bufferization::ToTensorOp>(loc, rtp, val) ;
858+ return val;
862859 }
863860
864- void updateSSA (Value v) {
865- // Dense memref is a non-SSA value.
866- assert (isSparse);
867- val = v;
861+ bool isSparse () const {
862+ return getSparseTensorEncoding (val.getType ()) != nullptr ;
868863 }
869864
870- private:
871- bool isSparse;
872- Value val; // either a memref (for dense tensor) or a sparse tensor.
865+ Value val;
873866};
874867
875868struct ConcatenateRewriter : public OpRewritePattern <ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
901894
902895 TensorLike dstBuf (rewriter, loc, dstTp.getRankedTensorType (), sizes);
903896 Value offset = constantIndex (rewriter, loc, 0 );
904- Value iterArg = dstBuf.getSSA () ;
897+ Value iterArg = dstBuf.val ;
905898
906899 ForeachOp foreachOp;
907900 for (Value input : op.getInputs ()) {
908901 // Builds a for op for each input tensor to append new values into the
909902 // output tensor.
910903 foreachOp = rewriter.create <ForeachOp>(
911- loc, input, iterArg ? ValueRange{iterArg} : ValueRange{} ,
904+ loc, input, iterArg,
912905 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
913906 ValueRange reduc) {
914907 SmallVector<Value> dstLcvs (dstTp.getLvlRank ());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
920913 // FIXME: `toStoredDim` is deprecated
921914 dstLcvs[toStoredDim (dstTp.getEncoding (), d)] = crd;
922915 }
923-
924- if (!reduc.empty ())
925- dstBuf.updateSSA (reduc.front ());
926-
916+ // Enters foreach, updates the SSA chain.
917+ dstBuf.val = reduc.front ();
927918 if (!dstTp.isAllDense ()) {
928919 Value cond = genIsNonzero (builder, loc, v);
929920 auto ifOp = builder.create <scf::IfOp>(loc, reduc.getTypes (), cond,
930921 /* else*/ true );
931922 builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
932- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
923+ builder.create <scf::YieldOp>(loc, dstBuf.val );
933924
934925 builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
935- dstBuf.insertOrStore (builder, loc, v, dstLcvs);
936- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
926+ dstBuf.insert (builder, loc, v, dstLcvs);
927+ builder.create <scf::YieldOp>(loc, dstBuf.val );
937928
938929 // Exits the ifOp, update the sparse tensor SSA value.
939930 builder.setInsertionPointAfter (ifOp);
940- assert (!reduc.empty ());
941- dstBuf.updateSSA (ifOp.getResult (0 ));
931+ dstBuf.val = ifOp.getResult (0 );
942932 } else {
943- dstBuf.insertOrStore (builder, loc, v, dstLcvs);
933+ dstBuf.insert (builder, loc, v, dstLcvs);
944934 }
945- if (reduc.empty ())
946- builder.create <sparse_tensor::YieldOp>(loc);
947- else
948- builder.create <sparse_tensor::YieldOp>(loc, dstBuf.getSSA ());
935+ builder.create <sparse_tensor::YieldOp>(loc, dstBuf.val );
949936 });
950937 // Accumulates the offset. Note that only static-shaped inputs are allowed
951938 // by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
955942 offset = rewriter.create <arith::AddIOp>(
956943 loc, offset, constantIndex (rewriter, loc, *sh));
957944
958- if (!foreachOp.getResults ().empty ()) {
959- iterArg = foreachOp.getResult (0 );
960- dstBuf.updateSSA (iterArg);
961- }
945+ iterArg = foreachOp.getResult (0 );
946+ dstBuf.val = iterArg;
962947 }
963948
964- if (!foreachOp.getResults ().empty ())
965- dstBuf.updateSSA (iterArg);
966-
949+ dstBuf.val = iterArg;
967950 Value ret = dstBuf.finalize (rewriter, loc, dstTp.getRankedTensorType ());
968951 rewriter.replaceOp (op, ret);
969952 return success ();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1010993 ValueRange vs;
1011994 TensorLike dstBuf (rewriter, loc, dstStt.getRankedTensorType (), sizes);
1012995
1013- Value iterArg = dstBuf.getSSA ();
1014996 auto foreachOp = rewriter.create <ForeachOp>(
1015- loc, src, iterArg ? ValueRange{iterArg} : ValueRange{} , foreachOrder,
997+ loc, src, dstBuf. val , foreachOrder,
1016998 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1017999 ValueRange reduc) {
10181000 // Enters the loop, update the SSA value for insertion chain.
1019- if (!reduc.empty ())
1020- dstBuf.updateSSA (reduc.front ());
1021-
1001+ dstBuf.val = reduc.front ();
10221002 const Dimension dimRank = dstStt.getDimRank ();
10231003 const Level lvlRank = dstStt.getLvlRank ();
10241004 SmallVector<Value> lcvs (lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
10281008 }
10291009
10301010 if (!skipZeroCheck) {
1031- assert (!reduc.empty ());
10321011 Value cond = genIsNonzero (builder, loc, v);
10331012 auto ifOp = builder.create <scf::IfOp>(loc, reduc.getTypes (), cond,
10341013 /* else*/ true );
10351014 builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
1036- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
1015+ builder.create <scf::YieldOp>(loc, dstBuf.val );
10371016
10381017 builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1039- dstBuf.insertOrStore (builder, loc, v, lcvs);
1040- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
1018+ dstBuf.insert (builder, loc, v, lcvs);
1019+ builder.create <scf::YieldOp>(loc, dstBuf.val );
10411020
10421021 // Exits the ifOp, update the sparse tensor SSA value.
10431022 builder.setInsertionPointAfter (ifOp);
1044- dstBuf.updateSSA ( ifOp.getResult (0 ) );
1023+ dstBuf.val = ifOp.getResult (0 );
10451024 } else {
1046- dstBuf.insertOrStore (builder, loc, v, lcvs);
1025+ dstBuf.insert (builder, loc, v, lcvs);
10471026 }
1048- if (reduc.empty ())
1049- builder.create <sparse_tensor::YieldOp>(loc);
1050- else
1051- builder.create <sparse_tensor::YieldOp>(loc, dstBuf.getSSA ());
1027+ builder.create <sparse_tensor::YieldOp>(loc, dstBuf.val );
10521028 });
10531029
10541030 rewriter.setInsertionPointAfter (foreachOp);
10551031
10561032 // Exits the for loop, links the SSA chain.
1057- if (!foreachOp.getResults ().empty ())
1058- dstBuf.updateSSA (foreachOp.getResult (0 ));
1033+ dstBuf.val = foreachOp.getResult (0 );
10591034
10601035 Value ret = dstBuf.finalize (rewriter, loc, dstStt.getRankedTensorType ());
10611036 rewriter.replaceOp (op, ret);
0 commit comments