Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the first version of TinyMS Model API design #5

Merged
merged 6 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ install:
script:
- python -c "import mindspore; print(mindspore.__version__)"
- python -c "import tinyms; print(tinyms.__version__)"
- pytest

services:
- docker
112 changes: 112 additions & 0 deletions tests/st/lenet5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Lenet5 Tutorial
The sample can be run on CPU, GPU and Ascend 910 AI processor.
"""
import os
import argparse

import tinyms as ts
from tinyms import context
from tinyms.data import MnistDataset, download_dataset
from tinyms.data.transforms import TypeCast
from tinyms.vision import Inter, Resize, Rescale, HWC2CHW
from tinyms.model import Model, lenet5
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor
from tinyms.metrics import Accuracy
from tinyms.optimizers import Momentum
from tinyms.losses import SoftmaxCrossEntropyWithLogits


def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
""" create Mnist dataset for train or eval.
Args:
data_path: Data path
batch_size: The number of data records in each group
repeat_size: The number of replicated data records
num_parallel_workers: The number of parallel workers
"""
# define dataset
mnist_ds = MnistDataset(data_path, num_parallel_workers=num_parallel_workers,
shuffle=True)

# define map operations
c_trans = [
Resize((32, 32), interpolation=Inter.LINEAR), # Resize images to (32, 32)
Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081), # normalize images
Rescale(1.0 / 255.0, 0.0), # rescale images
HWC2CHW(), # change shape from (height, width, channel) to (channel, height, width) to fit network
]
type_cast_op = TypeCast(ts.int32) # change data type of label to int32 to fit network

# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply batch operations
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
# apply repeat operations
mnist_ds = mnist_ds.repeat(repeat_size)

return mnist_ds


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
parser.add_argument('--dataset_path', type=str, default=None, help='Mnist dataset path.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='CheckPoint file path.')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)

# download mnist dataset
if not args_opt.dataset_path:
args_opt.dataset_path = download_dataset('mnist')
# build the network
net = lenet5()
model = Model(net)
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define the optimizer
net_opt = Momentum(net.trainable_params(), 0.01, 0.9)
model.compile(net_loss, net_opt, metrics={"Accuracy": Accuracy()})

epoch_size = args_opt.epoch_size
batch_size = args_opt.batch_size
mnist_path = args_opt.dataset_path
dataset_sink_mode = not args_opt.device_target == "CPU"

if args_opt.do_eval: # as for evaluation, users could use model.eval
print("============== Starting Evaluating ==============")
# load testing dataset
ds_eval = create_dataset(os.path.join(mnist_path, "test"))
# load the saved model for evaluation
if args_opt.checkpoint_path:
model.load_checkpoint(args_opt.checkpoint_path)
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
else: # as for train, users could use model.train
print("============== Starting Training ==============")
# load training dataset
ds_train = create_dataset(os.path.join(mnist_path, "train"), batch_size=batch_size)
# save the network model and parameters for subsequence fine-tuning
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=CheckpointConfig(
save_checkpoint_steps=1875, keep_checkpoint_max=10))
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=dataset_sink_mode)
113 changes: 113 additions & 0 deletions tests/st/resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet50 Tutorial
The sample can be run on CPU, GPU and Ascend 910 AI processor.
"""
import random
import argparse

import tinyms as ts
from tinyms import context
from tinyms.data import Cifar10Dataset, download_dataset
from tinyms.data.transforms import TypeCast
from tinyms.vision import RandomCrop, RandomHorizontalFlip, Resize, Rescale, Normalize, HWC2CHW
from tinyms.model import Model, resnet50
from tinyms.callbacks import ModelCheckpoint, CheckpointConfig, LossMonitor
from tinyms.metrics import Accuracy
from tinyms.optimizers import Momentum
from tinyms.losses import SoftmaxCrossEntropyWithLogits

random.seed(1)


def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1,
training=True):
""" create Cifar10 dataset for train or eval.
Args:
data_path: Data path
batch_size: The number of data records in each group
repeat_size: The number of replicated data records
num_parallel_workers: The number of parallel workers
"""
# define dataset
cifar_ds = Cifar10Dataset(data_path, num_parallel_workers=num_parallel_workers,
shuffle=True)

# define map operations
c_trans = []
if training:
c_trans += [
RandomCrop((32, 32), (4, 4, 4, 4)),
RandomHorizontalFlip(prob=0.5),
]
c_trans += [
Resize((224, 224)), Rescale(1.0 / 255.0, 0.0),
Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
HWC2CHW(),
]
type_cast_op = TypeCast(ts.int32)

# apply map operations on images
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_size)

return cifar_ds


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
parser.add_argument('--dataset_path', type=str, default=None, help='Cifar10 dataset path.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--epoch_size', type=int, default=1, help='Epoch size.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size.')
parser.add_argument('--num_classes', type=int, default=10, help='Num classes.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='CheckPoint file path.')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)

# download cifar10 dataset
if not args_opt.dataset_path:
args_opt.dataset_path = download_dataset('cifar10')
# build the network
net = resnet50(args_opt.num_classes)
model = Model(net)
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# define the optimizer
net_opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model.compile(loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": Accuracy()})

epoch_size = args_opt.epoch_size
batch_size = args_opt.batch_size
cifar10_path = args_opt.dataset_path
dataset_sink_mode = not args_opt.device_target == "CPU"
if args_opt.do_eval: # as for evaluation, users could use model.eval
ds_eval = create_dataset(cifar10_path, batch_size=batch_size, training=False)
if args_opt.checkpoint_path:
model.load_checkpoint(args_opt.checkpoint_path)
acc = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
else: # as for train, users could use model.train
ds_train = create_dataset(cifar10_path, batch_size=batch_size)
ckpoint_cb = ModelCheckpoint(prefix="resnet_cifar10", config=CheckpointConfig(
save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=35))
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=dataset_sink_mode)
41 changes: 41 additions & 0 deletions tests/ut/data/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest

from tinyms.data import download_dataset


@pytest.mark.skip(reason="no way of currently testing this")
def test_download_dataset_mnist():
download_dataset(dataset_name='mnist', local_path='/tmp')

assert os.path.exists('/tmp/mnist/train')
assert os.path.exists('/tmp/mnist/test')


@pytest.mark.skip(reason="no way of currently testing this")
def test_download_dataset_cifar10():
download_dataset(dataset_name='cifar10', local_path='/tmp')

assert os.path.exists('/tmp/cifar10/cifar-10-batches-bin/batches.meta.txt')


@pytest.mark.skip(reason="no way of currently testing this")
def test_download_dataset_cifar100():
download_dataset(dataset_name='cifar100', local_path='/tmp')

assert os.path.exists('/tmp/cifar100/cifar-100-bin/train.bin')
assert os.path.exists('/tmp/cifar100/cifar-100-bin/test.bin')
33 changes: 33 additions & 0 deletions tests/ut/model/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import tinyms as ts
from tinyms import context, layers
from tinyms.layers import SequentialLayer
from tinyms.model import Model


def test_model_predict():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

net = SequentialLayer([
layers.Conv2d(1, 6, 5, pad_mode='valid', weight_init="ones"),
layers.ReLU(),
layers.MaxPool2d(kernel_size=2, stride=2)
])
model = Model(net)
model.compile()
z = model.predict(ts.ones((1, 1, 28, 28)))
print(z.asnumpy())
3 changes: 3 additions & 0 deletions tinyms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# ============================================================================
""".. TinyMS package."""
from .version import __version__
from . import common
from .common import *

__all__ = []
__all__.extend(__version__)
__all__.extend(common.__all__)
20 changes: 20 additions & 0 deletions tinyms/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore.train import callback
from mindspore.train.callback import *

__all__ = []
__all__.extend(callback.__all__)
22 changes: 22 additions & 0 deletions tinyms/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import common
from mindspore.common import *
from mindspore import numpy
from mindspore.numpy import *

__all__ = []
__all__.extend(common.__all__)
__all__.extend(numpy.__all__)
20 changes: 20 additions & 0 deletions tinyms/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from mindspore import context
from mindspore.context import *

__all__ = []
__all__.extend(context.__all__)
Loading