-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from zjuter0126/alexnet
Alexnet TinyMs XJY
- Loading branch information
1 parent
fa93595
commit d5aaaac
Showing
5 changed files
with
537 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright 2020-2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""alexnet Tutorial | ||
The sample can be run on CPU, GPU and Ascend 910 AI processor. | ||
""" | ||
import random | ||
import argparse | ||
|
||
from tinyms import context | ||
from tinyms.data import Cifar10Dataset, download_dataset | ||
from tinyms.vision import cifar10_transform | ||
from tinyms.model import Model, AlexNet | ||
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor | ||
from tinyms.metrics import Accuracy | ||
from tinyms.optimizers import Momentum | ||
from tinyms.losses import SoftmaxCrossEntropyWithLogits | ||
|
||
random.seed(1) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Image classification') | ||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], | ||
help='device where the code will be implemented (default: CPU)') | ||
parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.') | ||
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') | ||
parser.add_argument('--epoch_size', type=int, default=90, help='Epoch size.') | ||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.') | ||
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.') | ||
parser.add_argument('--save_checkpoint_epochs', type=int, default=5, | ||
help='Specify epochs interval to save each checkpoints.') | ||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path.') | ||
args_opt = parser.parse_args() | ||
|
||
return args_opt | ||
|
||
|
||
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1, | ||
is_training=True): | ||
""" create Cifar10 dataset for train or eval. | ||
Args: | ||
data_path: Data path | ||
batch_size: The number of data records in each group | ||
repeat_size: The number of replicated data records | ||
num_parallel_workers: The number of parallel workers | ||
""" | ||
# define dataset and apply the transform func | ||
cifar_ds = Cifar10Dataset(data_path, num_parallel_workers=num_parallel_workers, | ||
shuffle=True) | ||
cifar_ds = cifar10_transform.apply_ds(cifar_ds, | ||
repeat_size=repeat_size, | ||
batch_size=batch_size, | ||
num_parallel_workers=num_parallel_workers, | ||
is_training=is_training) | ||
|
||
return cifar_ds | ||
|
||
|
||
if __name__ == '__main__': | ||
args_opt = parse_args() | ||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | ||
|
||
# download cifar10 dataset | ||
if not args_opt.dataset_path: | ||
args_opt.dataset_path = download_dataset('cifar10') | ||
# build the network | ||
net = AlexNet(args_opt.num_classes) | ||
net.update_parameters_name(prefix='huawei') | ||
model = Model(net) | ||
# define the loss function | ||
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||
# define the optimizer | ||
net_opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | ||
model.compile(loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": Accuracy()}) | ||
|
||
epoch_size = args_opt.epoch_size | ||
batch_size = args_opt.batch_size | ||
cifar10_path = args_opt.dataset_path | ||
save_checkpoint_epochs = args_opt.save_checkpoint_epochs | ||
dataset_sink_mode = not args_opt.device_target == "CPU" | ||
if args_opt.do_eval: # as for evaluation, users could use model.eval | ||
ds_eval = create_dataset(cifar10_path, batch_size=batch_size, is_training=False) | ||
if args_opt.checkpoint_path: | ||
model.load_checkpoint(args_opt.checkpoint_path) | ||
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode) | ||
print("============== Accuracy:{} ==============".format(acc)) | ||
else: # as for train, users could use model.train | ||
ds_train = create_dataset(cifar10_path, batch_size=batch_size) | ||
ckpoint_cb = ModelCheckpoint(prefix="alexnet_cifar10", config=CheckpointConfig( | ||
save_checkpoint_steps=save_checkpoint_epochs * ds_train.get_dataset_size(), | ||
keep_checkpoint_max=10)) | ||
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], | ||
dataset_sink_mode=dataset_sink_mode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright 2020-2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""densenetBC_100 Tutorial | ||
The sample can be run on CPU, GPU and Ascend 910 AI processor. | ||
""" | ||
import random | ||
import argparse | ||
|
||
from tinyms import context | ||
from tinyms.data import Cifar10Dataset, download_dataset | ||
from tinyms.vision import cifar10_transform | ||
from tinyms.model import Model, densenetBC_100 | ||
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor | ||
from tinyms.metrics import Accuracy | ||
from tinyms.optimizers import Momentum | ||
from tinyms.losses import SoftmaxCrossEntropyWithLogits | ||
|
||
random.seed(1) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Image classification') | ||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], | ||
help='device where the code will be implemented (default: CPU)') | ||
parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.') | ||
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') | ||
parser.add_argument('--epoch_size', type=int, default=90, help='Epoch size.') | ||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.') | ||
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.') | ||
parser.add_argument('--save_checkpoint_epochs', type=int, default=5, | ||
help='Specify epochs interval to save each checkpoints.') | ||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path.') | ||
args_opt = parser.parse_args() | ||
|
||
return args_opt | ||
|
||
|
||
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1, | ||
is_training=True): | ||
""" create Cifar10 dataset for train or eval. | ||
Args: | ||
data_path: Data path | ||
batch_size: The number of data records in each group | ||
repeat_size: The number of replicated data records | ||
num_parallel_workers: The number of parallel workers | ||
""" | ||
# define dataset and apply the transform func | ||
cifar_ds = Cifar10Dataset(data_path, num_parallel_workers=num_parallel_workers, | ||
shuffle=True) | ||
cifar_ds = cifar10_transform.apply_ds(cifar_ds, | ||
repeat_size=repeat_size, | ||
batch_size=batch_size, | ||
num_parallel_workers=num_parallel_workers, | ||
is_training=is_training) | ||
|
||
return cifar_ds | ||
|
||
|
||
if __name__ == '__main__': | ||
args_opt = parse_args() | ||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | ||
|
||
# download cifar10 dataset | ||
if not args_opt.dataset_path: | ||
args_opt.dataset_path = download_dataset('cifar10') | ||
# build the network | ||
net = densenetBC_100(args_opt.num_classes) | ||
net.update_parameters_name(prefix='huawei') | ||
model = Model(net) | ||
# define the loss function | ||
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||
# define the optimizer | ||
net_opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | ||
model.compile(loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": Accuracy()}) | ||
|
||
epoch_size = args_opt.epoch_size | ||
batch_size = args_opt.batch_size | ||
cifar10_path = args_opt.dataset_path | ||
save_checkpoint_epochs = args_opt.save_checkpoint_epochs | ||
dataset_sink_mode = not args_opt.device_target == "CPU" | ||
if args_opt.do_eval: # as for evaluation, users could use model.eval | ||
ds_eval = create_dataset(cifar10_path, batch_size=batch_size, is_training=False) | ||
if args_opt.checkpoint_path: | ||
model.load_checkpoint(args_opt.checkpoint_path) | ||
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode) | ||
print("============== Accuracy:{} ==============".format(acc)) | ||
else: # as for train, users could use model.train | ||
ds_train = create_dataset(cifar10_path, batch_size=batch_size) | ||
ckpoint_cb = ModelCheckpoint(prefix="densenetBC_100_cifar10", config=CheckpointConfig( | ||
save_checkpoint_steps=save_checkpoint_epochs * ds_train.get_dataset_size(), | ||
keep_checkpoint_max=10)) | ||
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], | ||
dataset_sink_mode=dataset_sink_mode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
|
||
|
||
import numpy as np | ||
from scipy.stats import truncnorm | ||
|
||
import tinyms as ts | ||
from tinyms import layers, Tensor | ||
from tinyms.layers import ReLU, MaxPool2d, Flatten, Dropout | ||
|
||
|
||
|
||
|
||
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): | ||
fan_in = in_channel * kernel_size * kernel_size | ||
scale = 1.0 | ||
scale /= max(1., fan_in) | ||
stddev = (scale ** 0.5) / .87962566103423978 | ||
mu, sigma = 0, stddev | ||
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) | ||
return ts.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) | ||
|
||
|
||
def _weight_variable(shape, factor=0.01): | ||
init_value = np.random.randn(*shape).astype(np.float32) * factor | ||
return Tensor(init_value) | ||
|
||
|
||
def _conv3x3(in_channel, out_channel, stride=1): | ||
weight_shape = (out_channel, in_channel, 3, 3) | ||
weight = _weight_variable(weight_shape) | ||
return layers.Conv2d(in_channel, out_channel, | ||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) | ||
|
||
|
||
def _conv1x1(in_channel, out_channel, stride=1): | ||
weight_shape = (out_channel, in_channel, 1, 1) | ||
weight = _weight_variable(weight_shape) | ||
return layers.Conv2d(in_channel, out_channel, | ||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) | ||
|
||
|
||
def _conv7x7(in_channel, out_channel, stride=1): | ||
weight_shape = (out_channel, in_channel, 7, 7) | ||
weight = _weight_variable(weight_shape) | ||
return layers.Conv2d(in_channel, out_channel, | ||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) | ||
def _conv11x11(in_channel, out_channel, stride=1): | ||
weight_shape = (out_channel, in_channel, 11, 11) | ||
weight = _weight_variable(weight_shape) | ||
return layers.Conv2d(in_channel, out_channel, | ||
kernel_size=11, stride=stride, padding=2, pad_mode='pad', weight_init=weight) | ||
|
||
def _conv5x5(in_channel, out_channel, stride=1): | ||
weight_shape = (out_channel, in_channel, 5, 5) | ||
weight = _weight_variable(weight_shape) | ||
return layers.Conv2d(in_channel, out_channel, | ||
kernel_size=5, stride=stride, padding=2, pad_mode='pad', weight_init=weight) | ||
|
||
|
||
def _bn(channel): | ||
return layers.BatchNorm2d(channel, eps=1e-4, momentum=0.9, | ||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) | ||
|
||
|
||
def _bn_last(channel): | ||
return layers.BatchNorm2d(channel, eps=1e-4, momentum=0.9, | ||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) | ||
|
||
|
||
def _fc(in_channel, out_channel): | ||
weight_shape = (out_channel, in_channel) | ||
weight = _weight_variable(weight_shape) | ||
return layers.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) | ||
|
||
|
||
|
||
|
||
|
||
class AlexNet(layers.Layer): | ||
def __init__(self, class_num=1000): | ||
super(AlexNet, self).__init__() | ||
|
||
|
||
|
||
self.features = layers.SequentialLayer( | ||
[ | ||
_conv11x11(3, 64, 4), | ||
ReLU(), | ||
MaxPool2d(kernel_size=3, stride=2), | ||
_conv5x5(64, 192), | ||
ReLU(), | ||
MaxPool2d(kernel_size=3, stride=2), | ||
_conv3x3(192, 384), | ||
ReLU(), | ||
_conv3x3(384, 256), | ||
ReLU(), | ||
_conv3x3(256, 256), | ||
ReLU(), | ||
MaxPool2d(kernel_size=3, stride=2), | ||
Flatten(), | ||
Dropout(), | ||
_fc(256*6*6, 4096), | ||
ReLU(), | ||
Dropout(), | ||
_fc(4096, 4096), | ||
ReLU(), | ||
_fc(4096, class_num) | ||
] | ||
|
||
) | ||
def construct(self, x): | ||
x = self.features(x) | ||
return x | ||
|
||
def alexnet(class_num=10): | ||
""" | ||
Get AlexNet neural network. | ||
Args: | ||
class_num (int): Class number. | ||
Returns: | ||
layers.Layer, layer instance of AlexNet neural network. | ||
Examples: | ||
>>> net = AlexNet(class_num=10) | ||
""" | ||
return AlexNet(class_num=class_num) |
Oops, something went wrong.