Skip to content

Commit

Permalink
fix preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Sep 1, 2024
1 parent 12334f9 commit 9c1603f
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions apps/ios_rpc/tests/ios_rpc_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def compile_metal(src, target):


def prepare_input():
from torchvision import transforms

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_name = "cat.png"
synset_url = "".join(
Expand All @@ -65,13 +67,19 @@ def prepare_input():
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))
input_image = Image.open(img_path)

image = np.array(image) - np.array([123.0, 117.0, 104.0])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image.astype("float32"), synset
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
return input_batch.detach().cpu().numpy(), synset


def get_model(model_name, data_shape):
Expand Down

0 comments on commit 9c1603f

Please sign in to comment.