1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include <opencv2/opencv.hpp> #include <opencv2/dnn.hpp> #include <iostream> #include <fstream>
using namespace cv; using namespace cv::dnn; using namespace std;
String modelTxt = "bvlc_googlenet.prototxt"; String modelBin = "bvlc_googlenet.caffemodel"; String labelFile = "classification_classes_ILSVRC2012.txt";
vector<String> readClasslabels();
int main(int argc, char** argv) { Mat testImage = imread("dog.jpg"); if (testImage.empty()) { printf("could not load image...\n"); return -1; } Net net = dnn::readNetFromCaffe(modelTxt, modelBin); if (net.empty()) { std::cerr << "Can't load network by using the following files: " << std::endl; std::cerr << "prototxt: " << modelTxt << std::endl; std::cerr << "caffemodel: " << modelBin << std::endl; return -1; } net.setPreferableBackend(dnn::DNN_BACKEND_OPENCV); net.setPreferableTarget(dnn::DNN_TARGET_CPU); vector<String> labels = readClasslabels(); Mat inputBlob = blobFromImage(testImage, 1, Size(224, 224), Scalar(104, 117, 123)); Mat prob; for (int i = 0; i < 10; i++) { net.setInput(inputBlob, "data"); prob = net.forward("prob"); } Mat probMat = prob.reshape(1, 1); Point classNumber; double classProb; minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber); int classIdx = classNumber.x; printf("\n current image classification : %s, possible : %.2f \n", labels.at(classIdx).c_str(), classProb); putText(testImage, labels.at(classIdx), Point(20, 20), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 0, 255), 2, 8); imshow("Image Category", testImage); waitKey(0); return 0; }
vector<String> readClasslabels() { std::vector<String> classNames; std::ifstream fp(labelFile); if (!fp.is_open()) { std::cerr << "File with classes labels not found: " << labelFile << std::endl; exit(-1); } std::string name; while (!fp.eof()) { std::getline(fp, name); if (name.length()) classNames.push_back(name.substr(name.find(' ') + 1)); } fp.close(); return classNames; }
|