Skip to content

Commit

Permalink
[CompiledGraph] Add option to store dispatch table option (#377)
Browse files Browse the repository at this point in the history
This option will include the dispatch_table file as part of
save_compiled_graph

This way, it will be extracted into the hidet/graphs cache when
load_compiled_graph is called so the client can avoid this fine-tuning
  • Loading branch information
destefy authored Nov 15, 2023
1 parent e688ab1 commit 32ef876
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
22 changes: 11 additions & 11 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def get_compiled_graph(flow_graph: FlowGraph):
parallel_k = dynamo_config['parallel_k']
tensor_core = dynamo_config['use_tensor_core']
save_dir = dynamo_config['dump_graph_ir']

with PassContext() as ctx:
if use_fp16:
ctx.set_precision('float16')
Expand All @@ -98,6 +97,17 @@ def get_compiled_graph(flow_graph: FlowGraph):
return cgraph


def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
if not x.is_contiguous():
# warnings.warn_once('Hidet received a non-contiguous torch input tensor, converting it to contiguous')
x = x.contiguous()
torch_inputs.append(x)
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs


def get_wrapper(cgraph: CompiledGraph, inputs, output_format):
use_cuda_graph = dynamo_config['use_cuda_graph']
if use_cuda_graph:
Expand All @@ -108,16 +118,6 @@ def get_wrapper(cgraph: CompiledGraph, inputs, output_format):
else:
runner = cgraph

def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]:
torch_inputs: List[torch.Tensor] = []
for x in inputs:
if not x.is_contiguous():
# warnings.warn_once('Hidet received a non-contiguous torch input tensor, converting it to contiguous')
x = x.contiguous()
torch_inputs.append(x)
hidet_inputs: List[hidet.Tensor] = [hidet.from_torch(tensor) for tensor in torch_inputs]
return hidet_inputs

def run(*inputs: torch.Tensor):
hidet_inputs = preprocess_inputs(inputs)
hidet_outputs: List[hidet.Tensor] = runner.run_async(hidet_inputs)
Expand Down
9 changes: 7 additions & 2 deletions python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def save(self, path: str):
save_compiled_graph(self, path)


def save_compiled_graph(model: CompiledGraph, path: str):
def save_compiled_graph(model: CompiledGraph, path: str, save_dispatch_table: bool = False):
from hidet.utils.dataclass import asdict

dirname = os.path.dirname(path)
Expand Down Expand Up @@ -416,6 +416,12 @@ def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = N
ge_bytes = json.dumps(asdict(model.graph_execution), indent=4).encode('utf-8')
f.write(ge_bytes)

# save dispatch table file
if save_dispatch_table and os.path.exists(model.dispatch_table_path):
with zf.open('dispatch_table.txt', 'w') as f:
with open(model.dispatch_table_path, 'rb') as f2:
f.write(f2.read())

# save graph string
with zf.open('graph_string.txt', 'w') as f:
f.write(model.graph_string.encode('utf-8'))
Expand Down Expand Up @@ -448,7 +454,6 @@ def load_compiled_graph(path: str) -> CompiledGraph:

# extract all files except weights
cache_dir = hidet.utils.cache_dir('graphs', meta_data.graph_hash)

if not os.path.exists(os.path.join(cache_dir, 'graph_string.txt')):
# only extract files if the graph_string.txt is not in the cache
# here 'graph_string.txt' is just the last file we usually save to disk, we use it as a flag
Expand Down

0 comments on commit 32ef876

Please sign in to comment.