上一个教程: 如何在浏览器中运行深度网络
下一个教程: 如何运行自定义 OCR 模型
| |
| 原始作者 | Dmitry Kurtaev |
| 兼容性 | OpenCV >= 3.4.1 |
简介
深度学习是一个快速发展的领域。构建神经网络的新方法通常会引入新类型的层。这些层可能是对现有层的修改,也可能是对优秀研究思想的实现。
OpenCV 允许从不同的深度学习框架导入和运行网络。其中包含了许多最流行的层。然而,你可能会遇到一个问题:你的网络无法使用 OpenCV 导入,因为网络中的某些层可能未在 OpenCV 的深度学习引擎中实现。
第一个解决方案是在 https://github.com/opencv/opencv/issues 创建一个功能请求,其中提及模型的来源和新层的类型等详细信息。如果 OpenCV 社区有此需求,则可以实现新层。
第二种方法是定义一个自定义层,以便 OpenCV 的深度学习引擎知道如何使用它。本教程旨在向您展示深度学习模型导入定制的过程。
在 C++ 中定义自定义层
深度学习层是网络管道的基本组成部分。它连接到输入数据块(blob)并产生结果到输出数据块(blob)。其中包含经过训练的权重和超参数。层的名称、类型、权重和超参数存储在训练期间由原生框架生成的文件中。如果 OpenCV 遇到未知的层类型,它将在尝试读取模型时抛出异常
Unspecified error: Can't create layer "layer_name" of type "MyType" in function getLayerInstance
要正确导入模型,您必须从 cv::dnn::Layer 派生一个类,并包含以下方法
{
public:
const int requiredOutputs,
std::vector<std::vector<int> > &outputs,
std::vector<std::vector<int> > &internals)
const CV_OVERRIDE;
};
并在导入前注册它
static inline void loadNet()
{
- 注意
MyType 是抛出异常中未实现层的类型。
让我们看看所有这些方法的作用
从 cv::dnn::LayerParams 中检索超参数。如果您的层具有可训练的权重,它们将已经存储在 Layer 的成员 cv::dnn::Layer::blobs 中。
此方法应创建您的层的一个实例,并返回一个包含它的 cv::Ptr。
const int requiredOutputs,
std::vector<std::vector<int> > &outputs,
std::vector<std::vector<int> > &internals)
const CV_OVERRIDE;
根据输入形状返回层的输出形状。您可以使用 internals 请求额外内存。
在此处实现层的逻辑。为给定输入计算输出。
- 注意
- OpenCV 管理为层分配的内存。在大多数情况下,相同的内存可以在层之间重用。因此,您的
forward 实现不应依赖于 forward 的第二次调用将在 outputs 和 internals 中拥有相同的数据。
方法链如下:OpenCV 深度学习引擎调用一次 create 方法,然后为每个创建的层调用 getMemoryShapes,然后您可以在 cv::dnn::Layer::finalize 中根据已知的输入维度进行一些准备。网络初始化后,对于网络的每个输入,仅调用 forward 方法。
- 注意
- 输入数据块(blob)的大小(例如高度、宽度或批次大小)变化会导致 OpenCV 重新分配所有内部内存。这会导致效率差距。请尝试使用固定的批次大小和图像维度来初始化和部署模型。
示例:来自 Caffe 的自定义层
让我们从 https://github.com/cdmh/deeplab-public 创建一个自定义层 Interp。它只是一个简单的尺寸调整层,接收大小为 N x C x Hi x Wi 的输入数据块(blob),并返回大小为 N x C x Ho x Wo 的输出数据块(blob),其中 N 是批次大小,C 是通道数,Hi x Wi 和 Ho x Wo 分别是输入和输出的 高 x 宽。此层没有可训练的权重,但它有超参数来指定输出大小。
例如,
layer {
name: "output"
type: "Interp"
bottom: "input"
top: "output"
interp_param {
height: 9
width: 8
}
}
这样我们的实现可能看起来像这样
{
public:
{
outWidth =
params.get<
int>(
"width", 0);
outHeight =
params.get<
int>(
"height", 0);
}
{
}
const int requiredOutputs,
std::vector<std::vector<int> > &outputs,
std::vector<std::vector<int> > &internals)
const CV_OVERRIDE
{
CV_UNUSED(requiredOutputs); CV_UNUSED(internals);
std::vector<int> outShape(4);
outShape[0] = inputs[0][0];
outShape[1] = inputs[0][1];
outShape[2] = outHeight;
outShape[3] = outWidth;
outputs.assign(1, outShape);
return false;
}
{
if (inputs_arr.depth() ==
CV_16S)
{
return;
}
std::vector<cv::Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
const float* inpData = (
float*)inp.
data;
float* outData = (float*)out.data;
const int batchSize = inp.size[0];
const int numChannels = inp.size[1];
const int inpHeight = inp.size[2];
const int inpWidth = inp.size[3];
const float rheight = (outHeight > 1) ? static_cast<float>(inpHeight - 1) / (outHeight - 1) : 0.f;
const float rwidth = (outWidth > 1) ? static_cast<float>(inpWidth - 1) / (outWidth - 1) : 0.f;
for (int h2 = 0; h2 < outHeight; ++h2)
{
const float h1r = rheight * h2;
const int h1 = static_cast<int>(h1r);
const int h1p = (h1 < inpHeight - 1) ? 1 : 0;
const float h1lambda = h1r - h1;
const float h0lambda = 1.f - h1lambda;
for (int w2 = 0; w2 < outWidth; ++w2)
{
const float w1r = rwidth * w2;
const int w1 = static_cast<int>(w1r);
const int w1p = (w1 < inpWidth - 1) ? 1 : 0;
const float w1lambda = w1r - w1;
const float w0lambda = 1.f - w1lambda;
const float* pos1 = inpData + h1 * inpWidth + w1;
float* pos2 = outData + h2 * outWidth + w2;
for (int c = 0; c < batchSize * numChannels; ++c)
{
pos2[0] =
h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) +
h1lambda * (w0lambda * pos1[h1p * inpWidth] + w1lambda * pos1[h1p * inpWidth + w1p]);
pos1 += inpWidth * inpHeight;
pos2 += outWidth * outHeight;
}
}
}
}
private:
int outWidth, outHeight;
};
接下来我们需要注册新的层类型并尝试导入模型。
示例:来自 TensorFlow 的自定义层
这是一个导入包含 tf.image.resize_bilinear 操作的网络的示例。这也是一个尺寸调整操作,但其实现与 OpenCV 或上述 Interp 不同。
让我们创建一个单层网络
inp = tf.placeholder(tf.float32, [2, 3, 4, 5], 'input')
resized = tf.image.resize_bilinear(inp, size=[9, 8], name='resize_bilinear')
OpenCV 以以下方式看待 TensorFlow 图
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
}
node {
name: "resize_bilinear/size"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\t\000\000\000\010\000\000\000"
}
}
}
}
node {
name: "resize_bilinear"
op: "ResizeBilinear"
input: "input:0"
input: "resize_bilinear/size"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "align_corners"
value {
b: false
}
}
}
library {
}
从 TensorFlow 导入自定义层旨在将所有层的 attr 放入 cv::dnn::LayerParams,但将输入 Const 数据块(blobs)放入 cv::dnn::Layer::blobs。在我们的例子中,调整大小的输出形状将存储在层的 blobs[0] 中。
{
public:
{
for (size_t i = 0; i < blobs.size(); ++i)
if (blobs.size() == 1)
{
outHeight = blobs[0].at<int>(0, 0);
outWidth = blobs[0].at<int>(0, 1);
factorHeight = factorWidth = 0;
}
else
{
factorHeight = blobs[0].at<int>(0, 0);
factorWidth = blobs[1].at<int>(0, 0);
outHeight = outWidth = 0;
}
}
{
}
const int,
std::vector<std::vector<int> > &outputs,
{
std::vector<int> outShape(4);
outShape[0] = inputs[0][0];
outShape[1] = inputs[0][1];
outShape[2] = outHeight != 0 ? outHeight : (inputs[0][2] * factorHeight);
outShape[3] = outWidth != 0 ? outWidth : (inputs[0][3] * factorWidth);
outputs.assign(1, outShape);
return false;
}
{
std::vector<cv::Mat> outputs;
outputs_arr.getMatVector(outputs);
if (!outWidth && !outHeight)
{
outHeight = outputs[0].size[2];
outWidth = outputs[0].size[3];
}
}
{
if (inputs_arr.depth() ==
CV_16S)
{
return;
}
std::vector<cv::Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
const float* inpData = (
float*)inp.
data;
float* outData = (float*)out.data;
const int batchSize = inp.size[0];
const int numChannels = inp.size[1];
const int inpHeight = inp.size[2];
const int inpWidth = inp.size[3];
float heightScale = static_cast<float>(inpHeight) / outHeight;
float widthScale = static_cast<float>(inpWidth) / outWidth;
for (int b = 0; b < batchSize; ++b)
{
for (int y = 0; y < outHeight; ++y)
{
float input_y = y * heightScale;
int y0 = static_cast<int>(std::floor(input_y));
int y1 = std::min(y0 + 1, inpHeight - 1);
for (int x = 0; x < outWidth; ++x)
{
float input_x = x * widthScale;
int x0 = static_cast<int>(std::floor(input_x));
int x1 = std::min(x0 + 1, inpWidth - 1);
for (int c = 0; c < numChannels; ++c)
{
float interpolation =
inpData[offset(inp.size, c, x0, y0, b)] * (1 - (input_y - y0)) * (1 - (input_x - x0)) +
inpData[offset(inp.size, c, x0, y1, b)] * (input_y - y0) * (1 - (input_x - x0)) +
inpData[offset(inp.size, c, x1, y0, b)] * (1 - (input_y - y0)) * (input_x - x0) +
inpData[offset(inp.size, c, x1, y1, b)] * (input_y - y0) * (input_x - x0);
outData[offset(out.size, c, x, y, b)] = interpolation;
}
}
}
}
}
private:
static inline int offset(
const cv::MatSize& size,
int c,
int x,
int y,
int b)
{
}
int outWidth, outHeight, factorWidth, factorHeight;
};
接下来我们注册一个层并尝试导入模型。
在 Python 中定义自定义层
以下示例展示了如何在 Python 中自定义 OpenCV 的层。
让我们考虑 Holistically-Nested Edge Detection 深度学习模型。该模型与当前版本的 Caffe 框架相比,只有一个不同之处。Crop 层接收两个输入数据块(blob),并裁剪第一个数据块以匹配第二个数据块的空间维度,过去是从中心裁剪。现在 Caffe 的层是从左上角裁剪的。因此,使用最新版本的 Caffe 或 OpenCV,您将得到带有填充边界的偏移结果。
接下来我们将把 OpenCV 中执行左上角裁剪的 Crop 层替换为一个中心裁剪的层。
- 创建一个包含
getMemoryShapes 和 forward 方法的类
class CropLayer(object)
def __init__(self, params, blobs)
self.xstart = 0
self.xend = 0
self.ystart = 0
self.yend = 0
def getMemoryShapes(self, inputs)
inputShape, targetShape = inputs[0], inputs[1]
batchSize, numChannels = inputShape[0], inputShape[1]
height, width = targetShape[2], targetShape[3]
self.ystart = (inputShape[2] - targetShape[2]) // 2
self.xstart = (inputShape[3] - targetShape[3]) // 2
self.yend = self.ystart + height
self.xend = self.xstart + width
return [[batchSize, numChannels, height, width]]
def forward(self, inputs)
return [inputs[0][:,:,self.ystart:self.yend,self.xstart:self.xend]]
- 注意
- 两个方法都应返回列表。
cv.dnn_registerLayer('Crop', CropLayer)
就是这样!我们已经将 OpenCV 已实现的层替换为自定义层。您可以在源代码中找到完整的脚本。