Full width home advertisement

OpenCV

HTML

Post Page Advertisement [Top]

  Today post will explain how OpenCV SVM can used for Handwritten Digit classification. To lean more about SVM see OpenCV documentation. Here I am explaining how SVM used to train handwritten digit and latter classify the test sample. So for simplicity I used only 2 digit, 1 and 0. The demo for this tutorial can be found in below video.

OpenCV SVM Handwritten Digits Classification
  As a first step create training sample, the image with 0 and one and stored in separate directory. I have stored the data in folders data/one/ and data/zero/, Below is the sample image I have used for the demo, it can download from my git-hub page here.

Data to Train




Initialise SVM


Ptr svm;
//Create the SVM
svm = SVM::create();
svm->setType(ml::SVM::C_SVC);
svm->setKernel(ml::SVM::LINEAR);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));

Create Training Data

 Here we load the handwritten digit from disk and assign corresponding label. Note the flag 0 of imread which wich will load the image as gray-scale. This step include resize image, converting image to floating point number and make it continues for training SVM etc..

   Mat trainingDataMat;
   Mat label_array;
   String imgName;

   for(int i=1;i<11;i++){

       //Create one data and label
       imgName = format("data/one/%d.png",i);
       Mat src = imread(imgName,0); //load image as gray scale
       Mat tmp1, tmp2;
       resize(src,tmp1, Size(72,72), 0,0,INTER_LINEAR );
       tmp1.convertTo(tmp2,CV_32FC1);
       trainingDataMat.push_back(tmp2.reshape(1,1));
       label_array.push_back(1);

       //Create zero data and label
       imgName = format("data/zero/%d.png",i);
       src = imread(imgName,0);
       resize(src,tmp1, Size(72,72), 0,0,INTER_LINEAR );
       tmp1.convertTo(tmp2,CV_32FC1);
       trainingDataMat.push_back(tmp2.reshape(1,1));
       label_array.push_back(0);

   }

   Mat labelsMat;
   labelsMat=label_array.reshape(1,1); //make continuous

Train SVM

 From above step we got pre-processed data for training, that is input data and corresponding labels.

     svm->train( trainingDataMat , ml::ROW_SAMPLE , labelsMat );

Predict the input

 Before doing the predication, we have to do the same pre-processing steps that has performed while training, so we do the following , convert input image to grayscale, resize, convert to float and reshape.

Mat src = imread("testImg.png",1);
Mat gray;
cvtColor(src,gray,COLOR_BGR2GRAY);
Mat tmp1, tmp2;
resize(gray,tmp1, Size(72,72), 0,0,INTER_LINEAR );
tmp1.convertTo(tmp2,CV_32FC1);

float prediction = svm->predict(tmp2.reshape(1,1));
if(prediction ==1){
      putText(src,"prediction = 1", Point(10, 50), FONT_HERSHEY_PLAIN, 5, CV_RGB(255,0,0), 2.8);
}
else if(prediction == 0){
      putText(src,"prediction = 0", Point(10, 50), FONT_HERSHEY_PLAIN, 5, CV_RGB(255,0,0), 2.8);
}

Full SourceCode

 And finally, here is the full source code which will train the SVM modal and user able to draw digit using mouse on window and SVM will predict the digit whether it's 0 or 1

#include 
#include 

using namespace cv;
using namespace std;
using namespace cv::ml;

const char *windowName = "SVM Example";
Ptr svm;
bool Clicked;
Mat srcImg;

void predictDigit() {

  Mat gray;
  // String imgName = format("data/test/%d.png",2);
  // srcImg = imread(imgName,1);
  cvtColor(srcImg, gray, COLOR_BGR2GRAY);
  Mat tmp1, tmp2;
  resize(gray, tmp1, Size(72, 72), 0, 0, INTER_LINEAR);
  tmp1.convertTo(tmp2, CV_32FC1);

  float prediction = svm->predict(tmp2.reshape(1, 1));

  Mat out = Mat(240, 240, CV_8UC3, Scalar(255, 255, 255));
  if (prediction == 1) {
    putText(out, "1", Point(70, 190), FONT_HERSHEY_PLAIN, 14, CV_RGB(0, 200, 0),
            15.8);
  } else if (prediction == 0) {
    putText(out, "0", Point(70, 190), FONT_HERSHEY_PLAIN, 14, CV_RGB(0, 200, 0),
            15.8);
  }

  imshow("Prediction", out);
}
void onMouseAction(int event, int x, int y, int f, void *) {

  switch (event) {

  case EVENT_LBUTTONDOWN:
    Clicked = true;
    srcImg = Mat(240, 240, CV_8UC3, Scalar(255, 255, 255));
    break;

  case EVENT_LBUTTONUP:
    predictDigit();
    Clicked = false;
    break;

  case EVENT_MOUSEMOVE:
    if (Clicked) {
      circle(srcImg, Point(x, y), 10, Scalar(0, 0, 0), -1);
      imshow(windowName, srcImg);
    }
    break;

  default:
    break;
  }
}

int main() {
  Mat trainingDataMat;
  Mat label_array;
  String imgName;

  for (int i = 1; i < 11; i++) {

    // Create one data and label
    imgName = format("data/one/%d.png", i);
    Mat src = imread(imgName, 0);
    Mat tmp1, tmp2;
    resize(src, tmp1, Size(72, 72), 0, 0, INTER_LINEAR);
    tmp1.convertTo(tmp2, CV_32FC1);
    trainingDataMat.push_back(tmp2.reshape(1, 1));
    label_array.push_back(1);

    // Create zero data and label
    imgName = format("data/zero/%d.png", i);
    src = imread(imgName, 0);
    resize(src, tmp1, Size(72, 72), 0, 0, INTER_LINEAR);
    tmp1.convertTo(tmp2, CV_32FC1);
    trainingDataMat.push_back(tmp2.reshape(1, 1));
    label_array.push_back(0);
  }

  Mat labelsMat;
  labelsMat = label_array.reshape(1, 1); // make continuous

  // Create the SVM
  svm = SVM::create();

  // Set up SVM's parameters
  svm->setType(ml::SVM::C_SVC);
  svm->setKernel(ml::SVM::LINEAR);
  // svm->setGamma(3);
  svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));

  // Ptr tData = ml::TrainData::create(trainingDataMat,
  // ml::SampleTypes::ROW_SAMPLE, labelsMat);
  // svm->train(tData);
  svm->train(trainingDataMat, ml::ROW_SAMPLE, labelsMat);

  srcImg = Mat(240, 240, CV_8UC3, Scalar(255, 255, 255));
  namedWindow(windowName, 1);
  setMouseCallback(windowName, onMouseAction, NULL);
  imshow(windowName, srcImg);

  while (1) {
    char c = waitKey();
    if (c == 27)
      break;
  }
}

Demo Video



No comments:

Post a Comment

Bottom Ad [Post Page]

| Designed by Colorlib