-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
72 lines (55 loc) · 2.1 KB
/
main.py
File metadata and controls
72 lines (55 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from flask import Flask
from flask import render_template
from flask import request
import os
import logging
from werkzeug.utils import secure_filename
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input
from keras.applications.resnet50 import decode_predictions
app = Flask(__name__)
# configre logging
logging.basicConfig(level=logging.DEBUG)
# Initialize the model
model = ResNet50(weights='imagenet')
# Directory to store uploaded images
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
IMAGE_DIR = os.path.join(BASE_DIR,'images')
# Ensure the images directory exists
if not os.path.exists(IMAGE_DIR):
os.makedirs(IMAGE_DIR)
@app.route('/', methods=['GET'])
def hello_world():
return render_template('index.html')
@app.route('/', methods=['POST'])
def predict():
if 'imagefile' not in request.files:
return "No file part"
imagefile = request.files['imagefile']
if imagefile.filename == '':
return "No selected file"
# Ensure the filename is safe and save the file
filename = secure_filename(imagefile.filename)
image_path = os.path.join(IMAGE_DIR, filename)
logging.debug(f'Saving image to: {image_path}')
try:
imagefile.save(image_path)
# Load and preprocess the image
image = load_img(image_path, target_size=(224, 224))
image = img_to_array(image)
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
image = preprocess_input(image)
# Make predictions
yhat = model.predict(image)
label = decode_predictions(yhat)
label = label[0][0]
# Format the classification result
classification = '%s (%.2f%%)' % (label[1], label[2] * 100)
except Exception as e:
return str(e)
return render_template('index.html', prediction=classification)
if __name__ == '__main__':
port = int(os.environ.get("PORT", 5000)) # Get port from environment variable
app.run(host="0.0.0.0", port=port)