-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
123 lines (94 loc) · 3.2 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torchvision
from PIL import *
import torchvision.transforms as transforms
import pickle
import torch.nn as nn
import torch.nn.functional as F
from PIL import *
import numpy as np
from flask import *
import requests
from io import BytesIO
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
# pool
self.pool = nn.MaxPool2d(2, 2)
# fully-connected
self.fc1 = nn.Linear(7*7*128, 500)
self.fc2 = nn.Linear(500, num_classes)
# drop-out
self.dropout = nn.Dropout(0.3)
def forward(self, x):
## Define forward behavior
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = F.relu(self.conv3(x))
x = self.pool(x)
# flatten
x = x.view(-1, 7*7*128)
x = self.dropout(x)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = pickle.load(open(r"saved_model\model_arc.pkl","rb"))
checkpoint = torch.load(r"D:\Projects\PlantsvsWeeds\saved_model\model_scratch.pt",map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
def type(image_path):
class_names = [
'Black-grass',
'Charlock',
'Cleavers',
'Common Chickweed',
'Common wheat',
'Fat Hen',
'Loose Silky-bent',
'Maize',
'Scentless Mayweed',
'ShepherdтАЩs Purse',
'Small-flowered Cranesbill',
'Sugar beet']
standard_normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def load_input_image(img):
image = img.convert('RGB')
prediction_transform = transforms.Compose([transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
standard_normalization])
image = prediction_transform(image)[:3,:,:].unsqueeze(0)
return image
def predict_image(model, class_names, img):
# load the image and return the predicted breed
img = load_input_image(img)
model = model.cpu()
model.eval()
idx = torch.argmax(model(img))
return class_names[idx]
def run_app(img):
# img = Image.open(img)
prediction = predict_image(model, class_names, img)
return prediction
p = run_app(image_path)
return p
# print(type(image_path))
# --------------------------------------------- #
# ----------- Flask App Starts Here ----------- #
# --------------------------------------------- #
app = Flask(__name__)
@app.route('/', methods=['POST','GET'])
def main():
data = request.get_json()
img_url=data["data"]
response=requests.get(img_url)
img_org=Image.open(BytesIO(response.content))
output = type(img_org)
return jsonify(output)
if __name__=="__main__":
app.run(debug=True)