import java.util.Random;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.TermCriteria;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import org.opencv.ml.Ml;
import org.opencv.ml.SVM;
public class NonLinearSVMsDemo {
public static final int NTRAINING_SAMPLES = 100;
public static final float FRAC_LINEAR_SEP = 0.9f;
public static void main(String[] args) {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
System.out.println("\n--------------------------------------------------------------------------");
System.out.println("此程序展示了用于非线性可分离数据的支持向量机。");
System.out.println("--------------------------------------------------------------------------\n");
int width = 512, height = 512;
Mat I = Mat.zeros(height, width, CvType.CV_8UC3);
Mat trainData = new Mat(2 * NTRAINING_SAMPLES, 2, CvType.CV_32F);
Mat labels = new Mat(2 * NTRAINING_SAMPLES, 1, CvType.CV_32S);
Random rng = new Random(100); // 随机值生成类
int nLinearSamples = (int) (FRAC_LINEAR_SEP * NTRAINING_SAMPLES);
Mat trainClass = trainData.rowRange(0, nLinearSamples);
Mat c = trainClass.colRange(0, 1);
float[] cData = new float[(int) (c.total() * c.channels())];
double[] cDataDbl = rng.doubles(cData.length, 0, 0.4f * width).toArray();
for (int i = 0; i < cData.length; i++) {
cData[i] = (float) cDataDbl[i];
}
c.put(0, 0, cData);
c = trainClass.colRange(1, 2);
cData = new float[(int) (c.total() * c.channels())];
cDataDbl = rng.doubles(cData.length, 0, height).toArray();
for (int i = 0; i < cData.length; i++) {
cData[i] = (float) cDataDbl[i];
}
c.put(0, 0, cData);
trainClass = trainData.rowRange(2 * NTRAINING_SAMPLES - nLinearSamples, 2 * NTRAINING_SAMPLES);
c = trainClass.colRange(0, 1);
cData = new float[(int) (c.total() * c.channels())];
cDataDbl = rng.doubles(cData.length, 0.6 * width, width).toArray();
for (int i = 0; i < cData.length; i++) {
cData[i] = (float) cDataDbl[i];
}
c.put(0, 0, cData);
c = trainClass.colRange(1, 2);
cData = new float[(int) (c.total() * c.channels())];
cDataDbl = rng.doubles(cData.length, 0, height).toArray();
for (int i = 0; i < cData.length; i++) {
cData[i] = (float) cDataDbl[i];
}
c.put(0, 0, cData);
trainClass = trainData.rowRange(nLinearSamples, 2 * NTRAINING_SAMPLES - nLinearSamples);
c = trainClass.colRange(0, 1);
cData = new float[(int) (c.total() * c.channels())];
cDataDbl = rng.doubles(cData.length, 0.4 * width, 0.6 * width).toArray();
for (int i = 0; i < cData.length; i++) {
cData[i] = (float) cDataDbl[i];
}
c.put(0, 0, cData);
c = trainClass.colRange(1, 2);
cData = new float[(int) (c.total() * c.channels())];
cDataDbl = rng.doubles(cData.length, 0, height).toArray();
for (int i = 0; i < cData.length; i++) {
cData[i] = (float) cDataDbl[i];
}
c.put(0, 0, cData);
labels.rowRange(0, NTRAINING_SAMPLES).setTo(new Scalar(1)); // 类1
labels.rowRange(NTRAINING_SAMPLES, 2 * NTRAINING_SAMPLES).setTo(new Scalar(2)); // 类2
System.out.println("开始训练过程");
SVM svm = SVM.create();
svm.setType(SVM.C_SVC);
svm.setC(0.1);
svm.setKernel(SVM.LINEAR);
svm.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, (int) 1e7, 1e-6));
svm.train(trainData, Ml.ROW_SAMPLE, labels);
System.out.println("训练过程完成");
byte[] IData = new byte[(int) (I.total() * I.channels())];
Mat sampleMat = new Mat(1, 2, CvType.CV_32F);
float[] sampleMatData = new float[(int) (sampleMat.total() * sampleMat.channels())];
for (int i = 0; i < I.rows(); i++) {
for (int j = 0; j < I.cols(); j++) {
sampleMatData[0] = j;
sampleMatData[1] = i;
sampleMat.put(0, 0, sampleMatData);
float response = svm.predict(sampleMat);
if (response == 1) {
IData[(i * I.cols() + j) * I.channels()] = 0;
IData[(i * I.cols() + j) * I.channels() + 1] = 100;
IData[(i * I.cols() + j) * I.channels() + 2] = 0;
} else if (response == 2) {
IData[(i * I.cols() + j) * I.channels()] = 100;
IData[(i * I.cols() + j) * I.channels() + 1] = 0;
IData[(i * I.cols() + j) * I.channels() + 2] = 0;
}
}
}
I.put(0, 0, IData);
int thick = -1;
int lineType = Imgproc.LINE_8;
float px, py;
float[] trainDataData = new float[(int) (trainData.total() * trainData.channels())];
trainData.get(0, 0, trainDataData);
for (int i = 0; i < NTRAINING_SAMPLES; i++) {
px = trainDataData[i * trainData.cols()];
py = trainDataData[i * trainData.cols() + 1];
Imgproc.circle(I, new Point(px, py), 3, new Scalar(0, 255, 0), thick, lineType, 0);
}
for (int i = NTRAINING_SAMPLES; i < 2 * NTRAINING_SAMPLES; ++i) {
px = trainDataData[i * trainData.cols()];
py = trainDataData[i * trainData.cols() + 1];
Imgproc.circle(I, new Point(px, py), 3, new Scalar(255, 0, 0), thick, lineType, 0);
}
thick = 2;
Mat sv = svm.getUncompressedSupportVectors();
float[] svData = new float[(int) (sv.total() * sv.channels())];
sv.get(0, 0, svData);
for (int i = 0; i < sv.rows(); i++) {
Imgproc.circle(I, new Point(svData[i * sv.cols()], svData[i * sv.cols() + 1]), 6, new Scalar(128, 128, 128),
thick, lineType, 0);
}
Imgcodecs.imwrite("result.png", I); // 保存图像
HighGui.imshow("非线性训练数据的SVM", I); // 显示给用户
HighGui.waitKey();
System.exit(0);
}
}
Point2i Point
**定义**: types.hpp:209
Scalar_< double > Scalar
**定义**: types.hpp:709