Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNVM] Add symbol squeezenet #1436

Merged
merged 5 commits into from
Jul 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nnvm/python/nnvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import mlp
from . import resnet
from . import vgg
from . import squeezenet
from . import dcgan
from . import dqn
from . import yolo2_detection
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/testing/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_workload(batch_size, num_classes=1000, image_shape=(3, 224, 224), dtype=
The batch size used in the model

num_classes : int, optional
Number of claseses
Number of classes

image_shape : tuple, optional
The input image shape
Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/testing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_workload(batch_size=1, num_classes=1000, num_layers=18,
The batch size used in the model

num_classes : int, optional
Number of claseses
Number of classes

num_layers : int, optional
Number of layers
Expand Down
132 changes: 132 additions & 0 deletions nnvm/python/nnvm/testing/squeezenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=unused-argument

"""
Symbol of SqueezeNet

Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""

from .. import symbol as sym
from . utils import create_workload

# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)

left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = sym.concatenate(left, right, axis=1)

return net

def _make_fire_conv(net, channels, kernel_size, padding=0):
net = sym.conv2d(net, channels=channels, kernel_size=(kernel_size, kernel_size),
padding=(padding, padding))
net = sym.relu(net)
return net

# Net
def get_symbol(num_classes, version, **kwargs):
"""Get symbol of SqueezeNet

Parameters
----------
num_classes: int
The number of classification results

version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = sym.Variable("data")
if version == '1.0':
net = sym.conv2d(net, channels=96, kernel_size=(7, 7), strides=(2, 2), padding=(3, 3))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = sym.conv2d(net, channels=64, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1))
net = sym.relu(net)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = sym.max_pool2d(net, pool_size=(3, 3), strides=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = sym.dropout(net, rate=0.5)
net = sym.conv2d(net, channels=num_classes, kernel_size=(1, 1))
net = sym.relu(net)
net = sym.global_avg_pool2d(net)
net = sym.flatten(net)
return sym.softmax(net)

def get_workload(batch_size=1, num_classes=1000, version='1.0',
image_shape=(3, 224, 224), dtype="float32", **kwargs):
"""Get benchmark workload for resnet

Parameters
----------
batch_size : int
The batch size used in the model

num_classes : int, optional
Number of classes

version : str, optional
"1.0" or "1.1" of SqueezeNet

image_shape : tuple, optional
The input image shape

dtype : str, optional
The data type

kwargs : dict
Extra arguments

Returns
-------
net : nnvm.Symbol
The computational graph

params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_classes=num_classes, version=version, **kwargs)
return create_workload(net, batch_size, image_shape, dtype)
12 changes: 10 additions & 2 deletions nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg, dqn, dcgan
from . import mlp, resnet, vgg, dqn, dcgan, squeezenet
import nnvm.testing

__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg',
'mx_squeezenet', 'nnvm_squeezenet']

_num_class = 1000

Expand All @@ -27,6 +28,13 @@
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
1, _num_class, num_layers=num_layer)[0]

# squeezenet
mx_squeezenet = {}
nnvm_squeezenet = {}
for version in ['1.0', '1.1']:
mx_squeezenet[version] = squeezenet.get_symbol(version=version)
nnvm_squeezenet[version] = nnvm.testing.squeezenet.get_workload(1, version=version)[0]

# dqn
mx_dqn = dqn.get_symbol()
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]
Expand Down
76 changes: 76 additions & 0 deletions nnvm/tests/python/frontend/mxnet/model_zoo/squeezenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Symbol of SqueezeNet

Reference:
Iandola, Forrest N., et al.
"Squeezenet: Alexnet-level accuracy with 50x fewer parameters and< 0.5 mb model size." (2016).
"""

import mxnet as mx

# Helpers
def _make_fire(net, squeeze_channels, expand1x1_channels, expand3x3_channels):
net = _make_fire_conv(net, squeeze_channels, 1, 0)

left = _make_fire_conv(net, expand1x1_channels, 1, 0)
right = _make_fire_conv(net, expand3x3_channels, 3, 1)
# NOTE : Assume NCHW layout here
net = mx.sym.concat(left, right, dim=1)

return net

def _make_fire_conv(net, channels, kernel_size, padding=0):
net = mx.sym.Convolution(net, num_filter=channels, kernel=(kernel_size, kernel_size),
pad=(padding, padding))
net = mx.sym.Activation(net, act_type='relu')
return net

# Net
def get_symbol(num_classes=1000, version='1.0', **kwargs):
"""Get symbol of SqueezeNet

Parameters
----------
num_classes: int
The number of classification results

version : str, optional
"1.0" or "1.1" of SqueezeNet
"""
assert version in ['1.0', '1.1'], ("Unsupported SqueezeNet version {version}:"
"1.0 or 1.1 expected".format(version=version))
net = mx.sym.Variable("data")
if version == '1.0':
net = mx.sym.Convolution(net, num_filter=96, kernel=(7, 7), stride=(2, 2), pad=(3, 3))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 64, 256, 256)
else:
net = mx.sym.Convolution(net, num_filter=64, kernel=(3, 3), stride=(2, 2), pad=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 16, 64, 64)
net = _make_fire(net, 16, 64, 64)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 32, 128, 128)
net = _make_fire(net, 32, 128, 128)
net = mx.sym.Pooling(data=net, kernel=(3, 3), pool_type='max', stride=(2, 2))
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 48, 192, 192)
net = _make_fire(net, 64, 256, 256)
net = _make_fire(net, 64, 256, 256)
net = mx.sym.Dropout(net, p=0.5)
net = mx.sym.Convolution(net, num_filter=num_classes, kernel=(1, 1))
net = mx.sym.Activation(net, act_type='relu')
net = mx.sym.Pooling(data=net, global_pool=True, kernel=(13, 13), pool_type='avg')
net = mx.sym.flatten(net)
return mx.sym.softmax(net)
8 changes: 8 additions & 0 deletions nnvm/tests/python/frontend/mxnet/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def test_resnet():
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)

def test_squeezenet():
for version in ['1.0', '1.1']:
mx_sym = model_zoo.mx_squeezenet[version]
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_squeezenet[version]
compare_graph(from_mx_sym, nnvm_sym)

def test_dqn():
mx_sym = model_zoo.mx_dqn
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
Expand Down Expand Up @@ -62,3 +69,4 @@ def compose(F, **kwargs):
test_multi_outputs()
test_dqn()
test_dcgan()
test_squeezenet()
2 changes: 1 addition & 1 deletion tutorials/autotvm/tune_cuda_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# ---------------------------------
# There are plenty of useful schedule primitives in tvm. You can also find
# some tutorials that describe them in more details, such as
# (1). :doc:``Optimizing Conv2d on NVIDIA GPU <../optimize/opt_conv_cuda>`
# (1). :ref:`opt-conv-gpu`
# (2). `Optimizing DepthwiseConv on NVIDIA GPU <https://tvm.ai/2017/08/22/Optimize-Deep-Learning-GPU-Operators-with-TVM-A-Depthwise-Convolution-Example.html>`_
#
# However, their implementations are manually tuned for some special input
Expand Down
2 changes: 1 addition & 1 deletion tutorials/nnvm/imagenet_inference_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# To get the maximum performance, we need to enable nvcc's compiler hook.
# This usually gives better performance than nvrtc mode.

@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
Expand Down
5 changes: 4 additions & 1 deletion tutorials/optimize/opt_conv_cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""How to optimize convolution on GPU
"""
.. _opt-conv-gpu:

How to optimize convolution on GPU
==================================
**Author**: `Haichen Shen <https://homes.cs.washington.edu/~haichen/>`_

Expand Down