forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TUTORIAL] Move mobilenet to tutorial, fix precompute_prune (apache#35)
* [TUTORIAL] Move mobilenet to tutorial, fix precompute_prune * Some language improvements
- Loading branch information
Showing
15 changed files
with
271 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
doxygen | ||
_build | ||
gen_modules | ||
tutorials |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
The documentation of nnvm is generated with recommonmark and sphinx. | ||
|
||
- pip install sphinx>=1.5.5 sphinx-gallery sphinx_rtd_theme matplotlib Image recommonmark | ||
- Build tvm first in the root folder. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,4 +10,5 @@ Contents | |
|
||
self | ||
top | ||
tutorials/index | ||
dev/index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
NNVM Examples | ||
============= | ||
This folder contains example snippets of running NNVM Compilation. | ||
|
||
- See also [Tutorials](tutorials) for tutorials with detailed explainations. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
"""Utilities for testcase""" | ||
"""Utilities for testing and benchmarks""" | ||
from __future__ import absolute_import as _abs | ||
|
||
from .config import ctx_list | ||
from . import mobilenet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
"""Configuration about tests""" | ||
from __future__ import absolute_import as _abs | ||
|
||
import os | ||
import tvm | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""Helper utility to get mobilenet workload for testing.""" | ||
# pylint: disable=invalid-name | ||
from __future__ import absolute_import as _abs | ||
|
||
import numpy as np | ||
import tvm | ||
from .. compiler import graph_util | ||
from .. import graph | ||
from .. import symbol as sym | ||
|
||
def conv_block(data, name, channels, | ||
kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), | ||
epsilon=1e-5): | ||
"""Helper function to construct conv-bn-relu""" | ||
# convolution + bn + relu | ||
conv = sym.conv2d(data=data, channels=channels, | ||
kernel_size=kernel_size, strides=strides, | ||
padding=padding, use_bias=False, | ||
layout="NCHW", name=name + "_conv") | ||
bn = sym.batch_norm(data=conv, epsilon=epsilon, name=name + "_bn") | ||
act = sym.relu(data=bn, name=name + "_relu") | ||
return act | ||
|
||
def separable_conv_block(data, name, depthwise_channels, | ||
pointwise_channels, kernel_size=(3, 3), | ||
downsample=False, padding=(1, 1), | ||
epsilon=1e-5): | ||
"""Helper function to get a separable conv block""" | ||
if downsample: | ||
strides = (2, 2) | ||
else: | ||
strides = (1, 1) | ||
# depthwise convolution + bn + relu | ||
conv1 = sym.conv2d(data=data, channels=depthwise_channels, | ||
groups=depthwise_channels, kernel_size=kernel_size, strides=strides, | ||
padding=padding, use_bias=False, layout="NCHW", name=name + "_conv1") | ||
bn1 = sym.batch_norm(data=conv1, epsilon=epsilon, name=name + "_bn1") | ||
act1 = sym.relu(data=bn1, name=name + "_relu1") | ||
# pointwise convolution + bn + relu | ||
conv2 = sym.conv2d(data=act1, channels=pointwise_channels, kernel_size=(1, 1), strides=(1, 1), | ||
padding=(0, 0), use_bias=False, layout="NCHW", name=name + "_conv2") | ||
bn2 = sym.batch_norm(data=conv2, epsilon=epsilon, name=name + "_bn2") | ||
act2 = sym.relu(data=bn2, name=name + "_relu2") | ||
return act2 | ||
|
||
def mobile_net(num_classes=1000, alpha=1.0, is_shallow=False): | ||
"""Function to construct a MobileNet""" | ||
data = sym.Variable("data") | ||
body = conv_block(data, "conv_block_1", int(32*alpha), strides=(2, 2)) | ||
body = separable_conv_block(body, "separable_conv_block_1", | ||
int(32*alpha), int(64*alpha)) | ||
body = separable_conv_block(body, "separable_conv_block_2", | ||
int(64*alpha), int(128*alpha), downsample=True) | ||
body = separable_conv_block(body, "separable_conv_block_3", | ||
int(128*alpha), int(128*alpha)) | ||
body = separable_conv_block(body, "separable_conv_block_4", | ||
int(128*alpha), int(256*alpha), downsample=True) | ||
body = separable_conv_block(body, "separable_conv_block_5", | ||
int(256*alpha), int(256*alpha)) | ||
body = separable_conv_block(body, "separable_conv_block_6", | ||
int(256*alpha), int(512*alpha), downsample=True) | ||
if is_shallow: | ||
body = separable_conv_block(body, "separable_conv_block_7", | ||
int(512*alpha), int(1024*alpha), downsample=True) | ||
body = separable_conv_block(body, "separable_conv_block_8", | ||
int(1024*alpha), int(1024*alpha)) | ||
else: | ||
for i in range(7, 12): | ||
body = separable_conv_block(body, "separable_conv_block_%d" % i, | ||
int(512*alpha), int(512*alpha)) | ||
body = separable_conv_block(body, "separable_conv_block_12", | ||
int(512*alpha), int(1024*alpha), downsample=True) | ||
body = separable_conv_block(body, "separable_conv_block_13", | ||
int(1024*alpha), int(1024*alpha)) | ||
pool = sym.global_avg_pool2d(data=body, name="pool") | ||
flatten = sym.flatten(data=pool, name="flatten") | ||
fc = sym.dense(data=flatten, units=num_classes, use_bias=False, name="fc") | ||
softmax = sym.softmax(data=fc, name="softmax") | ||
return softmax | ||
|
||
|
||
def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype="float32"): | ||
"""Get benchmark workload for mobilenet | ||
Parameters | ||
---------- | ||
batch_size : int | ||
The batch size used in the model | ||
num_classes : int, optional | ||
Number of claseses | ||
image_shape : tuple, optional | ||
The input image shape | ||
dtype : str, optional | ||
The data type | ||
Returns | ||
------- | ||
net : nnvm.Symbol | ||
The computational graph | ||
params : dict of str to NDArray | ||
The parameters. | ||
""" | ||
image_shape = (3, 224, 224) | ||
data_shape = (batch_size,) + image_shape | ||
net = mobile_net(num_classes=num_classes, alpha=1.0, is_shallow=False) | ||
params = {} | ||
g = graph.create(net) | ||
input_shapes, _ = graph_util.infer_shape(g, data=data_shape) | ||
shape_dict = dict(zip(g.index.input_names, input_shapes)) | ||
for k, v in shape_dict.items(): | ||
if k == "data": | ||
continue | ||
# Specially generate non-negative parameters. | ||
if k.endswith("gamma"): | ||
init = np.random.uniform(0.9, 1, size=v) | ||
elif k.endswith("var"): | ||
init = np.random.uniform(0.9, 1, size=v) | ||
else: | ||
init = np.random.uniform(-0.1, 0.1, size=v) | ||
params[k] = tvm.nd.array(init.astype(dtype), ctx=tvm.cpu(0)) | ||
return net, params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.