《机器学习实战》之K—近邻算法实现手写体数字识别

xiaoxiao2021-02-28  71

一、问题描述

主要程序为kNN.py,在主程序中,包含函数: classify0(inX, dataSet, labels, k):实现分类file2matrix(filename):将文本文件转换为矩阵,本例中没有用到autoNum(dataSet):这个函数没有用到,作用是实现均值归一化img2vector(filename):将图片转化为向量,图片大小是32*32,转化后的向量为1*1024,Detect_Test():数字识别和错误率计算函数

二、各函数的代码

classify0(inX, dataSet, labels, k): # -*- coding=utf-8 -*- from os import listdir from numpy import * import operator def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] img2vector(filename): def img2vector(filename): #将图片转化为向量 returnVect=zeros((1,1024)) fr=open(filename) for i in range(32): lineStr=fr.readline() for j in range(32): returnVect[0,32*i+j]=int(lineStr[j]) return returnVect Detect_Test(): #手写体识别函数&错误率检测函数 def Detect_Test(): hwLabels=[] trainingFileList=listdir('E:/PythonApplication/kNN/trainingDigits') #获取目录内容 m=len(trainingFileList) #获取文件的个数 trainingMat=zeros((m,1024)) #创建一个矩阵,m行1024列,用来存储转化后的数字向量 for i in range(m): fileNameStr=trainingFileList[i] #从文件名解析分类数字 fileStr=fileNameStr.split('.')[0] #用[0]操作符保证操作的是数字矩阵中的每一列而不是每一行 classNumStr=int(fileStr.split('_')[0]) #同上,因为文件名的形式是0_4.txt,所以为了得到分类标签,可以只取‘_’前面的数字 hwLabels.append(classNumStr) trainingMat[i,:]=img2vector('E:/PythonApplication/kNN/trainingDigits/%s' % fileNameStr) testFileList=listdir('E:/PythonApplication/kNN/testDigits') errCount=0 #计算识别错误率 mTest=len(testFileList) for i in range(mTest): fileNameStr=testFileList[i] fileStr=fileNameStr.split('.')[0] classNumStr=int(fileStr.split('_')[0]) vectorOfTest=img2vector('E:/PythonApplication/kNN/testDigits/%s' % fileNameStr) result=classify0(vectorOfTest,trainingMat,hwLabels,3) print 'the classfiler came back with %d the real is:%d' % (result,classNumStr),'\t',i if (result!=classNumStr): errCount += 1.0 print '错误个数:',errCount print '错误率:',errCount/float(mTest)

三、补充

本次实验使用的IDE为PyCharm社区版;实验数据中原始图片是以二进制格式存储的,分为训练集和测试集,整体结构图如下:

其中单独的一张图片存储内容如下:

数据集下载:数据集+源码下载。写到最后:moulei007@gmail.com
转载请注明原文地址: https://www.6miu.com/read-37113.html

最新回复(0)