Today post will explain how OpenCV SVM can used for Handwritten Digit classification. To lean more about SVM see OpenCV documentation. Here I am explaining how SVM used to train handwritten digit and latter classify the test sample. So for simplicity I used only 2 digit, 1 and 0. The demo for this tutorial can be found in below video.
As a first step create training sample, the image with 0 and one and stored in separate directory. I have stored the data in folders data/one/ and data/zero/, Below is the sample image I have used for the demo, it can download from my git-hub page here.
As a first step create training sample, the image with 0 and one and stored in separate directory. I have stored the data in folders data/one/ and data/zero/, Below is the sample image I have used for the demo, it can download from my git-hub page here.
Data to Train
Initialise SVM
Ptr svm;
//Create the SVM
svm = SVM::create();
svm->setType(ml::SVM::C_SVC);
svm->setKernel(ml::SVM::LINEAR);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
Create Training Data
Here we load the handwritten digit from disk and assign corresponding label. Note the flag 0 of imread which wich will load the image as gray-scale. This step include resize image, converting image to floating point number and make it continues for training SVM etc..
Mat trainingDataMat;
Mat label_array;
String imgName;
for(int i=1;i<11;i++){
//Create one data and label
imgName = format("data/one/%d.png",i);
Mat src = imread(imgName,0); //load image as gray scale
Mat tmp1, tmp2;
resize(src,tmp1, Size(72,72), 0,0,INTER_LINEAR );
tmp1.convertTo(tmp2,CV_32FC1);
trainingDataMat.push_back(tmp2.reshape(1,1));
label_array.push_back(1);
//Create zero data and label
imgName = format("data/zero/%d.png",i);
src = imread(imgName,0);
resize(src,tmp1, Size(72,72), 0,0,INTER_LINEAR );
tmp1.convertTo(tmp2,CV_32FC1);
trainingDataMat.push_back(tmp2.reshape(1,1));
label_array.push_back(0);
}
Mat labelsMat;
labelsMat=label_array.reshape(1,1); //make continuous
Train SVM
From above step we got pre-processed data for training, that is input data and corresponding labels.
svm->train( trainingDataMat , ml::ROW_SAMPLE , labelsMat );
Predict the input
Before doing the predication, we have to do the same pre-processing steps that has performed while training, so we do the following , convert input image to grayscale, resize, convert to float and reshape.
Mat src = imread("testImg.png",1);
Mat gray;
cvtColor(src,gray,COLOR_BGR2GRAY);
Mat tmp1, tmp2;
resize(gray,tmp1, Size(72,72), 0,0,INTER_LINEAR );
tmp1.convertTo(tmp2,CV_32FC1);
float prediction = svm->predict(tmp2.reshape(1,1));
if(prediction ==1){
putText(src,"prediction = 1", Point(10, 50), FONT_HERSHEY_PLAIN, 5, CV_RGB(255,0,0), 2.8);
}
else if(prediction == 0){
putText(src,"prediction = 0", Point(10, 50), FONT_HERSHEY_PLAIN, 5, CV_RGB(255,0,0), 2.8);
}
Full SourceCode
And finally, here is the full source code which will train the SVM modal and user able to draw digit using mouse on window and SVM will predict the digit whether it's 0 or 1
#include
#include
using namespace cv;
using namespace std;
using namespace cv::ml;
const char *windowName = "SVM Example";
Ptr svm;
bool Clicked;
Mat srcImg;
void predictDigit() {
Mat gray;
// String imgName = format("data/test/%d.png",2);
// srcImg = imread(imgName,1);
cvtColor(srcImg, gray, COLOR_BGR2GRAY);
Mat tmp1, tmp2;
resize(gray, tmp1, Size(72, 72), 0, 0, INTER_LINEAR);
tmp1.convertTo(tmp2, CV_32FC1);
float prediction = svm->predict(tmp2.reshape(1, 1));
Mat out = Mat(240, 240, CV_8UC3, Scalar(255, 255, 255));
if (prediction == 1) {
putText(out, "1", Point(70, 190), FONT_HERSHEY_PLAIN, 14, CV_RGB(0, 200, 0),
15.8);
} else if (prediction == 0) {
putText(out, "0", Point(70, 190), FONT_HERSHEY_PLAIN, 14, CV_RGB(0, 200, 0),
15.8);
}
imshow("Prediction", out);
}
void onMouseAction(int event, int x, int y, int f, void *) {
switch (event) {
case EVENT_LBUTTONDOWN:
Clicked = true;
srcImg = Mat(240, 240, CV_8UC3, Scalar(255, 255, 255));
break;
case EVENT_LBUTTONUP:
predictDigit();
Clicked = false;
break;
case EVENT_MOUSEMOVE:
if (Clicked) {
circle(srcImg, Point(x, y), 10, Scalar(0, 0, 0), -1);
imshow(windowName, srcImg);
}
break;
default:
break;
}
}
int main() {
Mat trainingDataMat;
Mat label_array;
String imgName;
for (int i = 1; i < 11; i++) {
// Create one data and label
imgName = format("data/one/%d.png", i);
Mat src = imread(imgName, 0);
Mat tmp1, tmp2;
resize(src, tmp1, Size(72, 72), 0, 0, INTER_LINEAR);
tmp1.convertTo(tmp2, CV_32FC1);
trainingDataMat.push_back(tmp2.reshape(1, 1));
label_array.push_back(1);
// Create zero data and label
imgName = format("data/zero/%d.png", i);
src = imread(imgName, 0);
resize(src, tmp1, Size(72, 72), 0, 0, INTER_LINEAR);
tmp1.convertTo(tmp2, CV_32FC1);
trainingDataMat.push_back(tmp2.reshape(1, 1));
label_array.push_back(0);
}
Mat labelsMat;
labelsMat = label_array.reshape(1, 1); // make continuous
// Create the SVM
svm = SVM::create();
// Set up SVM's parameters
svm->setType(ml::SVM::C_SVC);
svm->setKernel(ml::SVM::LINEAR);
// svm->setGamma(3);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
// Ptr tData = ml::TrainData::create(trainingDataMat,
// ml::SampleTypes::ROW_SAMPLE, labelsMat);
// svm->train(tData);
svm->train(trainingDataMat, ml::ROW_SAMPLE, labelsMat);
srcImg = Mat(240, 240, CV_8UC3, Scalar(255, 255, 255));
namedWindow(windowName, 1);
setMouseCallback(windowName, onMouseAction, NULL);
imshow(windowName, srcImg);
while (1) {
char c = waitKey();
if (c == 27)
break;
}
}
No comments:
Post a Comment