Skip to content

Commit

Permalink
Merge pull request #17 from hellowaywewe/tests
Browse files Browse the repository at this point in the history
add mobilenet_v2 support
  • Loading branch information
leonwanghui authored Mar 1, 2021
2 parents a63522c + 1832b9c commit 4f80f52
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ dmypy.json
# auto-generated py file
*/version.py

**/.DS_Store
**/.DS_Store
124 changes: 124 additions & 0 deletions tests/st/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""MobileNetV2 Tutorial
The sample can be run on GPU and Ascend 910 AI processors
"""


import argparse

from tinyms import context
from tinyms.data import Cifar10Dataset, download_dataset
from tinyms.vision import cifar10_transform
from tinyms.model import Model, MobileNetV2
from tinyms.metrics import Accuracy
from tinyms.optimizers import Momentum
from tinyms.losses import SoftmaxCrossEntropyWithLogits, CrossEntropyWithLabelSmooth
from tinyms.utils.train.loss_manager import FixedLossScaleManager
from tinyms.utils.train.lr_generator import mobilenetv2_lr
from tinyms.utils.train.cb_config import mobilenetv2_cb


def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=4, training=True):
"""create Cifar10 dataset for train or eval.
Args:
data_path: Data path
batch_size: The number of data records in each group
repeat_size: The number of replicated data records
num_parallel_workers: The number of parallel workers
"""
# define cifar_10 dataset and apply the transform func
cifar10_ds = Cifar10Dataset(data_path,
num_parallel_workers=num_parallel_workers,
shuffle=True)
cifar10_ds = cifar10_transform.apply_ds(cifar10_ds,
repeat_size=repeat_size,
batch_size=batch_size,
training=training)

return cifar10_ds


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MobileNetV2 Image classification')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.')
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.')
parser.add_argument('--label_smooth', type=int, default=0.1, help='label smooth')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--epoch_size', type=int, default=100, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=150, help='Batch size.')
parser.add_argument('--is_saving_checkpoint', type=bool, default=True, help='Whether to save checkpoint.')
parser.add_argument('--save_checkpoint_epochs', type=int, default=10,
help='Specify epochs interval to save each checkpoints.')
parser.add_argument('--checkpoint_path', type=str, default="", help='Checkpoint file path.')
args_opt = parser.parse_args()

# Declare common variables and assign the args_opt value to them
epoch_size = args_opt.epoch_size
batch_size = args_opt.batch_size
cifar10_path = args_opt.dataset_path

# set runtime environment
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
dataset_sink_mode = not args_opt.device_target == "CPU"

# download cifar10 dataset
if not args_opt.dataset_path:
args_opt.dataset_path = download_dataset('cifar10')

# build the network
net = MobileNetV2(args_opt.num_classes)
model = Model(net)

# create cifar10 dataset for training
ds_train = create_dataset(cifar10_path, batch_size=batch_size)

# define the loss function
if args_opt.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(smooth_factor=args_opt.label_smooth,
num_classes=args_opt.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# get learning rate
step_size = ds_train.get_dataset_size()
lr = mobilenetv2_lr(global_step=0, lr_init=.0, lr_end=.0, lr_max=0.8, warmup_epochs=0,
total_epochs=epoch_size,
steps_per_epoch=step_size)

# define the optimizer
loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 4e-5, 1024)
model.compile(loss_fn=loss, optimizer=opt, metrics={"Accuracy": Accuracy()}, loss_scale_manager=loss_scale)

if args_opt.do_eval: # as for evaluation, users could use model.eval
# create cifar10 dataset for eval
ds_eval = create_dataset(cifar10_path, batch_size=batch_size, training=False)
if args_opt.checkpoint_path:
model.load_checkpoint(args_opt.checkpoint_path)
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
else: # as for train, users could use model.train
# configure checkpoint to save weights and do training job
save_checkpoint_epochs = args_opt.save_checkpoint_epochs
ckpoint_cb = mobilenetv2_cb(device_target=args_opt.device_target,
lr=lr,
is_saving_checkpoint=args_opt.is_saving_checkpoint,
save_checkpoint_epochs=args_opt.save_checkpoint_epochs,
step_size=step_size)
model.train(epoch_size, ds_train, callbacks=ckpoint_cb, dataset_sink_mode=dataset_sink_mode)
57 changes: 57 additions & 0 deletions tinyms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,65 @@
# limitations under the License.
# ============================================================================


import time
import numpy as np

from mindspore import Tensor
from mindspore.train import callback
from mindspore.train.callback import *


__all__ = []
__all__.extend(callback.__all__)


class LossTimeMonitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> LossTimeMonitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""

def __init__(self, lr_init=None):
super(LossTimeMonitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)

def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()

def epoch_end(self, run_context):
cb_params = run_context.original_args()

epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))

def step_begin(self, run_context):
self.step_time = time.time()

def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs

if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())

self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num

print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
38 changes: 38 additions & 0 deletions tinyms/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,44 @@

from mindspore.nn import loss
from mindspore.nn.loss import *
from mindspore.nn.loss.loss import _Loss
from tinyms import Tensor, dtype
import tinyms.primitives as P


__all__ = []
__all__.extend(loss.__all__)


class CrossEntropyWithLabelSmooth(_Loss):
"""
CrossEntropyWith LabelSmooth.
Args:
smooth_factor (float): smooth factor. Default is 0.
num_classes (int): number of classes. Default is 1000.
Returns:
None.
Examples:
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
"""

def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, dtype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), dtype.float32)
self.ce = SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()

def construct(self, logit, label):
one_hot_label = self.onehot(self.cast(label, dtype.int32), P.shape(logit)[1],
self.on_value, self.off_value)
out_loss = self.ce(logit, one_hot_label)
out_loss = self.mean(out_loss, 0)
return out_loss

Empty file added tinyms/utils/__init__.py
Empty file.
Empty file added tinyms/utils/train/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions tinyms/utils/train/cb_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossTimeMonitor


def mobilenetv2_cb(device_target, lr, is_saving_checkpoint, save_checkpoint_epochs, step_size):
cb = None
if device_target in ("CPU", "GPU"):
cb = [LossTimeMonitor(lr_init=lr.asnumpy())]

if is_saving_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_epochs * step_size,
keep_checkpoint_max=10)
ckpt_save_dir = "./"
ckpt_cb = ModelCheckpoint(prefix="mobilenetv2_cifar10", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
return cb

17 changes: 17 additions & 0 deletions tinyms/utils/train/loss_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.train.loss_scale_manager import *

57 changes: 57 additions & 0 deletions tinyms/utils/train/lr_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""learning rate generator"""
import math
import tinyms as ts


def mobilenetv2_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)

current_step = global_step
lr_each_step = ts.array(lr_each_step, dtype=ts.float32)
learning_rate = lr_each_step[current_step:]

return learning_rate


0 comments on commit 4f80f52

Please sign in to comment.