Skip to content

Commit

Permalink
use torchvision's mobilenet_v2 instead of mxnet
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Sep 5, 2024
1 parent 73b138b commit 3a320e7
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions apps/ios_rpc/tests/ios_rpc_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import coremltools
import numpy as np
import tvm
from mxnet import gluon
from PIL import Image
from tvm import relay, rpc
from tvm.contrib import coreml_runtime, graph_executor, utils, xcode
Expand Down Expand Up @@ -51,6 +50,8 @@ def compile_metal(src, target):


def prepare_input():
from torchvision import transforms

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_name = "cat.png"
synset_url = "".join(
Expand All @@ -62,22 +63,36 @@ def prepare_input():
]
)
synset_name = "imagenet1000_clsid_to_human.txt"
img_path = download_testdata(img_url, "cat.png", module="data")
img_path = download_testdata(img_url, img_name, module="data")
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))
input_image = Image.open(img_path)

image = np.array(image) - np.array([123.0, 117.0, 104.0])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image.astype("float32"), synset
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
return input_batch.detach().cpu().numpy(), synset


def get_model(model_name, data_shape):
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
import torch
import torchvision

torch_model = getattr(torchvision.models, model_name)(weights="IMAGENET1K_V1").eval()
input_data = torch.randn(data_shape)
scripted_model = torch.jit.trace(torch_model, input_data)

input_infos = [("data", input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)

# we want a probability so add a softmax operator
func = mod["main"]
func = relay.Function(
Expand All @@ -90,7 +105,7 @@ def get_model(model_name, data_shape):
def test_mobilenet(host, port, key, mode):
temp = utils.tempdir()
image, synset = prepare_input()
model, params = get_model("mobilenetv2_1.0", image.shape)
model, params = get_model("mobilenet_v2", image.shape)

def run(mod, target):
with relay.build_config(opt_level=3):
Expand Down

0 comments on commit 3a320e7

Please sign in to comment.