Skip to content

Commit

Permalink
Make attn disabled by default
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjq committed Apr 18, 2023
1 parent af3ec3d commit 4926922
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
6 changes: 3 additions & 3 deletions python/hidet/graph/frontend/torch/dynamo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
self._parallel_k: str = 'default'
self._use_fp16: bool = False
self._use_fp16_reduction: bool = False
self._use_attention: bool = True
self._use_attention: bool = False
self._use_cuda_graph: bool = True
self._use_tensor_core: bool = False
self._print_input_graph: bool = False
Expand All @@ -37,7 +37,7 @@ def reset(self):
self._parallel_k: str = 'default'
self._use_fp16: bool = False
self._use_fp16_reduction: bool = False
self._use_attention: bool = True
self._use_attention: bool = False
self._use_cuda_graph: bool = True
self._use_tensor_core: bool = False
self._print_input_graph: bool = False
Expand Down Expand Up @@ -105,7 +105,7 @@ def use_fp16_reduction(self, flag=True):
self._use_fp16_reduction = flag
return self

def use_attention(self, flag=True):
def use_attention(self, flag=False):
"""
Whether to use fused attention schedule
"""
Expand Down
7 changes: 5 additions & 2 deletions python/hidet/graph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self):
'reduce_precision': None,
# use attention or not
# [True, False]
'use_attention': True,
'use_attention': False,
# mma primitive:
# ['simt', 'wmma', 'mma']
'mma': 'simt',
Expand All @@ -91,6 +91,9 @@ def __enter__(self) -> PassContext:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
from ..transforms.graph_patterns import deregister_attn_patterns

deregister_attn_patterns()
popped = self._stack.pop()
assert popped == self

Expand Down Expand Up @@ -151,7 +154,7 @@ def set_reduce_precision(self, dtype: Optional[str] = None) -> PassContext:
self.configs['reduce_precision'] = dtype
return self

def set_use_attention(self, flag=True) -> PassContext:
def set_use_attention(self, flag=False) -> PassContext:
"""
Set to use fused attention schedule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hidet.graph.transforms.graph_patterns import MatchDict
from hidet.graph.transforms.graph_patterns import op_pattern, register_rewrite_rule, deregister_rewrite_rule
from hidet.graph.transforms.graph_patterns import TensorPattern, SubgraphRewriteRule
from hidet.utils import same_list, initialize
from hidet.utils import same_list
from hidet.graph.ops.definitions.matmul import MatmulOp
from hidet.graph.ops.definitions.arithmetic import AddOp, MultiplyScalarOp, DivideScalarOp
from hidet.graph.ops.definitions.activation import SoftmaxOp
Expand Down Expand Up @@ -105,7 +105,6 @@ def target(self, matched: MatchDict):
registered_attn_rules = []


@initialize()
def attn_patterns():
registered_attn_rules.append(AttentionRewriteRule())
registered_attn_rules.append(AttentionMaskAddRewriteRule())
Expand Down

0 comments on commit 4926922

Please sign in to comment.