手写识别之朴素贝叶斯

导读:本篇文章讲解 手写识别之朴素贝叶斯,希望对大家有帮助,欢迎收藏,转发!站点地址:www.bmabk.com

       在前面的几篇文章中,都对手写识别进行了一些讲解,这里主要是介绍一下通过另外一种方法来进行识别———朴素贝叶斯。。自己也是处于机器学习路上的一名新手,如果有什么讲解不对的话,欢迎大家进行交流,可以把建议写到下面。。。。

      好了,不多说,进入正题。。朴素贝叶斯,我相信,搞机器学习的人都不会陌生,关于它的一些基本概念我就不说了,如果还有什么不明白的地方,可以去百度查查它的理论知识。我主要就是对于手写识别来进行针对性的讲解。

    就把朴素贝叶斯中,最为关键的公式贴出来:

    手写识别之朴素贝叶斯

    一:关于训练集数据

       这部分,我在前面的文章中,进行了讲解,而且数据集我也分享到了百度云,如果有需要的可以翻看一下前面的那篇神经网络的文章进行下载。

    二:朴素贝叶斯的实践

       这里讲解一下,大概的步骤吧。其实了解朴素贝叶斯的话,应该很好理解如何进行实施的,毕竟这算法的优点就是通过概率来预测的这么一种简单的方法。  (就把自己做课程报告中PPT写的东西贴出来)

         手写识别之朴素贝叶斯

         手写识别之朴素贝叶斯

        我想,如果了解朴素贝叶斯的基本理念再加上上面的一些提示,那么应该就知道如何进行实施了。。下面就是代码(Java语言)::

        1:读取训练集数据(.csv后缀的文件)

            

package beiyesifenleiqi;

/*
 * 读取后缀为csv的excell文件
 * 
 */


import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class CSVFileUtil {
	 private String fileName = null;
     private BufferedReader br = null;
     private List<String> list = new ArrayList<String>();

     public CSVFileUtil() {

     }

     public CSVFileUtil(String fileName) throws Exception {
             this.fileName = fileName;
             br = new BufferedReader(new FileReader(fileName));
             String stemp;
             while ((stemp = br.readLine()) != null) {
                     list.add(stemp);
             }
     }

     public List getList() {
             return list;
     }
     /**
      * 获取行数
      * @return
      */
     public int getRowNum() {
             return list.size();
     }
     /**
      * 获取列数
      * @return
      */
     public int getColNum() {
             if (!list.toString().equals("[]")) {
                     if (list.get(0).toString().contains(",")) {// csv为逗号分隔文件
                             return list.get(0).toString().split(",").length;
                     } else if (list.get(0).toString().trim().length() != 0) {
                             return 1;
                     } else {
                             return 0;
                     }
             } else {
                     return 0;
             }
     }
     /**
      * 获取制定行
      * @param index
      * @return
      */
     public String getRow(int index) {
             if (this.list.size() != 0) {
                     return (String) list.get(index);
             } else {
                     return null;
             }
     }
     /**
      * 获取指定列
      * @param index
      * @return
      */
     public String getCol(int index) {
             if (this.getColNum() == 0) {
                     return null;
             }
             StringBuffer sb = new StringBuffer();
             String tmp = null;
             int colnum = this.getColNum();
             if (colnum > 1) {
                     for (Iterator it = list.iterator(); it.hasNext();) {
                             tmp = it.next().toString();
                             sb = sb.append(tmp.split(",")[index] + ",");
                     }
             } else {
                     for (Iterator it = list.iterator(); it.hasNext();) {
                             tmp = it.next().toString();
                             sb = sb.append(tmp + ",");
                     }
             }
             String str = new String(sb.toString());
             str = str.substring(0, str.length() - 1);
             return str;
     }
     /**
      * 获取某个单元格
      * @param row
      * @param col
      * @return
      */
     public String getString(int row, int col) {
             String temp = null;
             int colnum = this.getColNum();
             if (colnum > 1) {
                     temp = list.get(row).toString().split(",")[col];
             } else if(colnum == 1){
                     temp = list.get(row).toString();
             } else {
                     temp = null;
             }
             return temp;
     }

     public void CsvClose()throws Exception{
             this.br.close();
     }

}

        2:  构建朴素贝叶斯

        

package beiyesifenleiqi;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;

/*
 * 贝叶斯分类器
 */
public class BeiYeSIFenLeiQi {
      public int[] splicStyleResult;   //分类的结果10种,即0-9
	  public int[][] numberEveryLie;  //每一个数字对应的每一列的概率
	  public double[][] gailvEveryHang;  //每个数字的每一行中的概率
	  public double[] gailvStyleResult;   //训练集中每个数字出现的概率
	  public int totalResultNumber;               //记录总的输入结果的个数就是10,便于计算概率而已
	  public int totalHangNumber;                 //记录总的行数,其实就是784
	  private int xunlianjigeshu;
	  public BeiYeSIFenLeiQi(int styleResult,int totalLie){
		  splicStyleResult=new int[styleResult];  //初始化所需要进行分类的个数
		  numberEveryLie=new int[styleResult][totalLie];   //也就是10和784
		  gailvEveryHang=new double[styleResult][totalLie];   //也就是10和784
		  gailvStyleResult=new double[styleResult];        //每个数字的概率
		  totalResultNumber=styleResult;      //得到结果个数
		  totalHangNumber=totalLie;           //得到列数
	  }
	  /*
	   * 设置训练集的总个数  
	   */
	  public void setXunLianNumber(int geshu){
		  xunlianjigeshu=geshu;
	  }
	  
	  /*
	   * 计算在训练集中,每一个结果标签对应的个数
	   */
	  public void addeveryResultNumber(int numberresult, int currentnumber){     
		  splicStyleResult[currentnumber]+=numberresult;  //对应的标签个数加上线程处理后的个数
	  }
	  /*
	   *  找到对应的结果数字,并且将相应的列为1值的索引数加1,主要用来后面算概率
	   */
	  public void addEveryLieNumber(int numberresult,int liesuoyin,int number){           
		  numberEveryLie[numberresult][liesuoyin]+=number;        //将像素数值为1的对应的数字的索引个数加1,(784中)
	  }
	  
	  /*
	   *  更新处理完概率后每个线程计算出来的概率总和
	   * 
	   */
	  public void updataThreadComputeGaiLv(double[] gailvResult , double[][] gaileveryElement){
		   //更新数字的概率
		  for(int i=0 ; i<totalResultNumber ; i++){
			  gailvStyleResult[i]+=gailvResult[i];
			  for(int j=0 ; j<totalHangNumber ; j++){
				  gailvEveryHang[i][j]+=gaileveryElement[i][j];
			  }
		  }
	  }
	  
	  /*
	   * 打印概率结果
	  */
	  public void printfResult(){
		 for(int i=0;i<10;i++){
//  			 System.out.println(gailvStyleResult[i]);//打印每个数字的概率
//			 for(int j=0;j<784;j++){
//				 System.out.print(gailvEveryHang[i][j]+" ");//打印每个数字的每个像素的概率
//			 }
			 System.out.println();
		 } 
	  }
	  
	 /*
	  * 计算测试集的结果
	  */
	public int computeYuCeResult(double[] binary) {
		double[] everyGailv=new double[10];
		//计算每个数字的可能性概率
		for(int currennumber=0;currennumber<10;currennumber++){
			everyGailv[currennumber]=gailvStyleResult[currennumber];  //得到在训练集中该数字出现的概率
			for(int suoyin=0;suoyin<784;suoyin++){
				if(binary[suoyin]==0){   //表示该像素上没有可能性
					everyGailv[currennumber]=everyGailv[currennumber]*(1-gailvEveryHang[currennumber][suoyin]);//贝叶斯分类的概率计算,因为该位置为0,则表示与训练集汇总的可能性较大
				}
				else if(binary[suoyin]==1){  //表示该位置出现了,则按之后算好的概率进行计算
					everyGailv[currennumber]=everyGailv[currennumber]*gailvEveryHang[currennumber][suoyin];
				}
			}
		}
		//比较存储的10个数字中,概率最大的是哪个,则表示最有可能的预测就是哪个数字
		double sumGailv=0;  //总的概率
		for(int i=0;i<10;i++){
			sumGailv=sumGailv+everyGailv[i];
		}
		for(int j=0;j<10;j++){
			everyGailv[j]=everyGailv[j]/sumGailv;   //得到权重的百分比
		}
		
		double maxGailv=everyGailv[0];
		int maxSuoyin=0;
		for(int max=1;max<10;max++){
			if(maxGailv<everyGailv[max]){
				maxGailv=everyGailv[max];
				maxSuoyin=max;
			}
		}
		return maxSuoyin;   //返回预测的数字
	}
	/*
	 * 加载之前已经训练过的数据
	 */
	public void loadPreviousData(File writePath) throws Exception {
		 FileInputStream in=new FileInputStream(writePath);
		 InputStreamReader isr=new InputStreamReader(in, "UTF-8");  //防止乱码
	     BufferedReader br = new BufferedReader(isr);
	     String currenline ="";
	     int suoyin=0;
	     try {
			while((currenline=br.readLine())!=null){
				 String[] fengeeverynumber=currenline.split(",");  //得到每一行的每小格的数据
				 int totallength=fengeeverynumber.length;
				 gailvStyleResult[suoyin] = Double.valueOf(fengeeverynumber[0]);
					 for(int i=1;i<totallength-1;i++){
						 gailvEveryHang[suoyin][i-1]=Double.valueOf(fengeeverynumber[i]);
					 }			 
					 suoyin++;
			 }
		} catch (IOException e) {			
			e.printStackTrace();
		}
	     finally{
	    	 br.close();
	     }	   	
	}
	/*
	 * 把每个概率写入到Txt文件中,方便后面读
	 */
	public void writeEveryGailv(File writePath) throws IOException {
		int lengthdata = gailvStyleResult.length;   //得到数字概率的个数(其实就是10个)
		FileWriter fw= new FileWriter(writePath);
		BufferedWriter  bw= new BufferedWriter(fw);
		for(int resultsuoyin=0;resultsuoyin<lengthdata;resultsuoyin++){
			bw.write(gailvStyleResult[resultsuoyin]+",");
				for(int i=0;i<784;i++){
					bw.write(gailvEveryHang[resultsuoyin][i]+",");			  //写入数据
				}  
			bw.write("\t\n");                                    //加个换行(一个数据一行)
		}
		bw.close();		
	}
}

        3:线程类(因为数据太多了,就用多线程进行了读取,这样来减少读取的时间)

        

package beiyesifenleiqi;

import java.util.concurrent.CountDownLatch;

import shenjingwangluo2.CSVFileUtil;
import beiyesifenleiqi.Text;
public class startDealThread  implements Runnable{
		int startindex;
		int overindex;
		int trainResultNumber;
		int xunlianjigeshu;
		int totalHangNumber;
		CSVFileUtil resultData;
		CSVFileUtil trainData;
		int[] getnumber;
		int[][] everyNumberGeshu;
		double[] gailvResultNumber;
		double[][] gailvEveryHangLie;
		CountDownLatch countDownLatch;
		Text manythread;
		public startDealThread(Text manythread, CountDownLatch countDownLatch, 
				int suoyin,int oversuoyin,CSVFileUtil util,CSVFileUtil util2,int totalResultNumber,int totalLieNumber) {
			startindex=suoyin;
			this.overindex=oversuoyin;
			resultData=util;
			trainData=util2;
			totalHangNumber=totalLieNumber;     //总的元素的个数(784)
			xunlianjigeshu=oversuoyin-suoyin;   //每个线程训练的个数
			this.countDownLatch=countDownLatch;
			this.manythread=manythread;
			trainResultNumber=totalResultNumber;
			getnumber = new int[totalResultNumber];     //存储每个数字的个数
			everyNumberGeshu=new int[totalResultNumber][totalLieNumber];  //存储每个元素的个数
			gailvResultNumber = new double[totalResultNumber];  //存储每个数字的概率
			gailvEveryHangLie = new double[totalResultNumber][totalLieNumber]; //存储每个元素的概率
		}
		@Override
	    public void run() {  		 	
	   		compute(startindex);           //计算每个数字的次数
	   		computeGaiLVEveryHang();        //计算每个元素的概率
	    	manythread.updataAllData(gailvResultNumber,gailvEveryHangLie);   //计算概率完成之后更新所有线程计算出现的概率
	    	countDownLatch.countDown();   //表示该线程已经进行执行完成
	    	System.out.println("线程我已经完成计算工作!!");
	    }

		/*
	      * 计算数据
	      */
	    private void compute(int currentsuoyin) {
	    	int resultNumber=0;
	    	int suoyinlie=0;
	    	int value=0;
	    	for(int i=currentsuoyin;i<overindex;i++){
	    		 resultNumber=Integer.parseInt(resultData.getString(i, 0));
	    		getnumber[resultNumber]=getnumber[resultNumber]+1;
	    		 while(suoyinlie<784){
	             	value=Integer.parseInt(trainData.getString(i, suoyinlie));
	             	if(value>=128){
	             		addEveryLieNumber(resultNumber, suoyinlie);     //主要是为了让数据中只有0和1这样的灰度数据方便计算
	             	}                                                //而且对于存在1的时候,才进行存储,也就是表示实际有像素点被画
	             	suoyinlie++;
	             }
	             suoyinlie=0;  //处理一个后,记得还原
	    	}		    	
		}
	    /*
	     * 计算每一行中的像素为1的个数
	     */
		private void addEveryLieNumber(int resultNumber, int suoyinlie) {
			  everyNumberGeshu[resultNumber][suoyinlie]+=1;        //将数值为1的索引个数加1,(784中)			
		}	
		
		 /*
		   * 计算每一个结果每一行中的概率
		   */
		 public void computeGaiLVEveryHang(){
			  double tempResult=0;           //每个行的中间结果(784个)  
			  double tempResultgeshu=0;      //得到每个数字出现的个数
			  double styleResult=0; 		//保存每个数字的概率中间变量而已
			  double everyResult=0;			//保存每个元素的个数,中间变量而已
			  for(int i=0;i<trainResultNumber;i++){
				  	tempResultgeshu=getnumber[i];   //得到训练集中,对应数字的个数
				  	if(tempResultgeshu==0){         //防止一个都没出现的情况,为了效果更加的平滑
				  		tempResultgeshu=1;
				  	}
				  	styleResult=tempResultgeshu/xunlianjigeshu;  //得到每个数字在训练集中出现的概率
				    gailvResultNumber[i]=styleResult;  
				  for(int j=0;j<totalHangNumber;j++){
					  if(everyNumberGeshu[i][j]==0){   //表示一个样本都没出现,则加1,防止出现平滑处理
						  everyNumberGeshu[i][j]=1;
					  }
					  everyResult=everyNumberGeshu[i][j];
					  tempResult=everyResult/tempResultgeshu;  //得到对应的概率
					  gailvEveryHangLie[i][j]=tempResult;              //得到每个元素的概率
				  }
			  }
		  }
}

        4:  训练数据及其测试数据的主类    

       

package beiyesifenleiqi;

import java.io.File;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Date;
import java.util.concurrent.CountDownLatch;
import javax.xml.crypto.Data;
import shenjingwangluo2.CSVFileUtil;

/*
 * 进行测试
 * 
 */
public class Text {

	private static Text text;
	private static BeiYeSIFenLeiQi beiyesi;
	private static File writePath;
	public static void main(String[] args) throws Exception {
		Date data=new Date();
		SimpleDateFormat si=new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
		System.out.println("开始训练时间:"+si.format(data));
		text = new Text();
		beiyesi = new BeiYeSIFenLeiQi(10, 784);	
		writePath = new File("D:/xunlianresult");
		if(writePath.exists()){  //如果指定的文件已经存在,表示之前已经写入了数据
			beiyesi.loadPreviousData(writePath);
		}
		else{                    //没有训练结果,则进行训练
			writePath.createNewFile();
			dataInit(beiyesi);         //训练数据的处理
			System.out.println("训练数据并存储数据成功!!!");
		}
		computeTextData(beiyesi);  //进行测试
	}
 
	 /*
	 * 进行测试
	 */	
	 private static void computeTextData(BeiYeSIFenLeiQi beiyesi) throws Exception {
		 //输入测试数据
		 CSVFileUtil util3 = new CSVFileUtil("D:\\textdata.csv");
		 int textthang=util3.getRowNum(); //得到测试数据行数	     
	     CSVFileUtil util4 = new CSVFileUtil("D:\\textresult.csv");
	     
	     //二值化进行处理,并且进行预测
	     int getTextNumber=0;   //保存测试数据当前预测的值
	     int yuceresult=0;      //预测的结果
	     int accurencynumber=0;  //预测正确的个数
	     double[] binary=new double[784];   //二值化的信息
	     for(int i=0;i<textthang;i++){
	    	 getTextNumber=Integer.parseInt(util4.getString(i, 0));   //得到测试数据当前预测的值         
	    	 int currentsuoyin=0;
	    	 int value=0;	    
	    	 while(currentsuoyin<784){
	    		 value=Integer.parseInt(util3.getString(i, currentsuoyin));
	    		 if(value>128){  //因为二值化后大于128的就为1
	    			 binary[currentsuoyin]=1;	    			 
	    		 }
	    		 currentsuoyin++;
	    	 }
	    	 currentsuoyin=0;
	    	 //进行预测结果
	    	 yuceresult=beiyesi.computeYuCeResult(binary);   //得到预测的结果
	    	 System.out.print("经过预测的值为:"+yuceresult);  //打印预测的结果
	    	 if(yuceresult==getTextNumber){   //比较预测和真实结果,是否相同
	    	     accurencynumber++;
	    	     System.out.println("(正确)");
	    	 }
	    	 else{
	    		 System.out.println("(错误),实际的数字为:"+getTextNumber);
	    	 }
	    	 Arrays.fill(binary, 0);  //记得每次要把数组清零否则会影响后面的内容
	     }
	     //打印准确度	     
	     double result=(Double.valueOf(accurencynumber)/Double.valueOf(textthang))*100;
	     System.out.println("精确度为:"+result+"%");
	     Date data=new Date();
	     SimpleDateFormat si=new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
		 System.out.println("测试结束时间:"+si.format(data));
	}

	/*
	 * 训练数据的初始化处理
	 */
	private static void dataInit(BeiYeSIFenLeiQi beiyesi) throws Exception {
		
		 //得到训练集的结果标签
		 CSVFileUtil util = new CSVFileUtil("D:\\trainresult.csv");
	     int resulthang=util.getRowNum(); //得到训练结果行数
	     int resultlie=util.getColNum();  //得到训练结果列数
		   //得到训练数据的结果
	       CSVFileUtil util2 = new CSVFileUtil("D:\\traindata.csv");
	       int inputhang=util2.getRowNum(); //得到训练结果行数
	       int inputlie=util2.getColNum();  //得到训练结果列数
	       beiyesi.setXunLianNumber(resulthang);
	       readData(util,util2);                  //初始化数据(用多线程进行处理)
	       
	       //下面这些是最开始写的时候没用多线程进行处理的方法,也可以,就是训练太慢了,自己又改进了s
//	       int resultNumber=0;
//	       int suoyinlie = 0;  
//	       int value=0;
//	     for(int i=0;i<resulthang;i++){      //进行需要处理数据的个数的统一
//	    	 resultNumber=Integer.parseInt(util.getString(i, 0));  //得到结果标签的值
//	    	 beiyesi.addeveryResultNumber(resultNumber);   //对应的个数+1	    	 	    	 
//            
//            while(suoyinlie<784){
//            	value=Integer.parseInt(util2.getString(i, suoyinlie));
//            	if(value>=255/2){
//            		beiyesi.addEveryLieNumber(resultNumber, suoyinlie);     //主要是为了让数据中只有0和1这样的灰度数据方便计算
//            	}                                                //而且对于存在1的时候,才进行存储,也就是表示实际有像素点被画
//            	suoyinlie++;
//            }
//            suoyinlie=0;  //处理一个后,记得还原
//	     }
	     //进行每个数字中对应1位置占总数字个数的概率的计算
//	     beiyesi.computeGaiLVEveryHang();
//		 beiyesi.printfResult();        //打印概率的结果
	}
	/*
	 * 初始化训练数据的内容
	 */
	private static void readData(CSVFileUtil util, CSVFileUtil util2) throws IOException {
		CountDownLatch countDownLatch=new CountDownLatch(10); //开10个线程进行读取数据处理
		for(int i=0;i<10;i++){                                  //开启线程进行统计需要的数量
			Thread start=new Thread(new startDealThread(text,countDownLatch,i*250,(i+1)*250,util,util2,10,784));
			start.start();
		}
		try {
			countDownLatch.await();    //等待所有的子线程全部执行完成,才执行后面的任务
			//beiyesi.computeGaiLVEveryHang();  //所有数字的个数都记录好之后,进行计算概率
			beiyesi.writeEveryGailv(writePath);  //把训练好的结果存放到TXT文件,方便下次直接读取
		} catch (InterruptedException e) {		
			e.printStackTrace();
		}
		
	}
	/*
	 * 更新线程计算完后每个数字和元素的概率
	 */
	public synchronized void updataAllData(double[] gailvResult, double[][] gaileveryElement) {
			
			beiyesi.updataThreadComputeGaiLv(gailvResult, gaileveryElement); //更新概率			
//			int currentnumber=0;
//			for(int i=0;i<10;i++){
//				beiyesi.addeveryResultNumber(getnumber[i],i);  
//				for(int liesuoyin=0;liesuoyin<784;liesuoyin++){
//					currentnumber=everyNumberGeshu[i][liesuoyin];
//					beiyesi.addEveryLieNumber(i, liesuoyin,currentnumber);
//				}
//			}
			
				
	}

}

       上面的代码没有进行太多的优化,所以可能存在一些多余的部分,而且用了线程也是为了方便读取数据而已。。

       通过上面的代码的话,最后达到的准确度有94%左右,所以这个基本还行吧。至少能进行有效的识别了。。。。。。。。。。。…………

 注意:: 下面还说说,自己在做这个实验过程中,遇到的一些关于朴素贝叶斯进行分类的问题吧!(进行贴图总结了)!!!

       手写识别之朴素贝叶斯

           反应出的问题(也就是缺点):

          手写识别之朴素贝叶斯

       好了,这个就讲解到这里了。。。共同进步,慢慢的学习!!!!!!!!!!!!!!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之家整理,本文链接:https://www.bmabk.com/index.php/post/12453.html

(0)
小半的头像小半

相关推荐

极客之家——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!