-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from leonwanghui/model
Add the first version of TinyMS Model API design
- Loading branch information
Showing
29 changed files
with
1,249 additions
and
406 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright 2020-2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""Lenet5 Tutorial | ||
The sample can be run on CPU, GPU and Ascend 910 AI processor. | ||
""" | ||
import os | ||
import argparse | ||
|
||
import tinyms as ts | ||
from tinyms import context | ||
from tinyms.data import MnistDataset, download_dataset | ||
from tinyms.data.transforms import TypeCast | ||
from tinyms.vision import Inter, Resize, Rescale, HWC2CHW | ||
from tinyms.model import Model, lenet5 | ||
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor | ||
from tinyms.metrics import Accuracy | ||
from tinyms.optimizers import Momentum | ||
from tinyms.losses import SoftmaxCrossEntropyWithLogits | ||
|
||
|
||
def create_dataset(data_path, batch_size=32, repeat_size=1, | ||
num_parallel_workers=1): | ||
""" create Mnist dataset for train or eval. | ||
Args: | ||
data_path: Data path | ||
batch_size: The number of data records in each group | ||
repeat_size: The number of replicated data records | ||
num_parallel_workers: The number of parallel workers | ||
""" | ||
# define dataset | ||
mnist_ds = MnistDataset(data_path, num_parallel_workers=num_parallel_workers, | ||
shuffle=True) | ||
|
||
# define map operations | ||
c_trans = [ | ||
Resize((32, 32), interpolation=Inter.LINEAR), # Resize images to (32, 32) | ||
Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081), # normalize images | ||
Rescale(1.0 / 255.0, 0.0), # rescale images | ||
HWC2CHW(), # change shape from (height, width, channel) to (channel, height, width) to fit network | ||
] | ||
type_cast_op = TypeCast(ts.int32) # change data type of label to int32 to fit network | ||
|
||
# apply map operations on images | ||
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) | ||
mnist_ds = mnist_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers) | ||
# apply batch operations | ||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | ||
# apply repeat operations | ||
mnist_ds = mnist_ds.repeat(repeat_size) | ||
|
||
return mnist_ds | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='MindSpore LeNet Example') | ||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], | ||
help='device where the code will be implemented (default: CPU)') | ||
parser.add_argument('--dataset_path', type=str, default=None, help='Mnist dataset path.') | ||
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') | ||
parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.') | ||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.') | ||
parser.add_argument('--checkpoint_path', type=str, default=None, help='CheckPoint file path.') | ||
args_opt = parser.parse_args() | ||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | ||
|
||
# download mnist dataset | ||
if not args_opt.dataset_path: | ||
args_opt.dataset_path = download_dataset('mnist') | ||
# build the network | ||
net = lenet5() | ||
model = Model(net) | ||
# define the loss function | ||
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||
# define the optimizer | ||
net_opt = Momentum(net.trainable_params(), 0.01, 0.9) | ||
model.compile(net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | ||
|
||
epoch_size = args_opt.epoch_size | ||
batch_size = args_opt.batch_size | ||
mnist_path = args_opt.dataset_path | ||
dataset_sink_mode = not args_opt.device_target == "CPU" | ||
|
||
if args_opt.do_eval: # as for evaluation, users could use model.eval | ||
print("============== Starting Evaluating ==============") | ||
# load testing dataset | ||
ds_eval = create_dataset(os.path.join(mnist_path, "test")) | ||
# load the saved model for evaluation | ||
if args_opt.checkpoint_path: | ||
model.load_checkpoint(args_opt.checkpoint_path) | ||
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode) | ||
print("============== Accuracy:{} ==============".format(acc)) | ||
else: # as for train, users could use model.train | ||
print("============== Starting Training ==============") | ||
# load training dataset | ||
ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size=batch_size) | ||
# save the network model and parameters for subsequence fine-tuning | ||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=CheckpointConfig( | ||
save_checkpoint_steps=1875, keep_checkpoint_max=10)) | ||
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], | ||
dataset_sink_mode=dataset_sink_mode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright 2020-2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""ResNet50 Tutorial | ||
The sample can be run on CPU, GPU and Ascend 910 AI processor. | ||
""" | ||
import random | ||
import argparse | ||
|
||
import tinyms as ts | ||
from tinyms import context | ||
from tinyms.data import Cifar10Dataset, download_dataset | ||
from tinyms.data.transforms import TypeCast | ||
from tinyms.vision import RandomCrop, RandomHorizontalFlip, Resize, Rescale, Normalize, HWC2CHW | ||
from tinyms.model import Model, resnet50 | ||
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor | ||
from tinyms.metrics import Accuracy | ||
from tinyms.optimizers import Momentum | ||
from tinyms.losses import SoftmaxCrossEntropyWithLogits | ||
|
||
random.seed(1) | ||
|
||
|
||
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1, | ||
training=True): | ||
""" create Cifar10 dataset for train or eval. | ||
Args: | ||
data_path: Data path | ||
batch_size: The number of data records in each group | ||
repeat_size: The number of replicated data records | ||
num_parallel_workers: The number of parallel workers | ||
""" | ||
# define dataset | ||
cifar_ds = Cifar10Dataset(data_path, num_parallel_workers=num_parallel_workers, | ||
shuffle=True) | ||
|
||
# define map operations | ||
c_trans = [] | ||
if training: | ||
c_trans += [ | ||
RandomCrop((32, 32), (4, 4, 4, 4)), | ||
RandomHorizontalFlip(prob=0.5), | ||
] | ||
c_trans += [ | ||
Resize((224, 224)), Rescale(1.0 / 255.0, 0.0), | ||
Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | ||
HWC2CHW(), | ||
] | ||
type_cast_op = TypeCast(ts.int32) | ||
|
||
# apply map operations on images | ||
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) | ||
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers) | ||
# apply batch operations | ||
cifar_ds = cifar_ds.batch(batch_size=batch_size, drop_remainder=True) | ||
# apply repeat operations | ||
cifar_ds = cifar_ds.repeat(repeat_size) | ||
|
||
return cifar_ds | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Image classification') | ||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], | ||
help='device where the code will be implemented (default: CPU)') | ||
parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.') | ||
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') | ||
parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.') | ||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.') | ||
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.') | ||
parser.add_argument('--checkpoint_path', type=str, default=None, help='CheckPoint file path.') | ||
args_opt = parser.parse_args() | ||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | ||
|
||
# download cifar10 dataset | ||
if not args_opt.dataset_path: | ||
args_opt.dataset_path = download_dataset('cifar10') | ||
# build the network | ||
net = resnet50(args_opt.num_classes) | ||
model = Model(net) | ||
# define the loss function | ||
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||
# define the optimizer | ||
net_opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | ||
model.compile(loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": Accuracy()}) | ||
|
||
epoch_size = args_opt.epoch_size | ||
batch_size = args_opt.batch_size | ||
cifar10_path = args_opt.dataset_path | ||
dataset_sink_mode = not args_opt.device_target == "CPU" | ||
if args_opt.do_eval: # as for evaluation, users could use model.eval | ||
ds_eval = create_dataset(cifar10_path, batch_size=batch_size, training=False) | ||
if args_opt.checkpoint_path: | ||
model.load_checkpoint(args_opt.checkpoint_path) | ||
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode) | ||
print("============== Accuracy:{} ==============".format(acc)) | ||
else: # as for train, users could use model.train | ||
ds_train = create_dataset(cifar10_path, batch_size=batch_size) | ||
ckpoint_cb = ModelCheckpoint(prefix="resnet_cifar10", config=CheckpointConfig( | ||
save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=35)) | ||
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], | ||
dataset_sink_mode=dataset_sink_mode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
import os | ||
import pytest | ||
|
||
from tinyms.data import download_dataset | ||
|
||
|
||
@pytest.mark.skip(reason="no way of currently testing this") | ||
def test_download_dataset_mnist(): | ||
download_dataset(dataset_name='mnist', local_path='/tmp') | ||
|
||
assert os.path.exists('/tmp/mnist/train') | ||
assert os.path.exists('/tmp/mnist/test') | ||
|
||
|
||
@pytest.mark.skip(reason="no way of currently testing this") | ||
def test_download_dataset_cifar10(): | ||
download_dataset(dataset_name='cifar10', local_path='/tmp') | ||
|
||
assert os.path.exists('/tmp/cifar10/cifar-10-batches-bin/batches.meta.txt') | ||
|
||
|
||
@pytest.mark.skip(reason="no way of currently testing this") | ||
def test_download_dataset_cifar100(): | ||
download_dataset(dataset_name='cifar100', local_path='/tmp') | ||
|
||
assert os.path.exists('/tmp/cifar100/cifar-100-bin/train.bin') | ||
assert os.path.exists('/tmp/cifar100/cifar-100-bin/test.bin') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
import tinyms as ts | ||
from tinyms import context, layers | ||
from tinyms.layers import SequentialLayer | ||
from tinyms.model import Model | ||
|
||
|
||
def test_model_predict(): | ||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | ||
|
||
net = SequentialLayer([ | ||
layers.Conv2d(1, 6, 5, pad_mode='valid', weight_init="ones"), | ||
layers.ReLU(), | ||
layers.MaxPool2d(kernel_size=2, stride=2) | ||
]) | ||
model = Model(net) | ||
model.compile() | ||
z = model.predict(ts.ones((1, 1, 28, 28))) | ||
print(z.asnumpy()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
from mindspore.train import callback | ||
from mindspore.train.callback import * | ||
|
||
__all__ = [] | ||
__all__.extend(callback.__all__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
from mindspore import common | ||
from mindspore.common import * | ||
from mindspore import numpy | ||
from mindspore.numpy import * | ||
|
||
__all__ = [] | ||
__all__.extend(common.__all__) | ||
__all__.extend(numpy.__all__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2021 Huawei Technologies Co., Ltd | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
from mindspore import context | ||
from mindspore.context import * | ||
|
||
__all__ = [] | ||
__all__.extend(context.__all__) |
Oops, something went wrong.