-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Operator] slice&select scatter (#143)
* add Ops & UT & Bench * add full zero ones Ops & UT & Bench * split normal op * [Operator] init slice&select scatter * code format * PR comment * split test_special_ops * add K-S test * split special perf * Exponential added. (#138) * exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies. * resolve conflict * Use int64 indexing when needed & fix argmax (#146) 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size * test for op * test for op * Making libentry thread safe (#136) * libentry now is lock protected. * Add multithreading tests for libentry. * polish code. * add argparse * fix desc * fix num * Update test_specific_ops.py * split UT files * fix * fix * [Operator] Optimize CrossEntropyLoss (#131) reimplement cross_entropy_loss forward and backward support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors * Exponential added. (#138) * exponential added. * Added K-S tests to exponential_, fp64 corrected. * aligned with aten prototype * Exponential_ uses uint64 offsets in Triton kernel. * Update pyproject config for new test dependencies. * Use int64 indexing when needed & fix argmax (#146) 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max; 2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size * Making libentry thread safe (#136) * libentry now is lock protected. * Add multithreading tests for libentry. * polish code. * [Test] Test for op (#151) * [chore] solve slice&select scatter's test cases * [fix] fix slice&select scatter's test cases * [chore] remove out-of-range indices in select_scatter's test cases * [chore] simplify slice_scatter's test cases * [fix] Added range that is deleted by mistake * Merge branch 'master' into slice&select_scatter * [chore] reformat * [fix] typo * [chore] Considering perf, pause the replacement of some aTen operators * slice_scatter * select_scatter * index_select * [fix] Add libentry in op.cumsum * [fix] Del slice&select scatter's perf tests * [Chore] Add pytest mark for slice&select scatter's test * [Fix] Correct slice_scatter test * [Fix] Replace CPU Tensor --------- Co-authored-by: Bowen12992 <zhangbluestars@gmail.com> Co-authored-by: Tongxin Bai <waffle.bai@gmail.com> Co-authored-by: Clement Chan <iclementine@outlook.com> Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com> Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
- Loading branch information
1 parent
8dfeece
commit 89c65c7
Showing
7 changed files
with
314 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from ..utils import libentry, offsetCalculator, restride_dim | ||
|
||
|
||
def cfggen(): | ||
block_m = [1, 2, 4, 8] | ||
configs = [ | ||
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m | ||
] | ||
return configs | ||
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["M", "N"]) | ||
@triton.jit | ||
def select_scatter_kernel( | ||
inp, | ||
inp_indices, | ||
src, | ||
src_offsets, | ||
M, | ||
N, | ||
index, | ||
stride_dim, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] | ||
rows_mask = rows_offsets < M | ||
|
||
for off in range(0, N, BLOCK_N): | ||
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] | ||
cols_mask = cols_offsets < N | ||
|
||
offsets = rows_offsets * N + cols_offsets | ||
mask = rows_mask and cols_mask | ||
|
||
indices = tl.load(inp_indices + offsets, mask=mask, other=0) | ||
src_indices = tl.load(src_offsets + offsets, mask=mask, other=0) | ||
cur_src = tl.load(src + src_indices, mask=mask, other=0) | ||
|
||
indices += index * stride_dim | ||
tl.store(inp + indices, cur_src, mask=mask) | ||
|
||
|
||
def select_scatter(inp, src, dim, index): | ||
logging.debug("GEMS SELECT_SCATTER") | ||
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" | ||
assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index" | ||
dim = dim % inp.ndim | ||
index = index % inp.size(dim) | ||
out = inp.clone().contiguous() | ||
src = src.contiguous() | ||
|
||
valid_shape = list(inp.shape) | ||
del valid_shape[dim] | ||
assert ( | ||
list(src.shape) == valid_shape | ||
), "Expected src to have a size equal to the slice of self" | ||
|
||
src_expanded_shape = list(inp.shape) | ||
src_expanded_shape[dim] = 1 | ||
out_strided = restride_dim(out, dim, src_expanded_shape) | ||
idx = torch.arange(0, src.numel(), device=inp.device).reshape(src_expanded_shape) | ||
indices = offsetCalculator( | ||
out_strided, idx, out.stride(), dim, isInp=False | ||
).squeeze(dim=dim) | ||
src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False).squeeze( | ||
dim=dim | ||
) | ||
|
||
N = valid_shape[src.ndim - 1] | ||
M = src.numel() // N | ||
|
||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) | ||
select_scatter_kernel[grid]( | ||
out, indices, src, src_offsets, M, N, index, out.stride(dim) | ||
) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from ..utils import libentry, offsetCalculator, restride_dim | ||
|
||
|
||
def cfggen(): | ||
block_m = [1, 2, 4, 8] | ||
configs = [ | ||
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m | ||
] | ||
return configs | ||
|
||
|
||
@libentry() | ||
@triton.autotune(configs=cfggen(), key=["M", "N"]) | ||
@triton.jit | ||
def slice_scatter_kernel( | ||
inp, | ||
inp_indices, | ||
src, | ||
src_offsets, | ||
M, | ||
N, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] | ||
rows_mask = rows_offsets < M | ||
|
||
for off in range(0, N, BLOCK_N): | ||
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :] | ||
cols_mask = cols_offsets < N | ||
|
||
offsets = rows_offsets * N + cols_offsets | ||
mask = rows_mask and cols_mask | ||
|
||
indices = tl.load(inp_indices + offsets, mask=mask, other=0) | ||
src_indices = tl.load(src_offsets + offsets, mask=mask, other=0) | ||
cur_src = tl.load(src + src_indices, mask=mask, other=0) | ||
|
||
tl.store(inp + indices, cur_src, mask=mask) | ||
|
||
|
||
def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): | ||
logging.debug("GEMS SLICE_SCATTER") | ||
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" | ||
assert step > 0, "slice step must be positive" | ||
dim = dim % inp.ndim | ||
out = inp.clone().contiguous() | ||
src = src.contiguous() | ||
size_dim = inp.size(dim) | ||
|
||
if start is None: | ||
start = 0 | ||
if end is None: | ||
end = size_dim | ||
|
||
range = end - start | ||
if end < start: | ||
range = 0 | ||
elif (end - start) > size_dim: | ||
range = size_dim | ||
start = 0 | ||
end = size_dim | ||
|
||
if range == 0: | ||
return out | ||
|
||
valid_shape = list(inp.shape) | ||
valid_shape[dim] = (range + (step - 1)) // step | ||
assert ( | ||
list(src.shape) == valid_shape | ||
), "Expected src to have a size equal to the slice of self" | ||
|
||
storage_offset = out.storage_offset() + start * out.stride(dim) | ||
out_strided = restride_dim(out, dim, valid_shape, step, storage_offset) | ||
idx = torch.arange(0, src.numel(), device=inp.device).reshape(valid_shape) | ||
strides = list(out.stride()) | ||
strides[dim] *= step | ||
indices = ( | ||
offsetCalculator(out_strided, idx, strides, dim, isInp=False) + storage_offset | ||
) | ||
src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False) | ||
|
||
N = valid_shape[src.ndim - 1] | ||
M = src.numel() // N | ||
|
||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) | ||
slice_scatter_kernel[grid](out, indices, src, src_offsets, M, N) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.