Skip to content

Commit

Permalink
Create basic pieces of tl.local_copy and tl.gather
Browse files Browse the repository at this point in the history
  • Loading branch information
plotfi committed Jul 18, 2024
1 parent c929102 commit 08112f0
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 1 deletion.
4 changes: 4 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
namespace mlir {
namespace triton {

struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};

struct GlobalMemory : public SideEffects::Resource::Base<GlobalMemory> {
StringRef getName() final { return "<GlobalMemory>"; }
};
Expand Down
46 changes: 46 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
// Interfaces
//
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
def SharedMemory : Resource<"::mlir::triton::SharedMemory">;

//
// Op Base
Expand Down Expand Up @@ -225,6 +226,51 @@ def TT_AdvanceOp : TT_Op<"advance",
let hasFolder = 1;
}

// Allocate shared memory
def TT_LocalCopyOp : TT_Op<"local_copy", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "copy tensor to local memory";
let arguments = (ins TT_Tensor:$src);

let extraClassDeclaration = [{
bool isSharedMemoryAlloc() {
return true;
}
}];
let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}];

let results = (outs TT_MemDescType:$result);
}

def TT_GatherOp : TT_Op<"gather", [
AttrSizedOperandSegments,
MemoryEffects<[MemRead<SharedMemory>]>,
TypesMatchWith<"infer mask type from src type",
"src", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
TypesMatchWith<"infer other type from src type",
"src", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 4) || std::equal_to<>()">
]> {
let summary = "Gather from a buffer into a distributed tensor";

let arguments = (
ins TT_MemDescType:$src,
TT_IntTensor:$indices,
Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other
);

let assemblyFormat = [{
$src `[` $indices `]` (`,` $mask^)? (`,` $other^)?
attr-dict `:`
`(` type($src) `,` type($indices)
(`,` type($mask)^ )? (`,` type($other)^ )?
`)` `->` type($result)
}];

let results = (outs TT_Tensor:$result);
}

//
// Load/Store Ops
//
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@
namespace mlir {
namespace triton {

void LocalCopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Operation *op = getOperation();
if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset"))
return;
effects.emplace_back(MemoryEffects::Allocate::get(),
mlir::triton::SharedMemory::get());
if (getSrc())
effects.emplace_back(MemoryEffects::Write::get(), getResult(),
mlir::triton::SharedMemory::get());
}

void LoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Expand Down
50 changes: 49 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <pybind11/functional.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -1223,6 +1223,54 @@ void init_triton_ir(py::module &&m) {
return self.create<LoadOp>(ptrs, cacheModifier, evictionPolicy,
isVolatile);
})
.def("create_local_copy",
[](TritonOpBuilder &self, Value &ptr) -> Value {
auto context = ptr.getContext();
auto tensorType = dyn_cast<RankedTensorType>(ptr.getType());
auto elemType = tensorType.getElementType();
auto shape = tensorType.getShape();
auto rank = tensorType.getRank();
auto op = cast<LoadOp>(ptr.getDefiningOp());

// TODO: Set these with something other tha the defaults
auto ctaLayout = triton::gpu::CTALayoutAttr::getDefault(context, rank);
SmallVector<unsigned int> order = {0, 1};
//triton::gpu::getOrder(tensorType.getEncoding());

Attribute encoding =
triton::gpu::SharedEncodingAttr::get(context, 1, 1, 1,
order,
ctaLayout);

if (tensorType.getRank() > 1) {
encoding = triton::gpu::SharedEncodingAttr::get(
tensorType.getContext(), tensorType.getShape(), order,
ctaLayout,
tensorType.getElementType());
}

auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(context);

MemDescType memDescType =
MemDescType::get(shape, elemType, encoding, sharedMemorySpace,
/*mutableMemory=*/ false);

return self.create<LocalCopyOp>(memDescType, ptr);
})

// .def("create_local_gather",
// [](TritonOpBuilder &self, Value &ptr, Value &indices) -> Value {
// auto tensorType = dyn_cast<RankedTensorType>(ptr.getType());
// return self.create<GatherOp>(tensorType, ptr, indices);
// })
.def("create_masked_local_gather",
[](TritonOpBuilder &self, Value &ptr, Value &indices, Value &mask,
std::optional<Value> &other) -> Value {
auto tensorType = dyn_cast<RankedTensorType>(ptr.getType());
return self.create<GatherOp>(tensorType, ptr, indices, mask,
other.value_or(Value()));
})
.def("create_store",
[](TritonOpBuilder &self, Value &ptrs, Value &value,
CacheModifier cacheModifier,
Expand Down
4 changes: 4 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
int64,
int8,
join,
local_copy,
gather,
load,
make_block_ptr,
max_constancy,
Expand Down Expand Up @@ -188,6 +190,8 @@
"int8",
"ir",
"join",
"local_copy",
"gather",
"load",
"log",
"log2",
Expand Down
14 changes: 14 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,20 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
# -----------------------


@builtin
def local_copy(pointer, _builder=None):
return semantic.local_copy(pointer, _builder)

@builtin
def gather(pointer, mask=None, other=None, _builder=None):
mask = _constexpr_to_value(mask)
other = _constexpr_to_value(other)
if mask is not None:
mask = _to_tensor(mask, _builder)
if other is not None:
other = _to_tensor(other, _builder)
return semantic.gather(pointer, mask, other, _builder)

@builtin
def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
volatile=False, _builder=None):
Expand Down
20 changes: 20 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,26 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_
builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
is_volatile), dst_ty)

def local_copy(ptr: tl.tensor, builder: ir.builder) -> tl.tensor:
# Get `pointer_type<elt_ty>` and `elt_ty`
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty # .element_ty
dst_ty = elt_ty
return tl.tensor(builder.create_local_copy(ptr.handle), dst_ty)

def gather(ptr: tl.tensor,
mask: Optional[tl.tensor], other: Optional[tl.tensor],
builder: ir.builder) -> tl.tensor:
# Get `pointer_type<elt_ty>` and `elt_ty`
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty # .element_ty
dst_ty = elt_ty

return tl.tensor(
builder.create_masked_local_gather(ptr.handle,
mask.handle if mask else None,
other.handle if other else None),
dst_ty)

def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
Expand Down

0 comments on commit 08112f0

Please sign in to comment.