-
Notifications
You must be signed in to change notification settings - Fork 0
/
api.py
98 lines (79 loc) · 2.65 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
import albumentations
import pretrainedmodels
import numpy as np
import torch.nn as nn
from flask import Flask
from flask import request
from flask import render_template
from torch.nn import functional as F
from wtfml.data_loaders.image import ClassificationLoader
from wtfml.engine import Engine
app = Flask(__name__)
UPLOAD_FOLDER = "E:/Users/Weston/workspace/Detecting-Melanoma/static"
DEVICE = "cpu" #cpu with docker else gpu/tpu
MODEL = None
class SEResNext50_32x4d(nn.Module):
def __init__(self, pretrained="imagenet"):
super(SEResNext50_32x4d, self).__init__()
self.model = pretrainedmodels.__dict__[
"se_resnext50_32x4d"
](pretrained=pretrained)
self.out = nn.Linear(2048, 1)
def forward(self, image, targets):
bs, _, _, _ = image.shape
x = self.model.features(image)
x = F.adaptive_avg_pool2d(x, 1)
x = x.reshape(bs, -1)
out = torch.sigmoid(self.out(x))
loss = 0
return out, loss
def predict(image_path, model):
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
test_aug = albumentations.Compose(
[
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
albumentations.augmentations.transforms.Flip(),
]
)
test_images = [image_path]
test_targets = [0]
test_dataset = ClassificationLoader(
image_paths=test_images,
targets=test_targets,
resize=None,
augmentations=test_aug
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=0
)
predictions = Engine.predict(
test_loader,
model,
DEVICE
)
return np.vstack((predictions)).ravel()
@app.route("/", methods=["GET", "POST"])
def upload_predict():
if request.method == "POST":
image_file = request.files["image"]
if image_file:
image_location = os.path.join(
UPLOAD_FOLDER,
image_file.filename
)
image_file.save(image_location)
pred = predict(image_location, MODEL)[0]
return render_template("index.html", prediction=pred, image_loc=image_file.filename)
return render_template("index.html", prediction=0, image_loc=None)
# model0.bin was the name of my model, to create your own train yours using main.py
if __name__ == "__main__":
MODEL = SEResNext50_32x4d(pretrained=None)
MODEL.load_state_dict(torch.load("model0.bin", map_location=torch.device(DEVICE)))
MODEL.to(DEVICE)
app.run(host="0.0.0.0", port=12000, debug=True)