-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstreetnet.py
More file actions
140 lines (114 loc) · 4.44 KB
/
streetnet.py
File metadata and controls
140 lines (114 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import json
import matplotlib.pyplot as plt
import mss
import numpy as np
import torch
from PIL import Image, ImageTk, ImageDraw
from tkinter import Tk, Canvas, NW
from torchvision import transforms
class SaveFeatures:
features = None
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.features = (output.cpu()).data.numpy()
def remove(self):
self.hook.remove()
def take_screenshot():
with mss.mss() as sct:
monitor = {"left": 100, "top": 130, "width": 1305, "height": 810}
sct_img = sct.grab(monitor)
image = Image.frombytes("RGB", sct_img.size, sct_img.rgb)
image = image.convert("RGB")
width, height = (29 * 20, 18 * 20)
image = image.resize((width, height), Image.Resampling.LANCZOS)
new_width, new_height = (29 * 18, 18 * 18)
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
image = image.crop((left, top, right, bottom)) # centre crop
image_tensor = transforms.ToTensor()(image)
image_tensor = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(
image_tensor
)
return image_tensor, image
def make_prediction():
image_tensor, image = take_screenshot()
if next(model.parameters()).is_cuda:
image_tensor = image_tensor.to(device="cuda")
data = torch.unsqueeze(
image_tensor, dim=0
) # Make image tensor into a batch of size one, shape is [1,3,H,W]
outputs = model(data)
softmax = torch.nn.Softmax(dim=1)
preds, _ = outputs.topk(5)
softmax_preds, indices = softmax(outputs).topk(5)
text_label = ""
for i in range(5):
text_label += f"{class_names[indices[0, i].item()]} {round(softmax_preds[0, i].item() * 100, 2)}% - {round(preds[0, i].item(), 2)}\n"
# Heatmap
weight_softmax_params = list(model.fc.parameters())
weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
_, nc, h, w = activated_features.features.shape
cam = weight_softmax[indices[0, 0].item()].dot(
activated_features.features.reshape((nc, h * w))
)
cam = cam.reshape(h, w)
cam = cam - np.min(cam)
cam_img = cam / np.max(cam)
cm = plt.get_cmap("jet")
cam_img = cm(cam_img)
heatmap_image = Image.fromarray((cam_img[:, :, :3] * 255).astype(np.uint8))
heatmap_image = heatmap_image.resize(image.size, Image.Resampling.LANCZOS)
# Overlay heatmap on top of original image
mask_im = Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask_im)
draw.rectangle((0, 0, image.size[0], image.size[1]), fill=100)
image.paste(heatmap_image, (0, 0), mask_im)
return text_label, image.copy()
def update_window():
global preview_image
c_text, c_img = make_prediction()
canvas.delete("all")
preview_image = ImageTk.PhotoImage(c_img.resize((58 * 5, 36 * 5)))
canvas.create_text(9, 9, anchor=NW, text=c_text)
canvas.create_image(9, 90, anchor=NW, image=preview_image)
root.after(500, update_window) # reschedule event
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"model_path",
help="path to model, currently models/model_50.pt or models/model_18.pt",
)
parser.add_argument(
"--use-cuda", action="store_true", help="run on gpu instead of cpu"
)
args = parser.parse_args()
assert (
torch.cuda.is_available() or not args.use_cuda
), "Requested to use cuda, however cuda is not available"
device = "cuda" if torch.cuda.is_available() and args.use_cuda else "cpu"
counties_path = "countries.json"
class_names = json.load(open("countries.json"))
try:
model = torch.load(args.model_path, map_location=device)
except FileNotFoundError as e:
raise FileNotFoundError(
f"No such file or directory: '{args.model_path}'. Did you mean 'models/{args.model_path}'?"
) from e
model.eval()
final_layer = model.layer4
activated_features = SaveFeatures(final_layer)
root = Tk()
canvas = Canvas(root, width=307, height=280)
preview_image = None
root.title("StreetNet AI")
root.resizable(False, False)
root.geometry("307x280+1573+200")
canvas.pack()
canvas.create_text(50, 10, text="Loading...")
root.wm_attributes("-topmost", 1)
root.after(500, update_window)
root.mainloop()