Skip to content

Commit

Permalink
[Option] Add an option to disable imperative execution (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
serach24 authored Sep 30, 2023
1 parent 2c2c819 commit 3272fc3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/hidet/graph/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Dict, Any

import hidet.option
from hidet.ir.type import TensorType
from hidet.ir.task import Task
from hidet.runtime.compiled_task import CompiledTask
Expand Down Expand Up @@ -119,11 +121,14 @@ def compiled_task(self) -> CompiledTask:
def run(self) -> List[Tensor]:
from hidet.ir.tools import collect

# we imperatively run the operator if
# We imperatively run the operator if
# 1. all inputs are concrete tensors (i.e., t.storage is not None)
# 2. there is no symbol variable in the task
# 3. configuration option "imperative" is True
could_imperative_run = (
all(t.storage is not None for t in self.inputs) and len(collect(self.task, SymbolVar)) == 0
all(t.storage is not None for t in self.inputs)
and len(collect(self.task, SymbolVar)) == 0
and hidet.option.get_option('imperative')
)

if could_imperative_run:
Expand Down
30 changes: 30 additions & 0 deletions python/hidet/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def register_hidet_options():
default_value='auto',
description='The CUDA architecture to compile the kernels for (e.g., "sm_70"). "auto" for auto-detect.',
)
register_option(
name='imperative',
type_hint='bool',
default_value=True,
description='Whether to enable imperative execution when op arguments allows',
)

config_file_path = os.path.join(os.path.expanduser('~'), '.config', 'hidet')
if not os.path.exists(config_file_path):
Expand Down Expand Up @@ -734,6 +740,30 @@ def get_runtime_check() -> bool:
return OptionContext.current().get_option('runtime_check')


def imperative(enable: bool = True):
"""
Whether to enable imperative execution when op arguments allows.
Parameters
----------
enable: bool
Whether to enable imperative execution when op arguments allows.
"""
OptionContext.current().set_option('imperative', enable)


def get_imperative() -> bool:
"""
Get whether to enable imperative execution when op arguments allows.
Returns
-------
ret: bool
Get whether to enable imperative execution when op arguments allows.
"""
return OptionContext.current().get_option('imperative')


def debug_show_verbose_flow_graph(enable: bool = True):
"""Whether to show verbose information (like task) when we convert flow graph in to human-readable text.
Expand Down

0 comments on commit 3272fc3

Please sign in to comment.