k近邻法(k-nearest neighbor)

xiaoxiao2021-02-28  47

博主原创  http://blog.csdn.net/xuelabizp/article/details/50931493

1.什么是k近邻法

k近邻法是一种基本的多分类和回归的算法,常常简称为kNN。kNN在李航的《统计学习方法》中的描述如下:

给定一个训练数据集,对新的输入实例,在数据集中找到与该实例最近邻的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。

可以用一个简单的例子说明一下kNN,二维坐标下有一些点,如图所示: 数据集包含A、B两类数据,具体如下表所示:

xylabel00.03A0.010A11.05B10.95B

现有新的实例(0.1,0.1),要求将其分类。 第一步,计算输入实例和数据集各个数据的欧氏距离:[0.12, 0.01, 1.31, 1.24] 第二步,将计算的距离按照从小到大排序,统计前k个数据的类别,这里假设k为3,则前3个距离最近的数据类为AAB 第三步,将输入实例判断为频率最高的类,本例中A的频率最高(为2),即输入实例是A类数据

2.kNN三要素

kNN的三要素是k,距离度量和分类决策规则。

2.1k

如果选择小的k值,则只有和输入实例比较近的点才会对预测结果产生影响,这样做会导致分类系统的抗噪声能力弱,如果输入实例附近恰好有噪声,分类就极大地可能出错,导致过拟合。

如果选择大的k值,相当于在较大领域进行预测,假设k值和数据集数据的个数一样,则无论输入什么实例,都将分类为数据集中数量最多的类别。

一般情况下,k值选取一个比较小的数值。通常使用交叉验证法选取最优k值。

2.2距离度量

假设数据有n维,则距离的定义为: 

Lp(xi,xj)=(l=1n|x(l)ix(l)j|p)1pLp(xi,xj)=(∑l=1n|xi(l)−xj(l)|p)1p 这里p>=1,当p=1时,称为曼哈顿距离;当p=2时,称为欧氏距离,一般都使用欧氏距离。

2.3分类决策规则

kNN的分类策略规则是多数表决规则,即前k个最小距离中数量最多的类别决定输入实例的类别。

3.使用kNN对iris数据集中的花进行分类

3.1iris数据集

iris以鸢尾花的特征作为数据来源,常用在分类操作中。该数据集由3种不同类型的鸢尾花的50个样本数据构成。其中的一个种类与另外两个种类是线性可分离的,后两个种类是非线性可分离的。 该数据集包含了5个属性: & Sepal.Length(花萼长度),单位是cm; & Sepal.Width(花萼宽度),单位是cm; & Petal.Length(花瓣长度),单位是cm; & Petal.Width(花瓣宽度),单位是cm; & 种类:Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),以及Iris Virginica(维吉尼亚鸢尾)。

由于花瓣宽度变化很小,将其省略后根据前三维数据画出散点图,如下所示: 

3.2载入数据

def file2matrix(fileName): file = open(fileName) allLines = file.readlines() row = len(allLines) dataSet = zeros((row, 4)) labels = [] index = 0 for line in allLines: line = line.strip() listFromLine = line.split(',') dataSet[index, :] = listFromLine[0:4] labels.append(listFromLine[-1]) #取最后一维为标签 index += 1 return dataSet, labels #数据集和标签分开 1234567891011121314

3.3kNN算法

def kNN(x, dataSet, labels, k): dataSetSize = dataSet.shape[0] distance1 = tile(x, (dataSetSize,1)) - dataSet #欧氏距离计算开始 distance2 = distance1 ** 2 #每个元素平方 distance3 = distance2.sum(axis=1) #矩阵每行相加 distance4 = distance3 ** 0.5 #欧氏距离计算结束 sortedIndex = distance4.argsort() #返回从小到大排序的索引 classCount = {} for i in range (k): #统计前k个数据类的数量 label = labels[sortedIndex[i]] classCount[label] = classCount.get(label,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) #从大到小按类别数目排序 return sortedClassCount[0][0] 12345678910111213

3.3kNN算法测试

def kNN_test(): testRatio = 0.1 #取数据集的前0.1为测试数据 dataSet, labels = file2matrix('irisdata_test.txt') row = dataSet.shape[0] testNum = int(row * testRatio) error = 0.0 #判断错误的个数 for i in range (testNum): result = kNN(dataSet[i, :], dataSet[testNum:row, :], labels[testNum:row], 3) print 'the result came back with: %s, the real answer is: %s' % (result, labels[i]) if (result != labels[i]): error += 1.0 print 'error rate is: %f' % (error/float(testNum)) 123456789101112

3.4小结

输出结果如下: 分类效果还是不错的,但是由于后两种花是非线性可分离的,故在交界处的数据很可能分类错误,可以使用SVM等方法将非线性可分离的数据分离当有部分维数的数值较大的时候,会较大的影响距离计算,可以使用(xmin)/(maxmin)(x−min)/(max−min)对该维度进行归一化处理

4.总结

欢迎在我的GitHub中下载源代码,MachineLearningAction仓库里面有常见的机器学习算法处理常见数据集的各种实例kNN没有明显的学习过程,属于惰性学习方法kNN适合于多分类问题,当维数较大时,比SVM快k值过小导致对局部数据敏感,抗噪能力差;k值过大,会因为数据集中实例不均衡导致分类出错当数据集较大时,计算量较大,因为每次分类要进行一次全局运算kNN多应用于文本分类、模式识别、聚类分析,多分类领域
转载请注明原文地址: https://www.6miu.com/read-2627262.html

最新回复(0)