-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbackend.py
More file actions
127 lines (96 loc) · 4.07 KB
/
backend.py
File metadata and controls
127 lines (96 loc) · 4.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from flask import Flask, jsonify, request
import random, os, shutil, base64
import src.AI_detect as AI_detect
import src.extract_images as extract_images
from flask_cors import CORS, cross_origin
app = Flask(__name__)
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'
'''
Receives an image using a post request, runs the model on it, and returns the result.
It also receives the classes for classification.
'''
@app.route("/recognize/images/", methods=['POST'])
@cross_origin()
def recognize_images_post():
filenames = request.form.getlist("filenames")
if len(filenames) == 0:
return jsonify(error="No files found"), 400
if any(fname not in request.files for fname in filenames):
return jsonify(error="Files and filenames doesn't match"), 400
if any(request.files[fname] == '' for fname in filenames):
return jsonify(error="No selected file"), 400
# save the file
stored = os.path.join(".uploads", str(random.randint(0, 1e9)))
os.makedirs(stored)
for fname in filenames:
filepath = os.path.join(stored, fname)
if not os.path.exists(os.sep.join(filepath.split(os.sep)[:-1])):
os.makedirs(os.sep.join(filepath.split(os.sep)[:-1]))
request.files[fname].save(filepath)
#Convert to images and find all files
basedir = os.path.join(".extracted", str(random.randint(0, 1e9)))
if not os.path.exists(basedir):
os.makedirs(basedir)
extract_images.main(basedir, stored, (request.form.get("fast").lower() == 'true'))
image_paths = extract_images.find_files(basedir)
#Run the model and returns the results.
ret = {}
if 'classes' not in request.form:
shutil.rmtree(stored)
shutil.rmtree(basedir)
return jsonify(error="Missing classification classes"), 400
if len(request.form.getlist("classes")) == 1:
ret = AI_detect.predict_text(image_paths, request.form.getlist("classes"), len(basedir) + len(stored) + 2)
shutil.rmtree(stored)
shutil.rmtree(basedir)
return ret
for img in image_paths:
ret[img[len(basedir) + len(stored) + 2:]] = AI_detect.predict_photo(img, request.form.getlist("classes"))
shutil.rmtree(stored)
shutil.rmtree(basedir)
return ret
'''
Instead of sending the file, sends a path to a directory
@TODO: support url instead of path, or add a module that downloads the directory...
When only one class is received, the code sends a probability distribution on the images instead.
'''
@app.route("/recognize/images/", methods=['GET'])
@cross_origin()
def recognize_images_get():
path = request.args.get("path")
if path is None:
return jsonify("Path wasn't received"), 400
fast = (request.args.get("fast").lower() == 'true')
possible_classes = request.args.getlist("classes")
#Converts everything to images and saves it inside basedir.
basedir = os.path.join(".extracted", str(random.randint(0, 1e9)))
os.makedirs(basedir)
extract_images.main(basedir, path, fast)
#Runs the model on every image, and returns a json
#The json is of the form {image_name : prediction}
image_paths = extract_images.find_files(basedir)
ret = {}
if len(possible_classes) == 1:
ret = AI_detect.predict_text(image_paths, possible_classes[0], len(basedir) + 1)
shutil.rmtree(basedir)
return ret
for img in image_paths:
ret[img[len(basedir) + 1:]] = AI_detect.predict_photo(img, possible_classes)
shutil.rmtree(basedir)
return ret
'''
Sends text query to chatgpt 3.5, and returns the result.
'''
@app.route("/query/chatgpt/", methods=['GET'])
@cross_origin()
def query_chatgpt():
que = request.args.get("query")
if que is None:
return jsonify(error="Missing Query"), 400
if len(que) > 400:
return jsonify(error="Query exceeds maximal length (400)"), 400
ans = AI_detect.query_text(que)
return ans
if __name__ == "__main__":
app.run(ssl_context=("keys/public.crt", "keys/private.key"), host="0.0.0.0", port=54362)