Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding tutorials for resnet50 and lenet5 #11

Merged
merged 20 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions tests/ut/vision/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ def test_mnist_transform():
def test_mnist_transform_postprocess():
input = np.array([[10, 1, 4, 2, 5, 18, -10, -4, 3, 7]])
label = mnist_transform.postprocess(input)
assert label == 5
assert label == 'TOP1: 5, score: 0.99964439868927001953'

label = mnist_transform.postprocess(input, strategy='TOP5_CLASS')
expected = {'label': [5, 0, 9, 4, 2],
'score': [18, 10, 7, 5, 4]}
expected = "TOP1: 5, score: 0.99964439868927001953\nTOP2: 0, score: 0.00033534335670992732\nTOP3: 9, score: 0.00001669576158747077\nTOP4: 4, score: 0.00000225952544496977\nTOP5: 2, score: 0.00000083123296690246\n"
assert label == expected


Expand All @@ -51,11 +50,10 @@ def test_cifar10_transform():
def test_cifar10_transform_postprocess():
input = np.array([[10, 1, 4, 2, 5, 18, -10, -4, 3, 7]])
label = cifar10_transform.postprocess(input)
assert label == 'dog'
assert label == 'TOP1: dog, score: 0.99964439868927001953'

label = cifar10_transform.postprocess(input, strategy='TOP5_CLASS')
expected = {'label': ['dog', 'airplane', 'truck', 'deer', 'bird'],
'score': [18, 10, 7, 5, 4]}
expected = "TOP1: dog, score: 0.99964439868927001953\nTOP2: airplane, score: 0.00033534335670992732\nTOP3: truck, score: 0.00001669576158747077\nTOP4: deer, score: 0.00000225952544496977\nTOP5: bird, score: 0.00000083123296690246\n"
assert label == expected


Expand All @@ -68,13 +66,7 @@ def test_imagefolder_transform():
def test_imagefolder_transform_postprocess():
input = np.array([[10, 4, 2, 5, 18, -10, -4, 3, 7]])
label = imagefolder_transform.postprocess(input)
assert label == 'Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒'

assert label == 'TOP1: Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒, score: 0.99964439868927001953'
label = imagefolder_transform.postprocess(input, strategy='TOP5_CLASS')
expected = {'label': ['Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒',
'Agaricus双孢蘑菇,伞菌目,蘑菇科,蘑菇属,广泛分布于北半球温带,无毒',
'Suillus乳牛肝菌,牛肝菌目,乳牛肝菌科,乳牛肝菌属,分布于吉林、辽宁、山西、安徽、江西、浙江、湖南、四川、贵州等地,无毒',
'Cortinarius掷丝膜菌,伞菌目,丝膜菌科,丝膜菌属,分布于湖南等地(夏秋季在山毛等阔叶林地上生长)',
'Amanita毒蝇伞,伞菌目,鹅膏菌科,鹅膏菌属,主要分布于我国黑龙江、吉林、四川、西藏、云南等地,有毒'],
'score': [18, 10, 7, 5, 4]}
expected = "TOP1: Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒, score: 0.99964439868927001953\nTOP2: Agaricus双孢蘑菇,伞菌目,蘑菇科,蘑菇属,广泛分布于北半球温带,无毒, score: 0.00033534335670992732\nTOP3: Suillus乳牛肝菌,牛肝菌目,乳牛肝菌科,乳牛肝菌属,分布于吉林、辽宁、山西、安徽、江西、浙江、湖南、四川、贵州等地,无毒, score: 0.00001669576158747077\nTOP4: Cortinarius掷丝膜菌,伞菌目,丝膜菌科,丝膜菌属,分布于湖南等地(夏秋季在山毛等阔叶林地上生长), score: 0.00000225952544496977\nTOP5: Amanita毒蝇伞,伞菌目,鹅膏菌科,鹅膏菌属,主要分布于我国黑龙江、吉林、四川、西藏、云南等地,有毒, score: 0.00000083123296690246\n"
assert label == expected
17 changes: 15 additions & 2 deletions tinyms/serving/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import sys
import requests
import numpy as np
from PIL import Image
from tinyms.vision import mnist_transform, cifar10_transform, imagefolder_transform

Expand All @@ -33,7 +34,7 @@ def list_servables():
print(res_body['servables'])


def predict(img_path, servable_name, dataset_name="mnist"):
def predict(img_path, servable_name, dataset_name="mnist", strategy="TOP1_CLASS"):
# TODO: The preprocess would be moved to data module later
# check if dataset_name and img_path are valid
if dataset_name not in ("mnist", "cifar10", "imagenet2012"):
Expand All @@ -42,6 +43,9 @@ def predict(img_path, servable_name, dataset_name="mnist"):
if not os.path.isfile(img_path):
print("The image path "+img_path+" not exist!")
sys.exit(0)
if strategy not in ("TOP1_CLASS", "TOP5_CLASS"):
print("Currently strategy only supports `TOP1_CLASS` and `TOP5_CLASS`!")
sys.exit(0)

img_data = Image.open(img_path)
if dataset_name == "mnist":
Expand Down Expand Up @@ -69,4 +73,13 @@ def predict(img_path, servable_name, dataset_name="mnist"):
elif res_body['status'] != 0:
leonwanghui marked this conversation as resolved.
Show resolved Hide resolved
print(res_body['err_msg'])
else:
print(res_body['instance'])
instance = res_body['instance']
if dataset_name == "mnist":
data = mnist_transform.postprocess(np.array(json.loads(instance['data'])), strategy)
print(data)
elif dataset_name == "imagenet2012":
data = imagefolder_transform.postprocess(np.array(json.loads(instance['data'])), strategy)
print(data)
else:
data = cifar10_transform.postprocess(np.array(json.loads(instance['data'])), strategy)
print(data)
4 changes: 2 additions & 2 deletions tinyms/serving/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .server import start_server, shutdown
from .server import start_server, shutdown, run_flask

__all__ = ["start_server", "shutdown"]
__all__ = ["start_server", "shutdown", "run_flask"]
15 changes: 10 additions & 5 deletions tinyms/serving/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import subprocess

from flask import request, Flask, jsonify
from ..servable import predict, servable_search

Expand All @@ -37,13 +39,16 @@ def list_servables():
return jsonify(servable_search())


def start_server(host='127.0.0.1', port=5000):
def run_flask(host='127.0.0.1', port=5000):
app.run(host=host, port=port)


def start_server():
cmd = ['python -c "from tinyms.serving import run_flask; run_flask()"']
server_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)


def shutdown():
func = request.environ.get('werkzeug.server.shutdown')
if func is None:
raise RuntimeError('Not running with the Werkzeug Server')
func()
server_pid = subprocess.getoutput("netstat -anp | grep 5000 | awk '{printf $7}' | cut -d/ -f1")
subprocess.run("kill -9 " + str(server_pid) + "", shell=True)
return 'Server shutting down...'
Loading