Skip to content

Commit

Permalink
[CuTe] fix longformer (#411)
Browse files Browse the repository at this point in the history
The issue is caused by a wrong layout for the bias tensor.   
For example, we consider a bias tensor of shape (64, ) and its layout
can be written as
`(64, ): (1, )`
However, we can expand the layout by adding axes with 1-shape.  
For example, 
`(64, 1):(1, 1)`
Since the shape is equal to 1, the stride can be any number. The stride
corresponding to the 1-shape actually doesn't affect the computation of
the address. But two strides that are equal to one will influence the
instruction selection, and the invalid memory instruction leads to the
misaligned access.
To fix this issue, we force the stride paired with 1-shape to be 0. The
layout is equivalent when computing the memory address, and this will
help the compiler make the right decision in the instruction selection
pass.
closes #404

Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
  • Loading branch information
xiaocenxiaocen and xiaocenxiaocen authored Aug 12, 2024
1 parent 99d231c commit b808ca4
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/hidet/graph/ops/fusion/apply_prologue_epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
coalesce,
make_layout,
complement,
canonicalize,
CopyAtom,
TiledCopy,
ThrValAtom,
Expand Down Expand Up @@ -738,6 +739,7 @@ def _backward(self, ops: Sequence[Operator]):
TensorLayout(tuple(reversed(result_shape)), tuple(reversed(result_stride)))
)
layout = composition(tile_mapping, output_layout)
layout = canonicalize(layout)
# assert all(is_constant(s) for s in output_shape)
self.tensor2tile[ti] = (self._tile_divide(layout), layout, tuple(output_shape))
else:
Expand Down
1 change: 1 addition & 0 deletions python/hidet/ir/cute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
logical_product,
logical_divide,
make_layout,
canonicalize,
)
from .layout import ThrValAtom, Level
from .algorithm import CopyAtom, TiledCopy
19 changes: 19 additions & 0 deletions python/hidet/ir/cute/int_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,25 @@ def filter_zeros(a, b):
return 1 if is_constant(a) and a == 0 else b


def canonicalize_uni_shape(a, b):
"""
Replace the elements of Tuple b that are paired with an 1-shape with an
0-stride
Example:
a = ((3, 4), 1)
b = ((1, 3), 1)
=>
b = ((1, 3), 0)
"""
if isinstance(a, tuple):
assert isinstance(b, tuple) and len(b) == len(a)
return tuple(canonicalize_uni_shape(x, y) for x, y in zip(a, b))
else:
assert is_integer(a)
return 0 if is_constant(a) and a == 1 else b


def is_static(a):
if isinstance(a, int):
return True
Expand Down
6 changes: 6 additions & 0 deletions python/hidet/ir/cute/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
shape_abs,
shape_min,
is_static,
canonicalize_uni_shape,
)


Expand Down Expand Up @@ -638,3 +639,8 @@ def __eq__(self, other):
left_thr_layout = coalesce(self.thr_layout())
right_thr_layout = coalesce(other.thr_layout())
return left_thr_layout == right_thr_layout and self.val_layout() == other.val_layout()


def canonicalize(a: TensorLayout):
stride = canonicalize_uni_shape(a.shape, a.stride)
return TensorLayout(a.shape, stride)
36 changes: 36 additions & 0 deletions tests/lang/cute/test_matmul_bias_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,39 @@ def graph(a, b, bias):
np.testing.assert_allclose(actual=D_hidet.cpu().numpy(), desired=D.cpu().numpy(), rtol=1e-2)
hidet_mean, hidet_min, hidet_max = bench(graph_hidet, graph_args)
print(f"hidet(torch.compile): {hidet_mean} ms")


def test_longformer_issue404():
args = [1, 2304, 768, 512]

def graph(a, b, bias):
c = a @ b
c = c + bias
return c

M, N, K, L = args
graph_args = data(*args)

import torch._dynamo as dynamo

options = {"triton.cudagraphs": False, "epilogue_fusion": True, "max_autotune": True}
D = graph(*graph_args)
graph_opt = torch.compile(graph, options=options)
D_opt = graph_opt(*graph_args)
np.set_printoptions(threshold=3000, linewidth=200, edgeitems=100)
np.testing.assert_allclose(actual=D_opt.cpu().numpy(), desired=D.cpu().numpy(), rtol=1e-2)

torch_mean, torch_min, torch_max = bench(graph_opt, graph_args)
print(f"baseline(torch.compile mode=max-autotune): {torch_mean} ms")

hidet.torch.dynamo_config.reset()
hidet.torch.dynamo_config.parallel_k(strategy="disabled")

D = graph(*graph_args)
dynamo.reset()
graph_hidet = torch.compile(graph, backend="hidet", mode="max-autotune-no-cudagraphs")
D_hidet = graph_hidet(*graph_args)
np.set_printoptions(threshold=3000, linewidth=200, edgeitems=100)
np.testing.assert_allclose(actual=D_hidet.cpu().numpy(), desired=D.cpu().numpy(), rtol=1e-2)
hidet_mean, hidet_min, hidet_max = bench(graph_hidet, graph_args)
print(f"hidet(torch.compile): {hidet_mean} ms")

0 comments on commit b808ca4

Please sign in to comment.