Skip to content

Commit

Permalink
Revert "Revert "[Backend] Improve dot support to target FMA (#4516)""
Browse files Browse the repository at this point in the history
This reverts commit c5f5ac1.
  • Loading branch information
whitneywhtsang committed Dec 20, 2024
1 parent 4506b07 commit 242cd5a
Show file tree
Hide file tree
Showing 13 changed files with 573 additions and 294 deletions.
24 changes: 24 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,18 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape);

SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);

Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);

size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content);

Expand Down Expand Up @@ -496,6 +502,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
return ret;
}

/// Extend 2d shared object to 3d.
///
/// If tensor has 3 dimensions, returns original shared object.
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
///
/// This Function is used to simplify processing of 2d and 3d dot operands,
/// particularly in the conversion of local_load operation.
///
/// \param rewriter
/// \param loc
/// \param smemObj
/// \param shape shape of a tensor represented by smemObj
/// \returns shared object describing 3d tensor
SharedMemoryObject
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
SharedMemoryObject smemObj,
ArrayRef<int64_t> shape);

// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);

llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
Loading

0 comments on commit 242cd5a

Please sign in to comment.