Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlexNet-Hub #93

Merged
merged 2 commits into from
Jun 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions tests/st/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tinyms import context
from tinyms.data import Cifar10Dataset, download_dataset
from tinyms.vision import cifar10_transform
from tinyms.model import Model, AlexNet
from tinyms.model import Model, alexnet
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor
from tinyms.metrics import Accuracy
from tinyms.optimizers import Momentum
Expand All @@ -31,17 +31,21 @@


def parse_args():
parser = argparse.ArgumentParser(description='Image classification')
parser = argparse.ArgumentParser(description='TinyMS AlexNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--epoch_size', type=int, default=90, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.')
parser.add_argument('--load_pretrained', type=str, choices=['hub', 'local'], default='local',
help='Specify where to load pretrained model, only valid in do_eval mode. (default: local)')
parser.add_argument('--save_checkpoint_epochs', type=int, default=5,
help='Specify epochs interval to save each checkpoints.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path.')
parser.add_argument('--hub_uid', type=str, default=None,
help='Model asset uid. Only valid when load_pretrained is `hub`.')
args_opt = parser.parse_args()

return args_opt
Expand Down Expand Up @@ -76,8 +80,12 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers
if not args_opt.dataset_path:
args_opt.dataset_path = download_dataset('cifar10')
# build the network
net = AlexNet(args_opt.num_classes)
net.update_parameters_name(prefix='huawei')
if args_opt.do_eval and args_opt.load_pretrained == 'hub':
from tinyms import hub
net = hub.load(args_opt.hub_uid)
else:
net = alexnet(class_num=args_opt.num_classes)
net.update_parameters_name(prefix='huawei')
model = Model(net)
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
Expand All @@ -90,10 +98,13 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers
cifar10_path = args_opt.dataset_path
save_checkpoint_epochs = args_opt.save_checkpoint_epochs
dataset_sink_mode = not args_opt.device_target == "CPU"

if args_opt.do_eval: # as for evaluation, users could use model.eval
ds_eval = create_dataset(cifar10_path, batch_size=batch_size, is_training=False)
if args_opt.checkpoint_path:
model.load_checkpoint(args_opt.checkpoint_path)
# load the saved model for evaluation
if args_opt.load_pretrained == 'local':
if args_opt.checkpoint_path:
model.load_checkpoint(args_opt.checkpoint_path)
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
else: # as for train, users could use model.train
Expand Down
2 changes: 1 addition & 1 deletion tests/st/cycle_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def parse_args():
parser = argparse.ArgumentParser(description='MindSpore Cycle GAN Example')
parser = argparse.ArgumentParser(description='TinyMS Cycle GAN Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
parser.add_argument('--dataset_path', type=str, default=None, help='cityscape dataset path.')
Expand Down
2 changes: 1 addition & 1 deletion tests/st/lenet5.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


def parse_args():
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser = argparse.ArgumentParser(description='TinyMS LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented. (default: CPU)')
parser.add_argument('--dataset_path', type=str, default=None, help='Mnist dataset path.')
Expand Down
27 changes: 27 additions & 0 deletions tinyms/hub/assets/tinyms/0.2/alexnet_v1_cifar10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
##################################################################################################
# AlexNet with Cifar10
##################################################################################################
---
model-name: alexnet
backbone-name: alexnet
module-type: cv-classification
fine-tunable: False
input-shape: [3, 224, 224]
model-version: v1
train-dataset: cifar10
train-backend: GPU
accuracy: 0.887
author: TinyMS team
update-time: 2021-06-04
user-id: TinyMS
used-for: inference
infer-backend:
- CPU
- GPU
tinyms-version: 0.2
asset:
file-format: ckpt
asset-link: https://tinyms-hub.obs.cn-north-4.myhuaweicloud.com/tinyms/0.2/alexnet_v1_cifar10/alexnet.ckpt
asset-sha256: 941cbd3dcf9af40decf992a67c3635f77e1a797ca0c69e9cc4dae23bb3300bb4
license: Apache-2.0
summary: AlexNet model used to classify the 10 classes of Cifar10 dataset.
1 change: 1 addition & 0 deletions tinyms/hub/hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tinyms import model

MODEL_HUB = ed({
"alexnet_v1": model.alexnet,
"lenet5_v1": model.lenet5,
"resnet50_v1": model.resnet50,
"mobilenet_v2": model.mobilenetv2,
Expand Down