Skip to content

Commit

Permalink
[Operator] slice&select scatter (#143)
Browse files Browse the repository at this point in the history
* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <zhangbluestars@gmail.com>
Co-authored-by: Tongxin Bai <waffle.bai@gmail.com>
Co-authored-by: Clement Chan <iclementine@outlook.com>
Co-authored-by: Bowen <81504862+Bowen12992@users.noreply.github.com>
Co-authored-by: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com>
  • Loading branch information
6 people authored and machuanjiang committed Nov 13, 2024
1 parent 8dfeece commit 89c65c7
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 25 deletions.
4 changes: 3 additions & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ def enable(lib=aten_lib):
lib.impl("fill.Scalar", fill_scalar, "PrivateUse1")
lib.impl("fill.Tensor", fill_tensor, "PrivateUse1")
lib.impl("flip", flip, "PrivateUse1")
lib.impl("tile", tile, "PrivateUse1")
lib.impl("slice_scatter", slice_scatter, "PrivateUse1")
lib.impl("select_scatter", select_scatter, "PrivateUse1")
lib.impl("index_select", index_select, "PrivateUse1")
lib.impl("tile", tile, "PrivateUse1")
lib.impl("masked_fill", masked_fill, "PrivateUse1")
lib.impl("_unique2", _unique2, "PrivateUse1")
lib.impl("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa, "PrivateUse1")
Expand Down
4 changes: 4 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@
from .rsqrt import rsqrt
from .rsub import rsub
from .scatter import scatter
from .select_scatter import select_scatter
from .sigmoid import sigmoid
from .silu import silu
from .sin import sin
from .slice_scatter import slice_scatter
from .softmax import softmax
from .stack import stack
from .sub import sub
Expand Down Expand Up @@ -228,6 +230,8 @@
"where_self",
"where_scalar_self",
"where_scalar_other",
"select_scatter",
"slice_scatter",
"masked_fill",
"_unique2",
"_upsample_bicubic2d_aa",
Expand Down
86 changes: 86 additions & 0 deletions src/flag_gems/ops/select_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry, offsetCalculator, restride_dim


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def select_scatter_kernel(
inp,
inp_indices,
src,
src_offsets,
M,
N,
index,
stride_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

indices = tl.load(inp_indices + offsets, mask=mask, other=0)
src_indices = tl.load(src_offsets + offsets, mask=mask, other=0)
cur_src = tl.load(src + src_indices, mask=mask, other=0)

indices += index * stride_dim
tl.store(inp + indices, cur_src, mask=mask)


def select_scatter(inp, src, dim, index):
logging.debug("GEMS SELECT_SCATTER")
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index"
dim = dim % inp.ndim
index = index % inp.size(dim)
out = inp.clone().contiguous()
src = src.contiguous()

valid_shape = list(inp.shape)
del valid_shape[dim]
assert (
list(src.shape) == valid_shape
), "Expected src to have a size equal to the slice of self"

src_expanded_shape = list(inp.shape)
src_expanded_shape[dim] = 1
out_strided = restride_dim(out, dim, src_expanded_shape)
idx = torch.arange(0, src.numel(), device=inp.device).reshape(src_expanded_shape)
indices = offsetCalculator(
out_strided, idx, out.stride(), dim, isInp=False
).squeeze(dim=dim)
src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False).squeeze(
dim=dim
)

N = valid_shape[src.ndim - 1]
M = src.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
select_scatter_kernel[grid](
out, indices, src, src_offsets, M, N, index, out.stride(dim)
)

return out
96 changes: 96 additions & 0 deletions src/flag_gems/ops/slice_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry, offsetCalculator, restride_dim


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def slice_scatter_kernel(
inp,
inp_indices,
src,
src_offsets,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
cols_offsets = off + tl.arange(0, BLOCK_N)[None, :]
cols_mask = cols_offsets < N

offsets = rows_offsets * N + cols_offsets
mask = rows_mask and cols_mask

indices = tl.load(inp_indices + offsets, mask=mask, other=0)
src_indices = tl.load(src_offsets + offsets, mask=mask, other=0)
cur_src = tl.load(src + src_indices, mask=mask, other=0)

tl.store(inp + indices, cur_src, mask=mask)


def slice_scatter(inp, src, dim=0, start=None, end=None, step=1):
logging.debug("GEMS SLICE_SCATTER")
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
assert step > 0, "slice step must be positive"
dim = dim % inp.ndim
out = inp.clone().contiguous()
src = src.contiguous()
size_dim = inp.size(dim)

if start is None:
start = 0
if end is None:
end = size_dim

range = end - start
if end < start:
range = 0
elif (end - start) > size_dim:
range = size_dim
start = 0
end = size_dim

if range == 0:
return out

valid_shape = list(inp.shape)
valid_shape[dim] = (range + (step - 1)) // step
assert (
list(src.shape) == valid_shape
), "Expected src to have a size equal to the slice of self"

storage_offset = out.storage_offset() + start * out.stride(dim)
out_strided = restride_dim(out, dim, valid_shape, step, storage_offset)
idx = torch.arange(0, src.numel(), device=inp.device).reshape(valid_shape)
strides = list(out.stride())
strides[dim] *= step
indices = (
offsetCalculator(out_strided, idx, strides, dim, isInp=False) + storage_offset
)
src_offsets = offsetCalculator(src, idx, src.stride(), dim, isInp=False)

N = valid_shape[src.ndim - 1]
M = src.numel() // N

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
slice_scatter_kernel[grid](out, indices, src, src_offsets, M, N)

return out
4 changes: 2 additions & 2 deletions src/flag_gems/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
broadcastable,
broadcastable_to,
dim_compress,
offset_calculator,
offsetCalculator,
restride_dim,
)

Expand All @@ -13,7 +13,7 @@
"pointwise_dynamic",
"dim_compress",
"restride_dim",
"offset_calculator",
"offsetCalculator",
"broadcastable_to",
"broadcastable",
]
84 changes: 63 additions & 21 deletions src/flag_gems/utils/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,6 @@ def can_use_int32_index(a):
return True


def offsetCalculator(inp, idx, strides, dim, isInp):
ndim = inp.ndim
shape = list(inp.shape)
offsets = 0
idx_dim = 0
for d in range(0, ndim):
mod = idx % shape[d]
add_on = mod * strides[d]
offsets += add_on
if d == dim:
idx_dim = add_on
idx = idx // shape[d]
# FIXME: Should we write a fast div/mod
# to boost the '%' and '//'? (Since they may be run many times)
# See also:
# - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
# - Division by Invariant Integers Using Multiplication,
# Torbjörn Granlund and Peter L. Montgomery, 1994.
return (offsets) if not isInp else (offsets - idx_dim)


def restride_dim(src, dim, shape, step=0, storage_offset=None):
strides = list(src.stride())
strides[dim] *= step
Expand Down Expand Up @@ -290,6 +269,48 @@ def add_on_kernel(


def offset_calculator(inp, idx, strides, dim, isInp):
"""
Calculate the flat index(a.k.a offset) for a given ravel index in a multi-dimensional array.
The formula can be seen in:
- https://numpy.org/doc/stable/reference/arrays.ndarray.html#internal-memory-layout-of-an-ndarray
- https://numpy.org/devdocs/user/basics.indexing.html#single-element-indexing
Parameters:
inp (tensor): The input multi-dimensional array from which the offset is calculated.
idx (tensor): The linear index for which the offset is to be calculated.
strides (list of int): A list containing the stride lengths for each dimension of the input array.
dim (int): The specific dimension for which the index offset needs to be calculated.
isInp (bool): A flag indicating whether the tensor 'inp' is the parameter 'self'
in scatter/gather/index_* operators or not.
In operators such as scatter/gather and index_*, when the input tensor 'inp'
is the 'self' tensor to be processed, we may need to modify its offsets later.
For instance, in the scatter operator, the offset is calculated using the formula:
inp_offset = origin_offset - stride[dim] * n_dim + stride[dim] * index.
In this case, we return the fixed part of the formula:
origin_offset - stride[dim] * n_dim,
to facilitate subsequent modifications.
For other types of input 'inp', we return the complete calculation result
of origin_offsets directly.
Returns:
The calculated offset. If isInp is True, the fixed offset is returned; otherwise, the origin offset is returned.
Note:
The function includes a comment suggesting the potential optimization of division and modulus operations,
which may be beneficial if this function is called frequently.
See also:
- https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
- Division by Invariant Integers Using Multiplication,
Torbjörn Granlund and Peter L. Montgomery, 1994.
"""
ndim = inp.ndim
shape = list(inp.shape)
offsets = torch.zeros_like(inp, dtype=torch.int32, device=inp.device)
Expand All @@ -309,3 +330,24 @@ def offset_calculator(inp, idx, strides, dim, isInp):
idx_dim = add_on
idx = idx // shape[d]
return offsets if not isInp else (offsets - idx_dim)


def offsetCalculator(inp, idx, strides, dim, isInp):
ndim = inp.ndim
shape = list(inp.shape)
offsets = 0
idx_dim = 0
for d in range(0, ndim):
mod = idx % shape[d]
add_on = mod * strides[d]
offsets += add_on
if d == dim:
idx_dim = add_on
idx = idx // shape[d]
# FIXME: Should we write a fast div/mod
# to boost the '%' and '//'? (Since they may be run many times)
# See also:
# - https://ridiculousfish.com/blog/posts/labor-of-division-episode-i.html
# - Division by Invariant Integers Using Multiplication,
# Torbjörn Granlund and Peter L. Montgomery, 1994.
return (offsets) if not isInp else (offsets - idx_dim)
Loading

0 comments on commit 89c65c7

Please sign in to comment.