Skip to content

Commit

Permalink
[Bug] Fixing longformer compilation (#403)
Browse files Browse the repository at this point in the history
1. Added  `torch.Tensor.as_strided` and `torch.flip`
2. Added support for `rounding_mode == 'trunc'` in torch.divide
3. Registered `torch.new_ones`




Longformer model compilation fails with:
```
RuntimeError: cudaDeviceSynchronize failed with error: cudaErrorMisalignedAddress
```
aftering running `fused_matmul_f16_pk_cute_rearrange_add` kernel. Also
Nvidia Nsight Compute shows that matmul kernel fails to launch. This PR
contains all changes needed to reproduce this issue.

To reproduce:
1. check out to `zhumakhan/longformer` branch and 
4. python3 tests/benchmarks/bench_transformer.py longformer

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
  • Loading branch information
zhumakhan and Zhumakhan authored Aug 12, 2024
1 parent 09f6cc0 commit dfd4257
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 9 deletions.
36 changes: 30 additions & 6 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def interpolate(
)


@register_function(operator.itruediv)
@register_function(operator.truediv)
@register_function(torch.true_divide)
@register_method(torch.Tensor.true_divide)
Expand All @@ -568,6 +569,8 @@ def truediv(x: Union[Tensor, int, float], y: Union[Tensor, int, float]):
def is_integer(v: Union[Tensor, int, float]) -> bool:
return isinstance(v, int) or (isinstance(v, Tensor) and v.dtype.is_integer())

if not isinstance(x, Tensor) and not isinstance(y, Tensor):
return x / y
if is_integer(x) and is_integer(y):
if isinstance(y, (int, float)):
y = hidet.asarray(y).to(device=x.device)
Expand All @@ -589,8 +592,23 @@ def div(x: Tensor, y: Tensor, *, rounding_mode: Optional[str] = None, out=None):
elif rounding_mode == 'floor':
return ops.floor(result)
else:
assert rounding_mode == 'trunc'
raise NotImplementedError("torch.div(..., rounding_mode='trunc') is currently not supported by Hidet")
assert rounding_mode == 'trunc', 'rounding_mode should be one of "floor" or "trunc"'
if isinstance(result, Tensor):
dtype = result.dtype
result = result.to(dtype='int64')
return result.to(dtype=dtype)
else:
if isinstance(x, float) or isinstance(y, float):
return float(int(result))
return int(result)


@register_function(torch.as_strided)
@register_method(torch.Tensor.as_strided)
def torch_as_strided(
input: Tensor, size: Union[int, Tuple[int]], stride: Union[int, Tuple[int]], storage_offset: Optional[int] = None
):
return ops.as_strided(input, size, stride, storage_offset)


@register_function(operator.sub)
Expand Down Expand Up @@ -1771,17 +1789,17 @@ def torch_einsum(equation, *operands):


@register_function(torch.triu)
@register_function(torch.Tensor.triu)
@register_function(torch.Tensor.triu_)
@register_method(torch.Tensor.triu)
@register_method(torch.Tensor.triu_)
def torch_triu(x: Tensor, diagonal: int = 0, *, out=None):
if out is not None:
raise NotImplementedError("hidet: does not support torch.triu(..., out=...)")
return ops.triu(x, diagonal)


@register_function(torch.tril)
@register_function(torch.Tensor.tril)
@register_function(torch.Tensor.tril_)
@register_method(torch.Tensor.tril)
@register_method(torch.Tensor.tril_)
def torch_tril(x: Tensor, diagonal: int = 0, *, out=None):
if out is not None:
raise NotImplementedError("hidet: does not support torch.tril(..., out=...)")
Expand Down Expand Up @@ -1859,6 +1877,12 @@ def torch_unfold(input: Tensor, kernel_size, dilation=1, padding=0, stride=1) ->
return ops.im2col(input, kernel_size, dilation, padding, stride)


@register_function(torch.flip)
@register_method(torch.Tensor.flip)
def torch_unfold(input: Tensor, dims) -> Tensor:
return ops.flip(input, dims)


@register_function(torch.sign)
def torch_sign(input: Tensor, *, out=None):
if out is not None:
Expand Down
36 changes: 36 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,39 @@ def tensor_new_full(
@register_method(torch.Tensor.zero_)
def tensor_zero_(self: Tensor):
return ops.full(self.shape, dtype=self.dtype, device=self.device, value=self.dtype.zero)


@register_method(torch.Tensor.new_ones)
def tensor_new_ones(self: Tensor, size, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
if layout is not None and layout != torch.strided:
raise NotImplementedError("layout is not None and layout != torch.strided")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = size
if dtype is None:
dtype = self.dtype
device = self.device
_ = pin_memory
_ = requires_grad

return ops.full(shape, dtype=dtype, device=device, value=dtype.one)


@register_method(torch.Tensor.new_ones)
def tensor_new_ones_v2(
self: Tensor, *size, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False
):
if layout is not None and layout != torch.strided:
raise NotImplementedError("layout is not None and layout != torch.strided")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = size
if dtype is None:
dtype = self.dtype
device = self.device
_ = pin_memory
_ = requires_grad

return ops.full(shape, dtype=dtype, device=device, value=dtype.one)
15 changes: 14 additions & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,20 @@
from .quant import symmetric_quantize, symmetric_dequantize
from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .cumulative import cumsum
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape, im2col
from .transform import (
squeeze,
unsqueeze,
flatten,
concat,
cast,
take,
rearrange,
strided_slice,
reshape,
im2col,
as_strided,
flip,
)
from .transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims, gather, index_select, triu, tril
from .transform import permute_dims, meshgrid, repeat_interleave
from .fusion import fused_operator
Expand Down
4 changes: 3 additions & 1 deletion python/hidet/graph/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence
from typing import List, Sequence, Tuple
from hidet.graph.ops import permute_dims, reshape
from .utils import Tensor
from .matmul import matmul
Expand Down Expand Up @@ -346,6 +346,8 @@ def subscript_to_label(subscript: int) -> str:

# Do ad-hoc pattern matching: only support simple cases such as matrix multiply
def einsum(equation: str, operands: Sequence[Tensor]):
if isinstance(operands[0], (Tuple, List)):
operands = operands[0]
if len(operands) != 2:
raise NotImplementedError('einsum currently only supports 2 operands')

Expand Down
80 changes: 79 additions & 1 deletion python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Sequence
from typing import List, Optional, Tuple, Union, Sequence
from hidet.ir.type import DataType, data_type
from hidet.ir.expr import Expr, Constant, if_then_else, convert, cast as ir_cast, is_constant, logical_or
from hidet.ir.expr import Int
Expand Down Expand Up @@ -449,6 +449,54 @@ def fmap(*indices):
)


class AsStridedTask(Task):
def __init__(
self,
x: TensorNode,
size: Union[int, List[int]],
stride: Union[int, List[int]],
storage_offset: Optional[int] = None,
):
def unravel_index(index, shape):
out = []
for s in reversed(shape):
out.append(index % s)
index = index // s
return list(reversed(out))

storage_shift = storage_offset if storage_offset is not None else 0

def fmap(*indices):
stride1d = 0
for i, s in zip(indices, stride):
stride1d += i * s
stride1d += storage_shift

new_indices = unravel_index(stride1d, x.shape)
return x[new_indices]

out = compute(name='out', shape=size, fcompute=fmap)

super().__init__(
name='as_strided',
inputs=[x],
attributes={'size': size, 'stride': stride, 'storage_offset': storage_offset},
outputs=[out],
)


class FlipTask(Task):
def __init__(self, x: Tensor, dims: Union[List[int], Tuple[int]]):
def fmap(*indices):
idx = []
for i in range(len(indices)):
idx.append(if_then_else(i in dims, x.shape[i] - indices[i] - 1, indices[i]))
return x[idx]

out = compute(name='out', shape=x.shape, fcompute=fmap)
super().__init__(name='flip', inputs=[x], attributes={'dims': dims}, outputs=[out])


class ReshapeOp(Operator):
def __init__(self, x: Tensor, shape):
task = ReshapeTask(input_like(x, 'x'), shape)
Expand Down Expand Up @@ -644,6 +692,26 @@ def __init__(self, x: Tensor, kernel_size: List[Int], dilation: List[Int], paddi
)


class AsStridedOp(Operator):
def __init__(
self,
x: Tensor,
size: Union[int, List[int]],
stride: Union[int, List[int]],
storage_offset: Optional[int] = None,
):
super().__init__(
inputs=[x],
attributes={'size': size, 'stride': stride, 'storage_offset': storage_offset},
task=AsStridedTask(input_like(x, 'x'), size, stride, storage_offset),
)


class FlipOp(Operator):
def __init__(self, x: Tensor, dims: Union[List[int], Tuple[int]]):
super().__init__(inputs=[x], attributes={'dims': dims}, task=FlipTask(input_like(x, 'x'), dims))


def reshape(x: Tensor, shape) -> Tensor:
if same_shape(x.shape, shape):
return x
Expand Down Expand Up @@ -917,3 +985,13 @@ def im2col(
if nd == 3:
return x.squeeze(0)
return x


def as_strided(
x: Tensor, size: Union[int, List[int]], stride: Union[int, List[int]], storage_offset: Optional[int] = None
):
return AsStridedOp(x, size, stride, storage_offset).outputs[0]


def flip(x: Tensor, dims: Union[List[int], Tuple[int]]):
return FlipOp(x, dims).outputs[0]
24 changes: 24 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,5 +433,29 @@ def test_unfold(shape, kernel_size, dilation, padding, stride):
)


@pytest.mark.parametrize(
"shape, size, stride, storage_offset",
[
([1, 3, 10, 10], [10, 10], [1, 2], 0),
([3, 44, 41], [20, 3], [1, 1], 5),
([3, 1, 45, 33], [30, 40], [5, 2], 6),
([4, 10, 13], [7, 13, 2], [2, 3, 1], 0),
],
)
def test_as_strided(shape, size, stride, storage_offset):
check_transform_torch(
shape,
lambda x: torch.as_strided(x, size, stride, storage_offset),
lambda x: ops.as_strided(x, size, stride, storage_offset),
)


@pytest.mark.parametrize(
"shape, dims", [([1, 3, 10, 10], [1, 0]), ([3, 44, 41], [0]), ([3, 1, 45, 33], [0, 1, 2, 3]), ([4, 10, 13], [2, 0])]
)
def test_flip(shape, dims):
check_transform_torch(shape, lambda x: torch.flip(x, dims), lambda x: ops.flip(x, dims))


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit dfd4257

Please sign in to comment.