机器学习--交叉验证函数

xiaoxiao2021-02-28  27

一、交叉验证

在建立分类模型时,交叉验证(Cross Validation)简称为CV,CV是用来验证分类器的性能。它的主体思想是将原始数据进行分组,一部分作为训练集,一部分作为验证集。利用训练集训练出模型,利用验证集来测试模型,以评估分类模型的性能。

二、交叉验证的作用

验证分类器的性能用于模型的选择

三、交叉验证常用的几种方法

3.1 k折交叉验证 K-fold Cross Validation(记为K-CV)

1、将数据集平均分割成K个等份(参数cv值,一般选择5折10折,即测试集为20%) 2、使用1份数据作为测试数据,其余作为训练数据 3、计算测试准确率 4、使用不同的测试集,重复2、3步 5、对测试准确率做平均,作为对未知数据预测准确率的估计

优点: 因为每一个样本数据既可以作为测试集又可以作为训练集,可有效避免欠学习和过学习状态的发生,得到的结果比较有说服力。

3.2 留一法交叉验证 Leave-One-Out Cross Validation(记为LOO-CV)

假设样本数据集中有N个样本数据。将每个样本单独作为测试集,其余N-1个样本作为训练集,这样得到了N个分类器或模型,用这N个分类器或模型的分类准确率的平均数作为此分类器的性能指标。

优点: a. 每一个分类器或模型几乎所有的样本都用来作为训练模型,因此最接近样本,实验评估可靠;

b. 实验过程没有随机因素影响实验结果,所以实验结果可复制,因此实验结果稳定。

缺点: 计算成本高,因为需要建立的模型数量与样本数据数量相同,当N很大时,计算相当耗时。

3.3 留p交叉验证

留p验证指训练集上随机选择p个样本作为测试集,其余作为子训练集。时间复杂度为CpN,是阶乘的复杂度,不可取。

3.4 重复随机子抽样验证 Hold-Out Method

将数据集随机划分为训练集和测试集。对每一个划分,用训练集训练分类器或模型,用测试集评估预测的精确度。进行多次划分,用均值来表示效能。

优点: 与K值无关。严格意义来说Hold-Out Method不属于交叉验证方法,这种方法与k无关。

缺点: 验证集结果准确率的高低和原始分组有很大关系,可能导致一些数据从未做过训练或测试数据;而一些数据不止一次选为训练或测试数据的情况发生,因此结果不具有说服力。

四、交叉验证函数

cross_val_score详情可见官网

train_test_split

#导入 from sklearn.cross_validation import cross_val_score from sklearn.cross_validation import train_test_split

五、代码

例子:垃圾邮件分类 input:

from numpy import * from sklearn import metrics from sklearn.metrics import accuracy_score from sklearn.naive_bayes import GaussianNB as NB from sklearn.neighbors import KNeighborsClassifier as KNN from sklearn.linear_model import LogisticRegression as LR #将词条合并为一个列表 def createVocabList(dataSet): vocabSet = set([]) #创建一个空集 for document in dataSet: vocabSet = vocabSet | set(document) #创建两个集合的并集 return list(vocabSet) #将词汇转化为向量 def bagOfWords2VecMN(vocabList, inputSet): returnVec = [0]*len(vocabList) #初始化 词汇等长的0向量 for word in inputSet: if word in vocabList: returnVec[vocabList.index(word)] += 1 return returnVec #预处理 统一小写,去除长度小于2个的词汇 def textParse(bigString): import re listOfTokens = re.split(r'\W*', bigString) return [tok.lower() for tok in listOfTokens if len(tok) > 2] #统计词频前10 def calcMostFreq(vocabList,fullText): import operator freqDict = {} for token in vocabList: freqDict[token]=fullText.count(token) sortedFreq = sorted(freqDict.items(), key=operator.itemgetter(1), reverse=True) return sortedFreq[:10] #读取数据 def spamTest(): docList=[]; classList = []; fullText =[] for i in range(1,26): wordList = textParse(open('email/spam/%d.txt' % i).read()) docList.append(wordList) fullText.extend(wordList) classList.append(1) wordList = textParse(open('email/ham/%d.txt' % i).read()) docList.append(wordList) fullText.extend(wordList) classList.append(0) vocabList = createVocabList(docList) #创建词列表 top10Words = calcMostFreq(vocabList,fullText) #删除词频前10 for pairW in top10Words: if pairW[0] in vocabList: vocabList.remove(pairW[0]) trainingSet = list(range(50)) #0-49,,50个数字,50封邮件 train_data = [] #存储 所有训练词汇的向量 train_target = [] #存储 类别标签 for docIndex in trainingSet: #得到训练数据的向量 train_data.append(bagOfWords2VecMN(vocabList, docList[docIndex])) train_target.append(classList[docIndex]) return train_data,train_target

5.1 cross_val_score

input:

from sklearn.cross_validation import cross_val_score if __name__ == '__main__': data = [] target = [] data, target = spamTest() clf1 = KNN(n_neighbors=8) clf2 = LR() clf3 = NB() #交叉验证 cv:数据分成的份数,其中一份作为cv集,其余n-1作为训练集(默认为3) for clf,lable in zip([clf1, clf2, clf3],['KNN','LR','NB']): scores = cross_val_score(clf,data,target,cv=5,scoring='accuracy') #print(scores) print("Accuracy:%0.2f (+/-%0.2f)[%s]"%(scores.mean(),scores.std(),lable)) #计算均值及标准差

output:

G:\Anacanda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match. return _compile(pattern, flags).split(string, maxsplit) Accuracy:0.64 (+/-0.05)[KNN] Accuracy:0.94 (+/-0.05)[LR] Accuracy:0.92 (+/-0.07)[NB]

5.2 train_test_split

input:

from sklearn.cross_validation import train_test_split if __name__ == '__main__': data = [] target = [] data, target = spamTest() clf1 = KNN(n_neighbors=8) clf2 = LR() clf3 = NB() ''' #交叉验证 cv:数据分成的份数,其中一份作为cv集,其余n-1作为训练集(默认为3) for clf,lable in zip([clf1, clf2, clf3],['KNN','LR','NB']): scores = cross_val_score(clf,data,target,cv=5,scoring='accuracy') #print(scores) print("Accuracy:%0.2f (+/-%0.2f)[%s]"%(scores.mean(),scores.std(),lable)) #计算均值及标准差 ''' #交叉验证 x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.2) #交叉验证 20%选取测试集 clf = clf2.fit(x_train, y_train) predicted = clf.predict(x_test) expected = y_test print(metrics.classification_report(expected, predicted)) print(metrics.confusion_matrix(expected, predicted)) print('Score:',accuracy_score(expected,predicted))

output:

G:\Anacanda3\lib\re.py:212: FutureWarning: split() requires a non-empty pattern match. return _compile(pattern, flags).split(string, maxsplit) precision recall f1-score support 0 1.00 1.00 1.00 5 1 1.00 1.00 1.00 5 avg / total 1.00 1.00 1.00 10 [[5 0] [0 5]] Score: 1.0
转载请注明原文地址: https://www.6miu.com/read-2799972.html

最新回复(0)