Skip to content

Commit

Permalink
Inherit mode argument from torch.compile and set corresponding op…
Browse files Browse the repository at this point in the history
…tions (#237)

`mode` argument of `torch.compile` now inherits and set the corresponding
hidet options (`search_space` and `use_cude_graph`)
  • Loading branch information
vadiklyutiy committed Jul 22, 2024
1 parent 4add4b9 commit 4d95978
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@

logger = logging.getLogger(__name__)

# TODO: after search_space=1 will be tuned switch search_space from 0 to 1
def process_options(kwargs):
# Default options for case mode is not passed to torch.compile()
hidet.option.search_space(0)
hidet.torch.dynamo_config.search_space(0)
hidet.torch.dynamo_config.use_cuda_graph(False)

if 'mode' in kwargs:
mode = kwargs['mode']
if mode == 'max-autotune':
hidet.option.search_space(2)
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_cuda_graph(True)
elif mode == 'max-autotune-no-cudagraphs':
hidet.option.search_space(2)
hidet.torch.dynamo_config.search_space(2)
hidet.torch.dynamo_config.use_cuda_graph(False)
elif mode == 'reduce-overhead':
hidet.option.search_space(0)
hidet.torch.dynamo_config.search_space(0)
hidet.torch.dynamo_config.use_cuda_graph(True)
elif mode == 'default':
hidet.option.search_space(0)
hidet.torch.dynamo_config.search_space(0)
hidet.torch.dynamo_config.use_cuda_graph(False)
else:
raise ValueError(f'hidet_backend: unknown torch.compile mode={mode}')


# NOTES ABOUT DYNAMIC SHAPE.
# From pytorch we got two argument:
# - fxgraph
Expand Down Expand Up @@ -154,32 +183,36 @@ def __call__(self, *args):
return deserialize_output(self.output_format, outputs)


def hidet_backend(graph_module, example_inputs):
def hidet_backend(graph_module, example_inputs, **kwargs):
assert isinstance(graph_module, torch.fx.GraphModule)

logger.info('received a subgraph with %d nodes to optimize', len(graph_module.graph.nodes))
logger.debug('graph: %s', graph_module.graph)

if dynamo_config['print_input_graph']:
graph_module.print_readable()
print('---')
graph_module.graph.print_tabular()
with hidet.option.context():
# Process options passed to torch.compile
process_options(kwargs)

if dynamo_config['print_input_graph']:
graph_module.print_readable()
print('---')
graph_module.graph.print_tabular()

# get the interpreter for the subgraph
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)
# get the interpreter for the subgraph
interpreter: Interpreter = hidet.frontend.from_torch(graph_module)

if dynamo_config['correctness_report']:
# check correctness using random inputs
def wrapper(*args):
report, output = interpreter.forward_with_check(*args)
logger.info('finish checking correctness')
print(report)
return output
if dynamo_config['correctness_report']:
# check correctness using random inputs
def wrapper(*args):
report, output = interpreter.forward_with_check(*args)
logger.info('finish checking correctness')
print(report)
return output

return wrapper
return wrapper

flow_graph, inputs, output_format = get_flow_graph(interpreter, example_inputs)
flow_graph, inputs, output_format = get_flow_graph(interpreter, example_inputs)

cgraph = get_compiled_graph(flow_graph)
cgraph = get_compiled_graph(flow_graph)

return HidetCompiledModel(cgraph, inputs, output_format)
return HidetCompiledModel(cgraph, inputs, output_format)

0 comments on commit 4d95978

Please sign in to comment.