Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

[TUZ-150] Add a simplified access point for Unity Flow #32

Merged
merged 162 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
162 commits
Select commit Hold shift + click to select a range
f7165a1
[microTVM] Fix tvmc tutorial (#14076)
mehrdadh Feb 25, 2023
10fb8c5
[MetaSchedule] Introduce Async Pipeline in MultiLevelTiling (#14009)
cblmemo Feb 25, 2023
9fab56c
[TVMScript] Use op attribute to control whether to print dtype in TVM…
liangW-intellif Feb 25, 2023
1ad1994
[Fix][TVMScript] Fix index of metadata in printed script (#14130)
Ubospica Feb 25, 2023
f21a17b
[Pytorch] frontend full_impl fix (#14122)
Feb 26, 2023
d9b0a80
[DOCKER] Configurable NDK version support (#14000)
srkreddy1238 Feb 27, 2023
54a62c1
[Fix][TIR] SampleCategorical apply-to-schedule (#14133)
MasterJH5574 Feb 27, 2023
74603ee
[Arith] ConstIntBound was incorrectly assuming bounds were over int64…
Feb 27, 2023
0e046da
[CMSIS-NN] Reduction in code size of AOT test runner binary (#13815)
NicolaLancellotti Feb 27, 2023
77df6e8
[CMSIS-NN] Add a runtime error message (#13643)
NicolaLancellotti Feb 27, 2023
bf589f3
[CRT]Cleanup unused macros in crt_config.h.template (#14125)
mehrdadh Feb 27, 2023
663f7ae
[Fix][Relay] Fix axis transformation in squeeze shape function (#14135)
Lucien0 Feb 27, 2023
4d152fe
[Unittest] merge test_cp_async_in_if_then_else into test_tir_transfor…
cblmemo Feb 27, 2023
2feb243
[Frontend][TFLite] Fix conv2d import bug (#14124)
mehrdadh Feb 27, 2023
6097df5
[ONNX][TORCH] Replace scatter op by scatter_elements (#14019)
vvchernov Feb 28, 2023
2b2cb96
[TVMScript][Printer] Remove relax prefix for now (#14140)
tqchen Feb 28, 2023
7d67bb1
[microNPU] Sum legalization support (#13997)
Aleksei-grovety Feb 28, 2023
7c06de5
[Fix][MetaSchedule] Fix redundant stages in async pipeline for mlt (#…
cblmemo Feb 28, 2023
428400c
[COMMUNITY] Cheng Wen -> Reviewer (#14153)
Hzfengsy Mar 1, 2023
1043136
[Runtime] Fix high RAM usage when saving / loading paramters of big m…
masahi Mar 1, 2023
e9cf04e
[Relay][Frontend] Span Filling PyTorch (#14050)
chunit-quic Mar 1, 2023
6c04ac5
[TRT][BYOC] allow strided_slice ops on selected dimensions (#14142) (…
AreopagX Mar 1, 2023
69acdfb
[ONNX][TOPI] Add `DFT` operator (#13999)
KJlaccHoeUM9l Mar 1, 2023
908dc8f
[CRT][microTVM] Enable USMP by default for AoTExecutor + CRT runtime …
mehrdadh Mar 2, 2023
25f4d06
[Android] Fix using system libraries in Android apps (#14145)
echuraev Mar 2, 2023
05cbe32
[microTVM]Enable TVMC micro with AoT Executor (#14077)
mehrdadh Mar 2, 2023
bd8e7d3
[bugfix] Fix the write buffer scope of `mma_store_impl` (#14174)
yzh119 Mar 2, 2023
cb37b82
[Relay] Enhance EliminateCommonSubexpr to support Tuple argument (#14…
lixiaoquan Mar 2, 2023
91dc8ef
[TIR] Fix typo in doc (#14178)
vinx13 Mar 2, 2023
a42e98b
[microTVM] Use QNN schedules to give SOTA performance (#13752)
guberti Mar 2, 2023
bc92a3f
Add v0.11.0 docs link to site (#14181)
areusch Mar 3, 2023
df429c5
[TIR] Allow TransformLayout with non-inversible index map (#14095)
vinx13 Mar 3, 2023
c0f148a
[TIR][Analysis] Implement IdentifyMemCpy analysis function (#13947)
Lunderberg Mar 4, 2023
736ceca
[HotFix][MetaSchedule] Turn off database shash check (#14188)
MasterJH5574 Mar 4, 2023
22c47ee
[TOPI] Batch Norm Training Mode (#14190)
SiriusNEO Mar 4, 2023
baedf7f
[TOPI] Group normalization (#14193)
MasterJH5574 Mar 5, 2023
befdc4e
[Fix][TIR] LowerCrossThreadReduction with write-back predicate (#14199)
MasterJH5574 Mar 5, 2023
e7b02f2
[Unity] Relax VM (#13878)
YuchenJin Feb 1, 2023
9508a18
[Unity] Relax expressions and types (#13901)
YuchenJin Feb 2, 2023
4d46290
[Unity][IR] First-class StructInfo (#13907)
YuchenJin Feb 3, 2023
b59ad48
[Unity][CI] Unity specific jenkins setup (do not upstream to main) (#…
tqchen Feb 3, 2023
ff8bfa2
[Unity] Basic StructInfo Analysis and Expr construction (#13916)
YuchenJin Feb 5, 2023
bb0c129
[Unity] Relax BlockBuilder and ExprMutator (#13926)
YuchenJin Feb 7, 2023
1807e6f
[Unity] Relax TVMScript Parser. (#13932)
Hzfengsy Feb 8, 2023
846a2c5
[Unity] Relax TVMScript Printer (#13944)
junrushao Feb 10, 2023
850d6a4
[Unity] Relax VM codegen (#13954)
YuchenJin Feb 11, 2023
a966cf1
[Unity] Relax VM shape lowering pass (#13956)
YuchenJin Feb 11, 2023
f735d93
[Unity] e2e Relax minimum build flow (#13961)
YuchenJin Feb 11, 2023
ad4185c
[Unity][TVMScript] Use explicit `R.shape` in TVMScript (#13979)
Hzfengsy Feb 14, 2023
e8227b9
[Unity] Relax op: index (#13987)
MasterJH5574 Feb 14, 2023
2e08c8c
[Unity] Relax op: datatype (#13986)
MasterJH5574 Feb 14, 2023
68a04a8
[Unity] Relax op: set (#13990)
MasterJH5574 Feb 14, 2023
5723ebb
[Unity] Relax op: image (#13994)
MasterJH5574 Feb 14, 2023
72bca0f
[Unity] Relax op: arithmetic, comparison (#13983)
MasterJH5574 Feb 14, 2023
f491b96
[Unity] Relax op: statistical (#13991)
MasterJH5574 Feb 14, 2023
71437f7
[Unity] Relax op: neural networks (#13993)
MasterJH5574 Feb 14, 2023
9113fc9
[Unity] Relax op: creation (#13984)
MasterJH5574 Feb 14, 2023
c788135
[Unity] Relax op: linear algebra (#13988)
MasterJH5574 Feb 14, 2023
b6818bb
[Unity] Relax op: search (#13992)
MasterJH5574 Feb 14, 2023
17cf446
[Unity] Relax op: manipulation (#13989)
MasterJH5574 Feb 14, 2023
9317ec8
[Unity] NestedMsg Support utility (#13995)
tqchen Feb 14, 2023
07e0dfb
[Unity][Pass] Operator Fusion Passes (#14001)
Hzfengsy Feb 15, 2023
33c4aab
[Unity][Pass] LambdaLift pass (#14012)
yongwww Feb 16, 2023
fd5c73d
[Unity][VM] Supporting "compiled" exec mode. (#14015)
tqchen Feb 17, 2023
733fc00
[Unity][Pass] BindParams pass, FoldConstant pass (#14016)
sunggg Feb 17, 2023
88852c1
[Unity][Pass][TuningAPI] Introduce TuningAPI and MetaSchedule pass (#…
sunggg Feb 17, 2023
24470c9
[Unity] Relay -> Relax translator (#14026)
YuchenJin Feb 17, 2023
fc0540c
[Unity][Pass] Normalize Pass (#14031)
LeshengJin Feb 18, 2023
91adf7b
[Unity][BlockBuilder] CallTE convert PrimValue args (#14028)
MasterJH5574 Feb 18, 2023
53f800d
[Unity][Pass] Wellformed Analysis (#14032)
LeshengJin Feb 18, 2023
596d472
[Unity][TVMScript] Move tir/relax import in script out of __init__.py…
MasterJH5574 Feb 18, 2023
b150b1a
[Unity][Pass] Operator legalization (#14029)
MasterJH5574 Feb 18, 2023
449e094
[Unity][Op] Add ShapeExpr Tests for Reshape Op (#14035)
Ubospica Feb 18, 2023
9b5f214
[Unity] Initial PyTorch Frontend (#14037)
MasterJH5574 Feb 18, 2023
fd35d1e
[Unity][Pass] Block-level static memory planning (#14038)
MasterJH5574 Feb 18, 2023
d8fdd5c
[Unity] Disallow inline prim_func in relax IR (#14040)
yongwww Feb 18, 2023
8039f6a
[Unity] Update tests to adapt to latest TVMScript syntax (#14039)
MasterJH5574 Feb 18, 2023
782c632
[Unity] Relax dataflow pattern language (matching) (#14041)
ganler Feb 18, 2023
1b85765
[Unity] Statement rewriter for DataflowBlock (#14043)
ganler Feb 19, 2023
c1439b3
[Unity][Pass] FuseOps FuseTIR fixes (#14044)
MasterJH5574 Feb 19, 2023
180bead
[Unity][TVMScript] Overload `__neg__` for relax expr (#14045)
SiriusNEO Feb 19, 2023
fe528f6
[Unity][VM] Add per-op profiling support (#14053)
masahi Feb 20, 2023
aa55c05
[Unity][BYOC] Add pattern-based partitioning pass (#14054)
masahi Feb 20, 2023
fada709
[Unity] Relax op: collapse sum (#14059)
SiriusNEO Feb 21, 2023
06de35e
[Unity][Fix][Pass] Fix FuseOps for lack graph edges (#14058)
MasterJH5574 Feb 21, 2023
a9032d9
[Unity][Pass] Remove Unused Function (#14061)
sunggg Feb 21, 2023
ed2696a
[Unity][BYOC] Add pass to merge composite functions to offload large …
masahi Feb 21, 2023
0d58835
[Unity][Frontend] Annotate number of non-static input of FX function …
vinx13 Feb 21, 2023
246c4c1
[Unity][Transform] Add LiftTransformParams pass (#14069)
vinx13 Feb 21, 2023
e8a0c4d
[Unity][BYOC][Pass] RunCodegen and TensorRT (#14078)
sunggg Feb 22, 2023
4ad8d64
[Unity][Pass] Canonicalize Bindings (#14079)
YuchenJin Feb 22, 2023
3a64963
[Unity] Add testcases for `expr_args_converter` (#14080)
Hzfengsy Feb 22, 2023
b8460eb
[Unity][BYOC] Add CUTLASS backend (#14081)
masahi Feb 22, 2023
defc15b
[Unity][BYOC] Add DNNL backend (#14082)
masahi Feb 22, 2023
7645aa7
[Unity][Op] `log_softmax` and `cross_entropy_with_logits` (#14083)
SiriusNEO Feb 22, 2023
6e2d7bb
[Unity][Analysis] TIR pattern kind analysis for multi-buffer write bl…
MasterJH5574 Feb 22, 2023
acd0e0b
[Unity][Fix][Pass] FoldConstant with DCE in dataflow block (#14087)
MasterJH5574 Feb 22, 2023
1950940
[Unity] Refactor Relax Build JIT UX (#14088)
tqchen Feb 22, 2023
cf36b7b
[Unity][Relax] Set Shape Function to Be Host Function (#14090)
zxybazh Feb 22, 2023
b1f2d53
[Unity] Fix typo in the comment (#14096)
vinx13 Feb 22, 2023
a8338e6
[Unity] Lower `shape_of` to a builtin (#14093)
YuchenJin Feb 22, 2023
74f3007
[Unity] Relax Recursive function (#14092)
yongwww Feb 23, 2023
98d0a01
[Unity][Layout] Add layout transformation analysis for PrimFunc (#14066)
psrivas2 Feb 23, 2023
b755a6f
[Unity] Remove attributes of relax.print, assert and unique (#14101)
yongwww Feb 23, 2023
eaaa1fb
[Unity][BYOC]Add relax backend pattern registry (#14106)
yelite Feb 24, 2023
89bb68b
[Unity] Update tests again to adapt to latest TVMScript syntax (#14115)
Ubospica Feb 24, 2023
02b3a1f
[Unity][Fix] Fix bug in MergeCompositeFunctions (#14117)
Ubospica Feb 24, 2023
c7d2c38
[Unity][BlockBuilder] Add `name_hint` argument for `emit` and `emit_o…
SiriusNEO Feb 25, 2023
57c86eb
[Unity][WEB] Relax vm on web runtime (#14131)
tqchen Feb 25, 2023
61c2761
[Unity] Add Global info (#14132)
jinhongyii Feb 26, 2023
368d9f6
[Unity][BYOC] Add transposed matmul support to Relax CUTLASS BYOC (#1…
yelite Feb 27, 2023
4713b52
[Unity][TVMScript] emit_te sugar (#14123)
yongwww Feb 27, 2023
28c6825
[Unity][BYOC] Assign group to unused bindings and ignroe PrimFunc (#1…
vinx13 Feb 27, 2023
63ce37d
[Unity] Add callback to FuseOpsByPattern to check match result is acc…
vinx13 Feb 27, 2023
7a5d313
[Unity][Legalize] Fix Scalar Constant Legalization (#14127)
zxybazh Feb 28, 2023
c29ac7e
[Unity][Pass] Enhance constant folding to fold relax ops by evaluatin…
psrivas2 Feb 28, 2023
4c085d2
[Unity][Debugging] AST printer (#14152)
slyubomirsky Mar 1, 2023
8a1e623
[Unity][Pass] Support Symbolic Shape Deduction during BindParam (#14154)
Hzfengsy Mar 1, 2023
67659ac
[Unity][Analysis] Checking function return struct info in well-formed…
Hzfengsy Mar 1, 2023
e57f591
[Unity][BYOC] Use Relax legalize + CPU build for reference in tests (…
masahi Mar 1, 2023
3fa880a
[Unity] Add bind_constants option to FuseOpsByPattern (#14151)
vinx13 Mar 1, 2023
6ee79e1
[Unity][Analysis] Analysis for detecting recursion in Relax (#14149)
slyubomirsky Mar 1, 2023
8423811
[Unity][BYOC] Add batch matmul support to Relax CUTLASS BYOC (#14166)
yelite Mar 2, 2023
781bfe0
[Unity][Op] Full support of Relax op `power` (#14171)
SiriusNEO Mar 2, 2023
475f3c2
[Unity][Analysis] Restore Python bindings for var analyses (#14180)
slyubomirsky Mar 4, 2023
38315af
[Unity][OP] Add an operator for fused multi head attention (#14150)
cyx-6 Mar 4, 2023
ed5367d
[Unity][WEBGPU] Codegen improvements and WebRuntime (#14187)
tqchen Mar 4, 2023
1f04221
[Unity][Transform] LiftTransformParams handling multiple functions (#…
MasterJH5574 Mar 4, 2023
22b65bc
[Unity][Op] Group normalization (#14194)
MasterJH5574 Mar 4, 2023
88ab730
[Unity][Op] Argmax and argmin (#14195)
MasterJH5574 Mar 5, 2023
0e98e6e
[Unity][Op] Legalize `round`, `floor`, `ceil`, `sign` (#14198)
MasterJH5574 Mar 5, 2023
284b278
[Unity][Frontend] FX translator supporting more ops (#14196)
MasterJH5574 Mar 5, 2023
f7f24b7
[Unity][Frontend] FX translator returning weights with `keep_params_a…
MasterJH5574 Mar 5, 2023
d103ee2
[Unity][Fix] FX translating dtype (#14201)
MasterJH5574 Mar 5, 2023
8f7c343
[Unity][TIR][Pass] ForceNarrowIndexToInt32 (#14203)
MasterJH5574 Mar 6, 2023
6de551b
[Unity][Frontend] FX translator support torch.baddbmm (#14202)
MasterJH5574 Mar 6, 2023
bb34d97
[CI] Point cpu ci to dep with onnx (#40)
Mar 6, 2023
0937202
[Unity] Introduce Default GPU Schedule Pass (#14182)
zxybazh Mar 6, 2023
1878d7b
Merge with upstream Unity
Mar 6, 2023
b12320b
Remove now unnecessary ScheduleForTarget pass.
Mar 6, 2023
627fb0a
Go back to standard ci_cpu image
Mar 6, 2023
3d72050
Refactor importer locations for consistency and cleaner import.
Mar 7, 2023
1dd7a56
Add initial octoml utility functions.
Feb 21, 2023
d843a77
Add gpu target extraction
Feb 21, 2023
ea79707
Start compilation helper file.
Feb 22, 2023
dfe37fd
Add entrypoint compile function.
Feb 27, 2023
ab05b8c
Add OctoModel helper class and testing.
Feb 27, 2023
44f4742
Add full feature support and testing.
Feb 27, 2023
61e03a4
Chatgpt refactoring of my regular expressions.
Feb 28, 2023
e49e924
Cleanup lint issues.
Feb 28, 2023
e7ab17e
API cleanups after rebase.
Feb 28, 2023
dfc0313
Simplify cuda thread binding until Xiyous full pass lands.
Mar 2, 2023
ee05ae0
Importer improvements such that full flow works end to end
Mar 2, 2023
b4d172f
Improvements after merge with main
Mar 6, 2023
57cf2ff
Merge branch 'relax' into TUZ-150
Mar 7, 2023
f5b86c1
Improve tests
Mar 7, 2023
447bc59
Lint cleanup
Mar 8, 2023
e68e731
Add octo tests to CI
Mar 8, 2023
022042c
Incorporate feedback
Mar 8, 2023
b9813e8
Fix lint
Mar 8, 2023
b5777e1
Fix small type bug
Mar 8, 2023
dfc2931
Fix test target
Mar 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/tvm/octo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, wrong-import-position, redefined-builtin
"""OctoML Simplified API utilities."""

from . import utils
from .compile import compile
from .octo_model import OctoModel
170 changes: 170 additions & 0 deletions python/tvm/octo/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, wrong-import-position, redefined-builtin, not-callable
"""Simplified interface for TVM Unity Flow."""
from pathlib import Path
from typing import Union, Optional, Dict, List
import onnx
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from .utils import get_cuda_target, get_llvm_target
from .octo_model import OctoModel


def load_onnx_model(
model_file: Union[str, Path, onnx.ModelProto], shape_dict: Optional[Dict[str, List]] = None
) -> tvm.IRModule:
"""Convert an input onnx model into a relax module.

Parameters
----------
model_file : Union[str, Path, onnx.ModelProto]
An input onnx model to convert. Can either be a path to a model or an already
loaded onnx protobuf.

shape_dict : Optional[Dict[str, List]]
An optional dictionary that maps inputs to specific shapes. If not provided,
the default values in the onnx graph will be used.

Returns
-------
relax_mod : tvm.IRModule
A Relax module implementing the input onnx graph.
"""
# Check input format and load if needed.
if isinstance(model_file, (Path, str)):
model_file = onnx.load(model_file)
else:
assert isinstance(
model_file, onnx.ModelProto
), f"model_file must be one of (str, Path, onnx.ModelProto) but got {type(model_file)})"

# Convert the graph into a relax implementation.
relax_mod = from_onnx(model_file, shape_dict=shape_dict)

return relax_mod


def offload_cutlass(mod: tvm.IRModule, target: tvm.target.Target) -> tvm.IRModule:
"""Converts appropriate subgraphs to CUTLASS

Parameters
----------
mod : tvm.IRModule
The input module that should have subgraphs rewritten to CUTLASS.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to introspect the output and see what was offloaded (and why?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add IR visitor after partition / codegen pass to collect the information of the lifted subgraphs

target : tvm.target.Target
The target used for compilation. Needed to parameterize CUTLASS.

Returns
-------
cutlass_mod : tvm.IRModule
The input module after the partition_for_cutlass and RunCodegen passes
are applied. In the first step, subgraphs that cutlass supports are
found and annotated. Next, those subgraphs are compiled using nvcc.
The result is a graph containing a mixture of relax operators
and external calls to the compiled cutlass kernels.
"""
# Extract the sm version of the current target.
assert target.arch, "Target architecture must be specified."
sm = int(target.arch.split("_")[1])
# Cutlass only has support up to sm80, future sms will work with
# earlier kernels though.
if sm > 80:
sm = 80

# Apply partitioning to offload patterns to cutlass.
mod = partition_for_cutlass(mod)

# Construct CUTLASS codegen pass.
cutlass_codegen_pass = relax.transform.RunCodegen(
{"cutlass": {"sm": sm, "find_first_valid": True}}
)

# Generate code for matched cutlass kernels.
mod = cutlass_codegen_pass(mod)
return mod


def compile(
model: Union[str, Path, onnx.ModelProto],
target: Optional[tvm.target.Target] = None,
shape_dict: Optional[Dict[str, List]] = None,
):
"""Entrypoint to compiling a model using the Unity Flow.

Parameters
----------
model : Union[str, Path, onnx.ModelProto]
An input onnx model to convert. Can either be a path to a model or an already
loaded onnx protobuf.

target : Optional[tvm.target.Target]
A description of the hardware to compile to. If not provided, one will be extracted for
the current host machine.

shape_dict : Optional[Dict[str, List]]
An optional dictionary that maps inputs to specific shapes. If not provided,
the default values in the onnx graph will be used.

Returns
-------
octo_model: OctoModel
A convenience wrapper around the compiled model that provides utility functions.
"""
# Determine current target.
if target is None:
# Check if this is gpu enabled.
if tvm.cuda(0).exist:
target = get_cuda_target()
else:
target = get_llvm_target()
print(f"Auto-selected target {target}")

# Convert model into a relax module.
relax_mod = load_onnx_model(model, shape_dict)

# Extract information about input shapes and types so we can
# randomly generate them later if needed.
input_info = {}
for inp in relax_mod["main"].params:
input_shape = [i.value for i in inp.struct_info.shape]
input_dtype = inp.struct_info.dtype
input_info[inp.name_hint] = (input_shape, input_dtype)

# If target is gpu and compiled with Cutlass, offload where possible.
if target.kind.name == "cuda":
if tvm.get_global_func("relax.ext.cutlass", True):
# Match subgraphs that can be offloaded to cutlass and offload them.
relax_mod = offload_cutlass(relax_mod, target)
else:
print("Cutlass backend not detected. Consider enabling it for better performance.")

# Perform legalization to lower Relax operators.
relax_mod = relax.transform.LegalizeOps()(relax_mod)

# Schedule all remaining functions to be compatible with gpu if needed.
if target.kind.name == "cuda":
with target, tvm.transform.PassContext(opt_level=3):
relax_mod = tvm.tir.transform.DefaultGPUSchedule()(relax_mod)

# Compile the module.
exe = relax.build(relax_mod, target)

# Create an OctoModel from the compiled artifact.
return OctoModel(exe, input_info, target=target)
195 changes: 195 additions & 0 deletions python/tvm/octo/octo_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, wrong-import-position
"""Wrapper class for compiled models."""
import json
import tarfile
from pathlib import Path
from typing import Optional, Union, Dict, Tuple, List
import numpy as np
import tvm
from tvm import relax
from tvm.contrib import utils


class OctoModel(object):
"""A compiled model wrapper that provides helpful utilities.
jwfromm marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
exe : Optional[relax.Executable]
A compiled executable that can be loaded and run by a relax VM.
input_info : Optional[Dict[str, Tuple[List, str]]]
Information about the input names, shapes, and types for the VM.
Will be loaded from memory if possible.
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
model_path : Optional[Union[str, Path]]
The path to a saved OctoModel, one of exe and model_path must
be specified.
target : Optional[tvm.target.Target]
The target being compiled for.
"""

def __init__(
self,
exe: Optional[relax.Executable] = None,
input_info: Optional[Dict[str, Tuple[List, str]]] = None,
model_path: Optional[Union[str, Path]] = None,
target: Optional[tvm.target.Target] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this target have to be the same as the one that produced the relax.Executable?

):
self.target = target

if exe is None and model_path is None:
raise ValueError("One of vm and model_path must be provided.")

self._tmp_dir = utils.tempdir()

if model_path is not None:
exe, input_info = self.load(model_path)

self.dev = tvm.device(self.target.get_target_device_type())
self.exe = exe
self.input_info = input_info

# Create a vm from exe.
self.vm = relax.VirtualMachine(self.exe, self.dev, profile=True)
jwfromm marked this conversation as resolved.
Show resolved Hide resolved

def save(
self, model_path: Union[str, Path]
) -> Tuple[relax.Executable, Dict[str, relax.StructInfo]]:
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
"""Save the OctoModel to disk.

The current format used is a simple tar of the exported model library (exe.so),
the input information of the model (input_info.json), and a metadata
file containing strings such as the target.

Parameters
----------
model_path : Union[str, Path]
A full path to save this OctoModel to including the output file name.
The file will be saved as a tar file so using a ".tar" extension is advised.
"""
# Only two artifacts need to be saved, the exe and the input struct info.
# Serialize both to a temp directory.
exe_path = self._tmp_dir.relpath("exe.so")
self.exe.mod.export_library(exe_path)
input_info_path = self._tmp_dir.relpath("input_info.json")
with open(input_info_path, "w") as fo:
json.dump(self.input_info, fo)

# Save additional metadata.
metadata = {"target": str(self.target)}
metadata_path = self._tmp_dir.relpath("metadata.json")
with open(metadata_path, "w") as fo:
json.dump(metadata, fo)

# Tar the tempfile and save to the designated model_path.
with tarfile.open(model_path, "w") as tar:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we recycle Model Library Format here to avoid creating two of these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outlining the format would be a good idea in any case, and it would be even better to reuse an existing format if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using MLF is a good idea but not currently applicable since it is tied to micro and would need to be extended to support Relax. I'd be in favor of bumping MLF to a tvm namescoped standard that can be used more broadly. For now, I'll just note the format more clearly in the comments and we can change later.

tar.add(exe_path, "exe.so")
tar.add(input_info_path, "input_info.json")
tar.add(metadata_path, "metadata.json")

def load(self, model_path: Union[str, Path]) -> Tuple[relax.Executable, Dict[List, str]]:
"""Load a saved OctoModel back into memory.

Parameters
----------
model_path : Union[str, Path]
The path to the saved OctoModel that will be loaded.

Returns
-------
exe : relax.Executable
A compiled executable that can be loaded and run by a relax VM.
input_info : Dict[str, Tuple[List, str]]
Information about the input names, shapes, and types for the VM.
Will be loaded from memory if possible.
"""
t = tarfile.open(model_path)
t.extractall(self._tmp_dir.relpath("."))

# Load executable.
exe_path = self._tmp_dir.relpath("exe.so")
exe = relax.Executable(tvm.runtime.load_module(exe_path))

# Load input info.
input_info_path = self._tmp_dir.relpath("input_info.json")
with open(input_info_path, "r") as fi:
input_info = json.load(fi)

# load other metadata.
metadata_path = self._tmp_dir.relpath("metadata.json")
with open(metadata_path, "r") as fi:
metadata = json.load(fi)
self.target = tvm.target.Target(metadata["target"])

return exe, input_info

def generate_inputs(self) -> Dict[str, np.array]:
"""Generate random inputs for inference or benchmarking
jwfromm marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
input_dict : Dict[str, np.array]
"""
input_dict = {}
for name, (shape, dtype) in self.input_info.items():
input_dict[name] = np.random.normal(size=shape).astype(dtype)
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
return input_dict

def run(self, inputs: Optional[Dict[str, np.array]] = None) -> List[np.array]:
"""Perform an inference of the model.

Parameters
----------
inputs : Optional[Dict[str, np.array]]
An optional input dictionary containing the values to perform
inference with. If not provided, random values will be generated
instead.

Returns
-------
outputs : List[np.array]
The output values from the inference.
"""
# Generate random inputs if none are provided.
if inputs is None:
inputs = self.generate_inputs()

# Assign inputs.
self.vm.set_input("main", **inputs)
# Run the modeel.
self.vm.invoke_stateful("main")
# Get and return the outputs.
outputs = self.vm.get_outputs("main")
if isinstance(outputs, tuple):
outputs = [output.numpy() for output in outputs]
else:
outputs = [outputs.numpy()]
return outputs

def profile(self) -> tvm.runtime.profiling.Report:
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
"""Measures the model's performance.

Returns
-------
report : tvm.runtime.profiling.Report
A breakdown of the runtime and per layer metrics.
"""
inputs = self.generate_inputs()
self.vm.set_input("main", **inputs)
report = self.vm.profile("main")
return report
Loading