Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

Commit

Permalink
Improvements after merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh Fromm committed Mar 6, 2023
1 parent f3240fc commit a06ae38
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 73 deletions.
65 changes: 4 additions & 61 deletions python/tvm/octo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,70 +21,12 @@
from typing import Union, Optional, Dict, List
import tvm
from tvm import relax
from tvm.relax.frontend import from_onnx
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from .utils import get_cuda_target, get_llvm_target
from .octo_model import OctoModel


# TODO(jwfromm) This will later be replaced by a full pass from Xiyou.
def cuda_bind_threads(tvm_model: tvm.ir.IRModule, target: tvm.target.Target):
"""Schedule an IRModule on Cuda.
Parameters
----------
tvm_model : tvm.ir.IRModule
The input module to transform. Each primfunc in the module will be
rewritten to include thread and block bindings so that it can be
run on cuda.
target : tvm.target.Target
The full description of the target device.
Returns
-------
output_model : tvm.ir.IRModule
The rewritten input module that can now be compile and run on cuda.
"""

@tvm.transform.module_pass(opt_level=0)
def thread_bind(tvm_model: tvm.ir.IRModule, ctx: tvm.transform.PassContext):
"""A relax pass to do thread binding for the relax model."""
global_vars = tvm_model.get_global_vars()
max_threadblocks = 256
max_threads_per_block = tvm.target.Target(target).attrs["max_num_threads"]

for var in global_vars:
if isinstance(tvm_model[var], tvm.tir.PrimFunc):
func = tvm_model[var]
mod = tvm.IRModule({"main": func.with_attr("global_symbol", "main")})
sch = tvm.tir.Schedule(mod)
get_blocks_func = tvm.get_global_func("tvm.meta_schedule.collect_blocks")
blocks = get_blocks_func(sch, None) # no filter func
for block in blocks:
if len(sch.get_loops(block)) == 0:
continue
# Only fuse data parallel loops
iter_vars = sch.get(block).iter_vars
loops = sch.get_loops(block)
data_parralel_loops = []
for i, loop in enumerate(loops):
# Check that the corresponding itervar is data parallel.
if iter_vars[i].iter_type == tvm.tir.IterVar.DataPar:
data_parralel_loops.append(loop)

loop = sch.fuse(*data_parralel_loops)
splits = sch.split(
loop, factors=[None, max_threadblocks, max_threads_per_block]
)
sch.reorder(splits[1], splits[2], splits[0])
sch.bind(splits[1], "blockIdx.x")
sch.bind(splits[2], "threadIdx.x")

tvm_model[var] = sch.mod["main"]
return tvm_model

return thread_bind(tvm_model)


def load_onnx_model(
model_file: Union[str, Path, onnx.ModelProto], shape_dict: Optional[Dict[str, List]] = None
) -> tvm.IRModule:
Expand All @@ -110,7 +52,7 @@ def load_onnx_model(
model_file = onnx.load(model_file)

# Convert the graph into a relax implementation.
relax_mod = relax.from_onnx(model_file, shape=shape_dict)
relax_mod = from_onnx(model_file, shape=shape_dict)

return relax_mod

Expand Down Expand Up @@ -206,7 +148,8 @@ def compile(

# Schedule all remaining functions to be compatible with gpu if needed.
if str(target.kind) == "cuda":
relax_mod = cuda_bind_threads(relax_mod, target)
with target, tvm.transform.PassContext(opt_level=3):
relax_mod = tvm.tir.transform.DefaultGPUSchedule()(relax_mod)

# Compile the module.
exe = relax.build(relax_mod, target)
Expand Down
24 changes: 12 additions & 12 deletions tests/python/octo/test_octo_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@


def get_simple_onnx_model():
# Create a single onnx convolution model that can be used for testing.
conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"])
# Create a single onnx matmul model that can be used for testing.
matmul_node = helper.make_node("MatMul", ["a", "b"], ["c"])
graph = helper.make_graph(
[conv_node],
"minimal_conv",
[matmul_node],
"minimal_matmul",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 32, 32]),
helper.make_tensor_value_info("w", TensorProto.FLOAT, [16, 3, 3, 3]),
helper.make_tensor_value_info("b", TensorProto.FLOAT, [16]),
helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 16, 30, 30])],
outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])],
)
model = helper.make_model(graph, producer_name="minimal_conv")
model = helper.make_model(graph, producer_name="minimal_matmul")
return model


Expand All @@ -45,7 +44,7 @@ def test_e2e_flow():
test_model = get_simple_onnx_model()
octo_model = tvm.octo.compile(test_model)
# Check that the produced model has properly formed shape info.
assert octo_model.input_info["x"] == ([1, 3, 32, 32], "float32")
assert octo_model.input_info["a"] == ([32, 32], "float32")

# Test that the OctoModel can be saved and loaded.
temp = utils.tempdir()
Expand All @@ -59,6 +58,7 @@ def test_e2e_flow():

# Test the running and benchmarking helpers.
outputs = octo_model.run()
assert list(outputs[0].shape) == [1, 16, 30, 30]
assert list(outputs[0].shape) == [32, 32]
report = octo_model.profile()
assert "conv2d" in str(report)
# Confirm cutlass offload was successful.
assert "matmul_cutlass" in str(report)

0 comments on commit a06ae38

Please sign in to comment.