环境:python 3,scikit-learn 0.18
"""
python 3
scikit-learn 0.18
"""
from sklearn.model_selection
import GridSearchCV
from sklearn.model_selection
import train_test_split
from sklearn.tree
import DecisionTreeClassifier
from sklearn.metrics
import accuracy_score,confusion_matrix,classification_report
import input_data
import numpy
as np
mnist = input_data.read_data_sets(
'mnist/',one_hot=
False)
x = mnist.train.images
y = mnist.train.labels
train_data,validation_data,train_labels,validation_labels = train_test_split(x,y,test_size=
0.2)
dtree = DecisionTreeClassifier(random_state=
0)
criterion_options = [
'gini',
'entropy']
splitter_options = [
'best',
'random']
param_griddtree = dict(criterion=criterion_options,splitter=splitter_options)
griddtree = GridSearchCV(dtree,param_griddtree,cv=
10,scoring=
'accuracy',verbose=
1)
griddtree.fit(train_data,train_labels)
print(
'best score is:',str(griddtree.best_score_))
print(
'best params are :',str(griddtree.best_params_))
结果
耗时4.5min找到最优参数