diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index c44c5985ee7bba..dac6d6b64551c8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -774,12 +774,12 @@ Value sparse_tensor::genReader(OpBuilder &builder, Location loc, return reader; } -Value sparse_tensor::genReaderBuffers(OpBuilder &builder, Location loc, - SparseTensorType stt, - ArrayRef dimShapesValues, - Value dimSizesBuffer, - /*out*/ Value &dim2lvlBuffer, - /*out*/ Value &lvl2dimBuffer) { +Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc, + SparseTensorType stt, + ArrayRef dimShapesValues, + Value dimSizesBuffer, + /*out*/ Value &dim2lvlBuffer, + /*out*/ Value &lvl2dimBuffer) { const Dimension dimRank = stt.getDimRank(); const Level lvlRank = stt.getLvlRank(); // For an identity mapping, the dim2lvl and lvl2dim mappings are diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 698b6c491a9aef..1562ea3f20f73d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -353,11 +353,11 @@ Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt, /*out*/ SmallVectorImpl &dimShapeValues, /*out*/ Value &dimSizesBuffer); -/// Generates code to set up the buffer parameters for a reader. -Value genReaderBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, - ArrayRef dimShapeValues, Value dimSizesBuffer, - /*out*/ Value &dim2lvlBuffer, - /*out*/ Value &lvl2dimBuffer); +/// Generates code to set up the buffer parameters for a map. +Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt, + ArrayRef dimShapeValues, Value dimSizesBuffer, + /*out*/ Value &dim2lvlBuffer, + /*out*/ Value &lvl2dimBuffer); //===----------------------------------------------------------------------===// // Inlined constant generators. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 2c03f0a6020e6a..e22789643c90af 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1478,8 +1478,8 @@ struct SparseNewConverter : public OpConversionPattern { // Now construct the dim2lvl and lvl2dim buffers. Value dim2lvlBuffer; Value lvl2dimBuffer; - genReaderBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer, - dim2lvlBuffer, lvl2dimBuffer); + genMapBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer, + dim2lvlBuffer, lvl2dimBuffer); // Read the COO tensor data. Value xs = desc.getAOSMemRef(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index e44d8565fc867d..d2d7b46ab834e7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -205,9 +205,9 @@ class NewCallParams final { params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); // Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers. params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues); - params[kParamLvlSizes] = genReaderBuffers( - builder, loc, stt, dimSizesValues, params[kParamDimSizes], - params[kParamDim2Lvl], params[kParamLvl2Dim]); + params[kParamLvlSizes] = + genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes], + params[kParamDim2Lvl], params[kParamLvl2Dim]); // Secondary and primary types encoding. setTemplateTypes(stt); // Finally, make note that initialization is complete. @@ -446,8 +446,8 @@ class SparseTensorNewConverter : public OpConversionPattern { Value dim2lvlBuffer; Value lvl2dimBuffer; Value lvlSizesBuffer = - genReaderBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer, - dim2lvlBuffer, lvl2dimBuffer); + genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer, + dim2lvlBuffer, lvl2dimBuffer); // Use the `reader` to parse the file. Type opaqueTp = getOpaquePointerType(rewriter); Type eltTp = stt.getElementType();