diff --git a/.gitignore b/.gitignore index 822596cf..6c91ee2d 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,4 @@ dmypy.json # auto-generated py file */version.py -**/.DS_Store \ No newline at end of file +**/.DS_Store diff --git a/tests/st/mobilenetv2.py b/tests/st/mobilenetv2.py new file mode 100644 index 00000000..197a67dc --- /dev/null +++ b/tests/st/mobilenetv2.py @@ -0,0 +1,124 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MobileNetV2 Tutorial +The sample can be run on GPU and Ascend 910 AI processors +""" + + +import argparse + +from tinyms import context +from tinyms.data import Cifar10Dataset, download_dataset +from tinyms.vision import cifar10_transform +from tinyms.model import Model, MobileNetV2 +from tinyms.metrics import Accuracy +from tinyms.optimizers import Momentum +from tinyms.losses import SoftmaxCrossEntropyWithLogits, CrossEntropyWithLabelSmooth +from tinyms.utils.train.loss_manager import FixedLossScaleManager +from tinyms.utils.train.lr_generator import mobilenetv2_lr +from tinyms.utils.train.cb_config import mobilenetv2_cb + + +def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=4, training=True): + """create Cifar10 dataset for train or eval. + Args: + data_path: Data path + batch_size: The number of data records in each group + repeat_size: The number of replicated data records + num_parallel_workers: The number of parallel workers + """ + # define cifar_10 dataset and apply the transform func + cifar10_ds = Cifar10Dataset(data_path, + num_parallel_workers=num_parallel_workers, + shuffle=True) + cifar10_ds = cifar10_transform.apply_ds(cifar10_ds, + repeat_size=repeat_size, + batch_size=batch_size, + training=training) + + return cifar10_ds + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='MobileNetV2 Image classification') + parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: CPU)') + parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.') + parser.add_argument('--num_classes', type=int, default=10, help='Num classes.') + parser.add_argument('--label_smooth', type=int, default=0.1, help='label smooth') + parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') + parser.add_argument('--epoch_size', type=int, default=100, help='Epoch size.') + parser.add_argument('--batch_size', type=int, default=150, help='Batch size.') + parser.add_argument('--is_saving_checkpoint', type=bool, default=True, help='Whether to save checkpoint.') + parser.add_argument('--save_checkpoint_epochs', type=int, default=10, + help='Specify epochs interval to save each checkpoints.') + parser.add_argument('--checkpoint_path', type=str, default="", help='Checkpoint file path.') + args_opt = parser.parse_args() + + # Declare common variables and assign the args_opt value to them + epoch_size = args_opt.epoch_size + batch_size = args_opt.batch_size + cifar10_path = args_opt.dataset_path + + # set runtime environment + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + dataset_sink_mode = not args_opt.device_target == "CPU" + + # download cifar10 dataset + if not args_opt.dataset_path: + args_opt.dataset_path = download_dataset('cifar10') + + # build the network + net = MobileNetV2(args_opt.num_classes) + model = Model(net) + + # create cifar10 dataset for training + ds_train = create_dataset(cifar10_path, batch_size=batch_size) + + # define the loss function + if args_opt.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth(smooth_factor=args_opt.label_smooth, + num_classes=args_opt.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + # get learning rate + step_size = ds_train.get_dataset_size() + lr = mobilenetv2_lr(global_step=0, lr_init=.0, lr_end=.0, lr_max=0.8, warmup_epochs=0, + total_epochs=epoch_size, + steps_per_epoch=step_size) + + # define the optimizer + loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 4e-5, 1024) + model.compile(loss_fn=loss, optimizer=opt, metrics={"Accuracy": Accuracy()}, loss_scale_manager=loss_scale) + + if args_opt.do_eval: # as for evaluation, users could use model.eval + # create cifar10 dataset for eval + ds_eval = create_dataset(cifar10_path, batch_size=batch_size, training=False) + if args_opt.checkpoint_path: + model.load_checkpoint(args_opt.checkpoint_path) + acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode) + print("============== Accuracy:{} ==============".format(acc)) + else: # as for train, users could use model.train + # configure checkpoint to save weights and do training job + save_checkpoint_epochs = args_opt.save_checkpoint_epochs + ckpoint_cb = mobilenetv2_cb(device_target=args_opt.device_target, + lr=lr, + is_saving_checkpoint=args_opt.is_saving_checkpoint, + save_checkpoint_epochs=args_opt.save_checkpoint_epochs, + step_size=step_size) + model.train(epoch_size, ds_train, callbacks=ckpoint_cb, dataset_sink_mode=dataset_sink_mode) diff --git a/tinyms/callbacks.py b/tinyms/callbacks.py index a2a64ac3..e010772d 100644 --- a/tinyms/callbacks.py +++ b/tinyms/callbacks.py @@ -13,8 +13,65 @@ # limitations under the License. # ============================================================================ + +import time +import numpy as np + +from mindspore import Tensor from mindspore.train import callback from mindspore.train.callback import * + __all__ = [] __all__.extend(callback.__all__) + + +class LossTimeMonitor(Callback): + """ + Monitor loss and time. + Args: + lr_init (numpy array): train lr + Returns: + None + Examples: + >>> LossTimeMonitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(LossTimeMonitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) diff --git a/tinyms/losses.py b/tinyms/losses.py index 388dbf8b..dca46ab5 100644 --- a/tinyms/losses.py +++ b/tinyms/losses.py @@ -15,6 +15,44 @@ from mindspore.nn import loss from mindspore.nn.loss import * +from mindspore.nn.loss.loss import _Loss +from tinyms import Tensor, dtype +import tinyms.primitives as P + __all__ = [] __all__.extend(loss.__all__) + + +class CrossEntropyWithLabelSmooth(_Loss): + """ + CrossEntropyWith LabelSmooth. + + Args: + smooth_factor (float): smooth factor. Default is 0. + num_classes (int): number of classes. Default is 1000. + + Returns: + None. + + Examples: + >>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000) + """ + + def __init__(self, smooth_factor=0., num_classes=1000): + super(CrossEntropyWithLabelSmooth, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, dtype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), dtype.float32) + self.ce = SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logit, label): + one_hot_label = self.onehot(self.cast(label, dtype.int32), P.shape(logit)[1], + self.on_value, self.off_value) + out_loss = self.ce(logit, one_hot_label) + out_loss = self.mean(out_loss, 0) + return out_loss + diff --git a/tinyms/utils/__init__.py b/tinyms/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tinyms/utils/train/__init__.py b/tinyms/utils/train/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tinyms/utils/train/cb_config.py b/tinyms/utils/train/cb_config.py new file mode 100644 index 00000000..6bf9955a --- /dev/null +++ b/tinyms/utils/train/cb_config.py @@ -0,0 +1,31 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossTimeMonitor + + +def mobilenetv2_cb(device_target, lr, is_saving_checkpoint, save_checkpoint_epochs, step_size): + cb = None + if device_target in ("CPU", "GPU"): + cb = [LossTimeMonitor(lr_init=lr.asnumpy())] + + if is_saving_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_epochs * step_size, + keep_checkpoint_max=10) + ckpt_save_dir = "./" + ckpt_cb = ModelCheckpoint(prefix="mobilenetv2_cifar10", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + return cb + diff --git a/tinyms/utils/train/loss_manager.py b/tinyms/utils/train/loss_manager.py new file mode 100644 index 00000000..305c8b99 --- /dev/null +++ b/tinyms/utils/train/loss_manager.py @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.train.loss_scale_manager import * + diff --git a/tinyms/utils/train/lr_generator.py b/tinyms/utils/train/lr_generator.py new file mode 100644 index 00000000..127ecfaa --- /dev/null +++ b/tinyms/utils/train/lr_generator.py @@ -0,0 +1,57 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""learning rate generator""" +import math +import tinyms as ts + + +def mobilenetv2_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2. + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = ts.array(lr_each_step, dtype=ts.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate + +