OpenCV 4.12.0
开源计算机视觉
加载中...
搜索中...
无匹配项
samples/dnn/text_detection.cpp
/*
文本检测模型: https://github.com/argman/EAST
下载链接: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
文本识别模型可以直接在这里下载
下载链接: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
和 doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown
如何从pb转换为onnx
使用这里的类: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
import torch
from models.crnn import CRNN
model = CRNN(32, 1, 37, 256)
model.load_state_dict(torch.load('crnn.pth'))
dummy_input = torch.randn(1, 1, 32, 100)
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
更多信息请参考 doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown 和 doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown
*/
#include <iostream>
#include <fstream>
#include <opencv2/dnn.hpp>
using namespace cv;
using namespace cv::dnn;
const char* keys =
"{ help h | | Print help message. }"
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
"{ detModel dmp | | Path to a binary .pb file contains trained detector network.}"
"{ width | 320 | Preprocess input image by resizing to a specific width. It should be a multiple of 32. }"
"{ height | 320 | Preprocess input image by resizing to a specific height. It should be a multiple of 32. }"
"{ thr | 0.5 | Confidence threshold. }"
"{ nms | 0.4 | Non-maximum suppression threshold. }"
"{ recModel rmp | | Path to a binary .onnx file contains trained CRNN text recognition model. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
"{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }"
"{ vocabularyPath vp | alphabet_36.txt | Path to benchmarks for evaluation. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result);
int main(int argc, char** argv)
{
// 解析命令行参数。
CommandLineParser parser(argc, argv, keys);
parser.about("使用此脚本运行TensorFlow实现 (https://github.com/argman/EAST) 的 "
"EAST: 一种高效且准确的场景文本检测器 (https://arxiv.org/abs/1704.03155v2)");
if (argc == 1 || parser.has("help"))
{
parser.printMessage();
return 0;
}
float confThreshold = parser.get<float>("thr");
float nmsThreshold = parser.get<float>("nms");
int width = parser.get<int>("width");
int height = parser.get<int>("height");
int imreadRGB = parser.get<int>("RGBInput");
String detModelPath = parser.get<String>("detModel");
String recModelPath = parser.get<String>("recModel");
String vocPath = parser.get<String>("vocabularyPath");
if (!parser.check())
{
parser.printErrors();
return 1;
}
// 加载网络。
CV_Assert(!detModelPath.empty() && !recModelPath.empty());
TextDetectionModel_EAST detector(detModelPath);
detector.setConfidenceThreshold(confThreshold)
.setNMSThreshold(nmsThreshold);
TextRecognitionModel recognizer(recModelPath);
// 加载词汇表
CV_Assert(!vocPath.empty());
std::ifstream vocFile;
vocFile.open(samples::findFile(vocPath));
CV_Assert(vocFile.is_open());
String vocLine;
std::vector<String> vocabulary;
while (std::getline(vocFile, vocLine)) {
vocabulary.push_back(vocLine);
}
recognizer.setVocabulary(vocabulary);
recognizer.setDecodeType("CTC-greedy");
// 识别参数
double recScale = 1.0 / 127.5;
Scalar recMean = Scalar(127.5, 127.5, 127.5);
Size recInputSize = Size(100, 32);
recognizer.setInputParams(recScale, recInputSize, recMean);
// 检测参数
double detScale = 1.0;
Size detInputSize = Size(width, height);
Scalar detMean = Scalar(123.68, 116.78, 103.94);
bool swapRB = true;
detector.setInputParams(detScale, detInputSize, detMean, swapRB);
// 打开视频文件或图像文件或摄像头流。
bool openSuccess = parser.has("input") ? cap.open(parser.get<String>("input")) : cap.open(0);
CV_Assert(openSuccess);
static const std::string kWinName = "EAST: 一种高效且准确的场景文本检测器";
Mat frame;
while (waitKey(1) < 0)
{
cap >> frame;
if (frame.empty())
{
break;
}
std::cout << frame.size << std::endl;
// 检测
std::vector< std::vector<Point> > detResults;
detector.detect(frame, detResults);
Mat frame2 = frame.clone();
if (detResults.size() > 0) {
// 文本识别
Mat recInput;
if (!imreadRGB) {
cvtColor(frame, recInput, cv::COLOR_BGR2GRAY);
} else {
recInput = frame;
}
std::vector< std::vector<Point> > contours;
for (uint i = 0; i < detResults.size(); i++)
{
const auto& quadrangle = detResults[i];
CV_CheckEQ(quadrangle.size(), (size_t)4, "");
contours.emplace_back(quadrangle);
std::vector<Point2f> quadrangle_2f;
for (int j = 0; j < 4; j++)
quadrangle_2f.emplace_back(quadrangle[j]);
Mat cropped;
fourPointsTransform(recInput, &quadrangle_2f[0], cropped);
std::string recognitionResult = recognizer.recognize(cropped);
std::cout << i << ": '" << recognitionResult << "'" << std::endl;
putText(frame2, recognitionResult, quadrangle[3], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255), 2);
}
polylines(frame2, contours, true, Scalar(0, 255, 0), 2);
}
imshow(kWinName, frame2);
}
return 0;
}
void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result)
{
const Size outputSize = Size(100, 32);
Point2f targetVertices[4] = {
Point(0, outputSize.height - 1),
Point(0, 0), Point(outputSize.width - 1, 0),
Point(outputSize.width - 1, outputSize.height - 1)
};
Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
warpPerspective(frame, result, rotationMatrix, outputSize);
}
#define CV_CheckEQ(v1, v2, msg)
支持以下类型的值:int、float、double。
定义 check.hpp:118
如果数组没有元素,则返回 true。
int64_t int64
n 维密集数组类
定义 mat.hpp:830
CV_NODISCARD_STD Mat clone() const
创建数组及其底层数据的完整副本。
MatSize size
定义 mat.hpp:2187
用于指定图像或矩形大小的模板类。
Definition types.hpp:335
_Tp height
高度
Definition types.hpp:363
_Tp width
宽度
Definition types.hpp:362
用于从视频文件、图像序列或摄像头捕获视频的类。
Definition videoio.hpp:772
virtual bool open(const String &filename, int apiPreference=CAP_ANY)
打开一个视频文件或捕获设备或IP视频流以进行视频捕获。
此类表示与EAST模型兼容的文本检测DL网络的高级API。
定义 dnn.hpp:1840
此类表示文本识别网络的高级API。
定义 dnn.hpp:1684
std::string String
定义 cvstd.hpp:151
uint32_t uint
定义 interface.h:42
#define CV_Assert(expr)
在运行时检查条件,如果失败则抛出异常。
定义 base.hpp:423
void imshow(const String &winname, InputArray mat)
在指定窗口中显示图像。
int waitKey(int delay=0)
等待按键按下。
void cvtColor(InputArray src, OutputArray dst, int code, int dstCn=0, AlgorithmHint hint=cv::ALGO_HINT_DEFAULT)
将图像从一个颜色空间转换为另一个颜色空间。
@ COLOR_BGR2GRAY
在RGB/BGR和灰度之间转换,颜色转换
定义 imgproc.hpp:557
void putText(InputOutputArray img, const String &text, Point org, int fontFace, double fontScale, Scalar color, int thickness=1, int lineType=LINE_8, bool bottomLeftOrigin=false)
绘制文本字符串。
void polylines(InputOutputArray img, InputArrayOfArrays pts, bool isClosed, const Scalar &color, int thickness=1, int lineType=LINE_8, int shift=0)
绘制多条多边形曲线。
Mat getPerspectiveTransform(InputArray src, InputArray dst, int solveMethod=DECOMP_LU)
计算来自对应的四对点的透视变换。
void warpPerspective(InputArray src, OutputArray dst, InputArray M, Size dsize, int flags=INTER_LINEAR, int borderMode=BORDER_CONSTANT, const Scalar &borderValue=Scalar())
对图像应用透视变换。
int main(int argc, char *argv[])
定义 highgui_qt.cpp:3
定义 all_layers.hpp:47
定义 core.hpp:107