-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathDialect.h
169 lines (132 loc) · 7.22 KB
/
Dialect.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
// TritonGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
namespace mlir {
namespace triton {
namespace gpu {
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};
unsigned getTotalElemsPerThread(Type type);
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
Type eltTy);
SmallVector<unsigned> getElemsPerThread(Type type);
// Returns the number of threads per warp that may have access to replicated
// elements. If you want non-replicated threads, use
// getThreadsPerWarpWithUniqueData.
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
unsigned getWarpSize(Attribute layout);
// Returns the number of warps per CTA that may have access to replicated
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
SmallVector<unsigned> getSizePerThread(Attribute layout);
// Returns the number of contiguous elements that each thread
// has access to, on each dimension of the tensor. E.g.
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
// regardless of the shape of the tensor.
SmallVector<unsigned> getContigPerThread(Attribute layout);
// Returns the number of non-replicated contiguous elements that each thread
// has access to, on each dimension of the tensor. For a blocked layout
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
ArrayRef<int64_t> tensorShape);
// Returns the number of threads per warp that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17
// have access to the full tensor, whereas the other threads have access to
// replicated elements, so this function returns [2, 2].
SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape);
// Returns the number of warps per CTA that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
// returns [1, 1], since the first warp has access to the full tensor, whereas
// the other warps have access to replicated elements.
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For distributed layouts, this represents
// the order of the elements within a thread.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);
// Returns the dimensions along which warpId's are distributed.
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
// tells there are 2 warps along dim0 and 4 warps along dim1.
// warpOrder tells the specific order when distributing warp IDs.
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
// [warp0 warp2 warp4 warp6]
// [warp1 warp3 warp5 warp7]
// Note that in most cases, getWarpOrder and getOrder return the same results.
// But this is not guaranteed.
SmallVector<unsigned> getWarpOrder(Attribute layout);
// Returns the dimensions along which threadId's are distributed.
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
// distribution in the warp.
// Note that, in most cases, getThreadOrder and getOrder return the same
// results. But this is not guaranteed. One exception is mfma.transposed layout,
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
SmallVector<unsigned> getThreadOrder(Attribute layout);
CTALayoutAttr getCTALayout(Attribute layout);
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
SmallVector<unsigned> getCTASplitNum(Attribute layout);
SmallVector<unsigned> getCTAOrder(Attribute layout);
/* The difference between ShapePerCTATile and ShapePerCTA:
* (1) ShapePerCTATile is defined by SizePerThread * ThreadsPerWarp *
* WarpsPerCTA in each dimension and is independent from the tensor shape.
* (2) ShapePerCTA is defined by shape / CTASplitNum in each dimension.
* (3) In the implementation of emitIndices, ShapePerCTATile will
* be replicated or wrapped to fit ShapePerCTA.
*/
SmallVector<unsigned>
getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Type type);
unsigned getNumWarpsPerCTA(Attribute layout);
unsigned getNumCTAs(Attribute layout);
// Return the order that represents that the batch is in row-major or
// column-major order for a batch of matrices of shape [*, m, n] with
// len(shape) == rank.
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
// Return the order that represents that the dot operand is in kMajor
// (contiguous in the inner dimension) or it's contiguous on the outer
// dimension.
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
// Return true if a view between the two types cannot be implemented as a no-op.
bool isExpensiveView(Type srcType, Type dstType);
// Return a blocked encoding where the shape is distributed contiguously amongst
// the threads, warps, CTAs with 1 element per threads.
triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs);
// Dump information about which threads/registers contain each of the tensor
// elements.
void dumpLayout(RankedTensorType tensorType);
// Dump the layout from HW point of view and prints what tensor element is held
// by each thread and register.
void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
} // namespace gpu
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_