Skip to content

Commit

Permalink
[OPTIONS] Use Attention by default (#261)
Browse files Browse the repository at this point in the history
Use attention by default
  • Loading branch information
vadiklyutiy authored Jun 10, 2024
1 parent 2dc101e commit 83931d1
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 62 deletions.
2 changes: 0 additions & 2 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ def get_flow_graph(interpreter: Interpreter, example_inputs):


def get_compiled_graph(flow_graph: FlowGraph):
use_attention = dynamo_config['use_attention']
search_space = dynamo_config['search_space']
parallel_k = dynamo_config['parallel_k']
tensor_core = dynamo_config['use_tensor_core']
save_dir = dynamo_config['dump_graph_ir']
with PassContext() as ctx:
ctx.set_use_attention(use_attention)
if save_dir:
graph_dir = resolve_save_dir_multigraph(save_dir)
ctx.save_graph_instrument(graph_dir)
Expand Down
5 changes: 1 addition & 4 deletions python/hidet/graph/frontend/torch/dynamo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class DynamoConfig:
def __init__(self):
self._search_space: int = 0
self._parallel_k: str = 'default'
self._use_attention: bool = False
self._use_cuda_graph: bool = True
self._use_tensor_core: bool = False
self._print_input_graph: bool = False
Expand All @@ -42,7 +41,6 @@ def reset(self):
"""
self._search_space: int = 0
self._parallel_k: str = 'default'
self._use_attention: bool = False
self._use_cuda_graph: bool = True
self._use_tensor_core: bool = False
self._print_input_graph: bool = False
Expand Down Expand Up @@ -113,8 +111,7 @@ def use_attention(self, flag=False):
"""
Whether to use fused attention schedule
"""
self._use_attention = flag
return self
dynamo_config_warning()

def use_cuda_graph(self, flag=True):
"""
Expand Down
25 changes: 0 additions & 25 deletions python/hidet/graph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ def __init__(self):
# target reduce precision:
# [None, 'float16', 'float32']
'reduce_precision': None,
# use attention or not
# [True, False]
'use_attention': False,
# mma primitive:
# ['simt', 'mma']
'mma': 'simt',
Expand All @@ -95,9 +92,6 @@ def __enter__(self) -> PassContext:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
from ..transforms.graph_patterns.attn_patterns import deregister_attn_patterns

deregister_attn_patterns()
popped = self._stack.pop()
assert popped == self

Expand Down Expand Up @@ -191,25 +185,6 @@ def set_reduce_precision(self, dtype: Optional[str] = None) -> PassContext:
self.configs['reduce_precision'] = dtype
return self

def set_use_attention(self, flag=False) -> PassContext:
"""
Set to use fused attention schedule
"""
# fmha requires sm75+
cc = hidet.option.cuda.get_arch_pair()
if cc < (7, 5):
return self

from ..transforms.graph_patterns.attn_patterns import register_attn_patterns, deregister_attn_patterns

self.configs['use_attention'] = flag
if flag:
register_attn_patterns()
else:
deregister_attn_patterns()
return self

def set_verbose(self) -> PassContext:
"""
Allow each graph level passes to print detailed information related to its lowering and optimization.
Expand Down
39 changes: 12 additions & 27 deletions python/hidet/graph/transforms/graph_patterns/attn_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from hidet.ir.dtypes import f16
from hidet.ir.expr import is_true
from hidet.graph.transforms.graph_patterns import MatchDict
from hidet.graph.transforms.graph_patterns import op_pattern, register_rewrite_rule, deregister_rewrite_rule
from hidet.graph.transforms.graph_patterns import op_pattern, register_rewrite_rule
from hidet.graph.transforms.graph_patterns import TensorPattern, SubgraphRewriteRule
from hidet.graph.ops.matmul import MatmulOp
from hidet.graph.ops.arithmetic import AddOp, MultiplyScalarOp, DivideScalarOp
from hidet.graph.ops.activation import SoftmaxOp
from hidet.graph.ops.attention import attention
from hidet.utils import initialize


class ReorderMulScaleRewriteRule(SubgraphRewriteRule):
Expand Down Expand Up @@ -109,30 +110,14 @@ def target(self, matched: MatchDict):
return None


registered_attn_rules = []


@initialize()
def attn_patterns():
registered_attn_rules.append(AttentionRewriteRule())
registered_attn_rules.append(AttentionMaskAddRewriteRule())
registered_attn_rules.append(ReorderMulScaleRewriteRule())
registered_attn_rules.append(ReorderDivScaleRewriteRule())
for attn_rule in registered_attn_rules:
register_rewrite_rule(attn_rule)


def register_attn_patterns():
if len(registered_attn_rules) != 0:
return
registered_attn_rules.append(AttentionRewriteRule())
registered_attn_rules.append(AttentionMaskAddRewriteRule())
registered_attn_rules.append(ReorderMulScaleRewriteRule())
registered_attn_rules.append(ReorderDivScaleRewriteRule())
for attn_rule in registered_attn_rules:
register_rewrite_rule(attn_rule)


def deregister_attn_patterns():
for attn_rule in registered_attn_rules:
deregister_rewrite_rule(attn_rule)
registered_attn_rules.clear()
from hidet.option import cuda

cc = cuda.get_arch_pair()
# fmha requires sm75+
if cc >= (7, 5):
register_rewrite_rule(AttentionRewriteRule())
register_rewrite_rule(AttentionMaskAddRewriteRule())
register_rewrite_rule(ReorderMulScaleRewriteRule())
register_rewrite_rule(ReorderDivScaleRewriteRule())
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# pylint: disable=unused-import
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns
from .attn_patterns import attn_patterns
from .conv2d_patterns import conv2d_patterns
from .matmul_patterns import matmul_patterns
1 change: 0 additions & 1 deletion python/hidet/testing/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def init_hidet(self):
import os

hidet.torch.dynamo_config.search_space(self.search_space)
hidet.torch.dynamo_config.use_attention(True)
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
hidet.option.search_space(self.search_space)
Expand Down
1 change: 0 additions & 1 deletion scripts/regression/model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def enable_compiled_server():

def setup_hidet_flags(dtype):
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_attention(True)
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
hidet.torch.dynamo_config.dump_graph_ir("./graph_ir")
Expand Down
1 change: 0 additions & 1 deletion tests/benchmarks/bench_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def bench_reduce(params: str, *args, **kwargs) -> float:

with hidet.graph.PassContext() as ctx:
ctx.set_reduce_precision(dtype)
ctx.set_use_attention(True)
ctx.set_mma('mma')
latency = bench_func(params, dtype)
print(latency)

0 comments on commit 83931d1

Please sign in to comment.