-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdnn_maskrcnn.cpp
More file actions
131 lines (113 loc) · 5.16 KB
/
dnn_maskrcnn.cpp
File metadata and controls
131 lines (113 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <opencv2/highgui.hpp> // cv::imshow(...), ...
#include <opencv2/dnn/dnn.hpp> // object detect functions
#include <opencv2/imgproc.hpp> // cv::rectangle(...)
#include <random>
#include <string>
#include <iostream>
#include <filesystem>
#include <stdexcept>
#include <fstream>
const std::filesystem::path IMAGE_PATH = "resources/bike_person.jpg";
int main()
{
// Files we need for object detection
std::filesystem::path config_path = "resources/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt";
std::filesystem::path model_path = "resources/mask_rcnn_inception_v2_coco_2018_01_28_weights.pb";
std::filesystem::path label_path = "resources/coco_names.txt";
if(!std::filesystem::exists(config_path))
throw std::runtime_error("Could not find configuration .pbtxt file");
if(!std::filesystem::exists(model_path))
throw std::runtime_error("Could not find model weights .pb file");
if( !std::filesystem::exists(label_path))
throw std::runtime_error("Could not find label names file");
// Load image to process
if( !std::filesystem::exists(IMAGE_PATH))
throw std::runtime_error("Could not find image file");
cv::Mat img = cv::imread(IMAGE_PATH.string());
int height = img.size[0], width = img.size[1];
// Load and set-up model
cv::dnn::Net net = cv::dnn::readNetFromTensorflow(model_path.string(),config_path.string());
cv::Mat blob; // Create 4D Mat tensor -> blob
cv::dnn::blobFromImage(img,blob,1.0,cv::Size(),cv::Scalar(),true);
net.setInput(blob);
std::vector<cv::Mat> output;
std::vector<std::string> output_names = {"detection_out_final","detection_masks"};
net.forward(output,output_names);
// Load label names
std::ifstream label_file(label_path);
std::vector<std::string> label_names;
std::string label_name;
while(label_file >> label_name)
label_names.push_back(label_name);
if(label_names.empty())
throw std::runtime_error("Failed to load any label names");
// Segment objects
cv::Mat allboxes = output[0].reshape(1,output[0].total()/7);
cv::Mat allmasks = output[1];
const cv::Size mask_size(allmasks.size[2],allmasks.size[3]);
float conf_threshold = 0.6f;
std::vector<int> class_ids;
std::vector<float> confidences;
std::vector<cv::Rect> boxes;
std::vector<cv::Mat> masks;
for(int i = 0; i < allboxes.size[0]; i++){
float confidence = allboxes.at<float>(i, 2);
if (confidence > conf_threshold){
int class_id = static_cast<int>(allboxes.at<float>(i, 1));
class_ids.push_back(class_id);
confidences.push_back(confidence);
int x = static_cast<int>(width * allboxes.at<float>(i, 3));
int y = static_cast<int>(height * allboxes.at<float>(i, 4));
int x2 = static_cast<int>(width * allboxes.at<float>(i, 5));
int y2 = static_cast<int>(height * allboxes.at<float>(i, 6));
boxes.emplace_back(cv::Point2i(x,y),cv::Point2i(x2,y2));
masks.emplace_back(mask_size,CV_32FC1,allmasks.ptr<float>(i,class_id));
}
}
// Perform non-maximum suppression (of repeatedly detected objects)
std::vector<int> indices;
float nms_threshold = 0.35f; // higher -> keep more boxes
cv::dnn::NMSBoxes(boxes,confidences,conf_threshold,nms_threshold,indices);
// Create colors for each class label
std::vector<cv::Scalar> colors(label_names.size());
std::random_device rd;
std::default_random_engine engine(rd());
std::uniform_int_distribution<int> distr(0,255);
for(int i = 0; i < label_names.size(); i++)
colors[i] = cv::Scalar(distr(engine),distr(engine),distr(engine));
// Draw boxes around detected objects and name them
double font_scale = 1.25;
int font_thickness = 1;
int line_thickness = 2;
int contour_thickness = 2;
double mask_threshold = 0.4;
for(auto i : indices) {
// Draw the bounding box
cv::Rect& box = boxes[i];
cv::rectangle(img,box,cv::Scalar(0,255,255),line_thickness);
// Label the bounding box
int class_id = class_ids[i];
std::string label_name(label_names[class_id]);
cv::putText(img,label_name,cv::Point(box.x+10,box.y+40),\
cv::HersheyFonts::FONT_HERSHEY_TRIPLEX,font_scale,\
cv::Scalar(255,255,0),font_thickness,cv::LineTypes::LINE_AA);
// Draw the segmentation mask
cv::Scalar color = colors[class_id];
// Region of interest (current object box/add color)
cv::Mat roi = 0.4*color + 0.6*img(box);
// Process mask
cv::Mat& mask = masks[i];
cv::resize(mask,mask,cv::Size(box.width, box.height));
cv::threshold(mask,mask,mask_threshold,255,cv::THRESH_BINARY);
mask.convertTo(mask,CV_8UC1);
std::vector<cv::Mat> contours;
cv::Mat hierarchy;
cv::findContours(mask,contours,hierarchy,cv::RETR_EXTERNAL,cv::CHAIN_APPROX_TC89_KCOS);
cv::drawContours(roi,contours,-1,color,contour_thickness,cv::LINE_AA,hierarchy);
roi.copyTo(img(box),mask);
}
cv::namedWindow("Image",cv::WindowFlags::WINDOW_KEEPRATIO);
cv::imshow("Image",img);
cv::waitKey(0);
return 0;
}