一、导入标准库
In [1]:
# Importing the libraries 导入库
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 使图像能够调整
%matplotlib notebook
#中文字体显示
plt.rc('font', family='SimHei', size=8)
二、导入数据
In [2]:
dataset = pd.read_csv('Mall_Customers.csv') # 寻找目标用户,根据年收入和购物指数进行聚类分析
dataset
Out[2]:
CustomerIDGenreAgeAnnual Income (k$)Spending Score (1-100)
01Male19153912Male21158123Female2016634Female23167745Female31174056Female22177667Female3518678Female23189489Male64193910Female3019721011Male6719141112Female3519991213Female5820151314Female2420771415Male3720131516Male2220791617Female3521351718Male2021661819Male5223291920Female3523982021Male3524352122Male2524732223Female462552324Male3125732425Female5428142526Male2928822627Female4528322728Male3528612829Female4029312930Female232987..................170171Male408713171172Male288775172173Male368710173174Male368792174175Female528813175176Female308886176177Male588815177178Male278869178179Male599314179180Male359390180181Female379732181182Female329786182183Male469815183184Female299888184185Female419939185186Male309997186187Female5410124187188Male2810168188189Female4110317189190Female3610385190191Female3410323191192Female3210369192193Male331138193194Female3811391194195Female4712016195196Female3512079196197Female4512628197198Male3212674198199Male3213718199200Male3013783
200 rows × 5 columns
In [4]:
X = dataset.iloc[:, 3:5].values # 顾客年收入,顾客购物指数
X
Out[4]:
array([[ 15, 39],
[ 15, 81],
[ 16, 6],
[ 16, 77],
[ 17, 40],
[ 17, 76],
[ 18, 6],
[ 18, 94],
[ 19, 3],
[ 19, 72],
[ 19, 14],
[ 19, 99],
[ 20, 15],
[ 20, 77],
[ 20, 13],
[ 20, 79],
[ 21, 35],
[ 21, 66],
[ 23, 29],
[ 23, 98],
[ 24, 35],
[ 24, 73],
[ 25, 5],
[ 25, 73],
[ 28, 14],
[ 28, 82],
[ 28, 32],
[ 28, 61],
[ 29, 31],
[ 29, 87],
[ 30, 4],
[ 30, 73],
[ 33, 4],
[ 33, 92],
[ 33, 14],
[ 33, 81],
[ 34, 17],
[ 34, 73],
[ 37, 26],
[ 37, 75],
[ 38, 35],
[ 38, 92],
[ 39, 36],
[ 39, 61],
[ 39, 28],
[ 39, 65],
[ 40, 55],
[ 40, 47],
[ 40, 42],
[ 40, 42],
[ 42, 52],
[ 42, 60],
[ 43, 54],
[ 43, 60],
[ 43, 45],
[ 43, 41],
[ 44, 50],
[ 44, 46],
[ 46, 51],
[ 46, 46],
[ 46, 56],
[ 46, 55],
[ 47, 52],
[ 47, 59],
[ 48, 51],
[ 48, 59],
[ 48, 50],
[ 48, 48],
[ 48, 59],
[ 48, 47],
[ 49, 55],
[ 49, 42],
[ 50, 49],
[ 50, 56],
[ 54, 47],
[ 54, 54],
[ 54, 53],
[ 54, 48],
[ 54, 52],
[ 54, 42],
[ 54, 51],
[ 54, 55],
[ 54, 41],
[ 54, 44],
[ 54, 57],
[ 54, 46],
[ 57, 58],
[ 57, 55],
[ 58, 60],
[ 58, 46],
[ 59, 55],
[ 59, 41],
[ 60, 49],
[ 60, 40],
[ 60, 42],
[ 60, 52],
[ 60, 47],
[ 60, 50],
[ 61, 42],
[ 61, 49],
[ 62, 41],
[ 62, 48],
[ 62, 59],
[ 62, 55],
[ 62, 56],
[ 62, 42],
[ 63, 50],
[ 63, 46],
[ 63, 43],
[ 63, 48],
[ 63, 52],
[ 63, 54],
[ 64, 42],
[ 64, 46],
[ 65, 48],
[ 65, 50],
[ 65, 43],
[ 65, 59],
[ 67, 43],
[ 67, 57],
[ 67, 56],
[ 67, 40],
[ 69, 58],
[ 69, 91],
[ 70, 29],
[ 70, 77],
[ 71, 35],
[ 71, 95],
[ 71, 11],
[ 71, 75],
[ 71, 9],
[ 71, 75],
[ 72, 34],
[ 72, 71],
[ 73, 5],
[ 73, 88],
[ 73, 7],
[ 73, 73],
[ 74, 10],
[ 74, 72],
[ 75, 5],
[ 75, 93],
[ 76, 40],
[ 76, 87],
[ 77, 12],
[ 77, 97],
[ 77, 36],
[ 77, 74],
[ 78, 22],
[ 78, 90],
[ 78, 17],
[ 78, 88],
[ 78, 20],
[ 78, 76],
[ 78, 16],
[ 78, 89],
[ 78, 1],
[ 78, 78],
[ 78, 1],
[ 78, 73],
[ 79, 35],
[ 79, 83],
[ 81, 5],
[ 81, 93],
[ 85, 26],
[ 85, 75],
[ 86, 20],
[ 86, 95],
[ 87, 27],
[ 87, 63],
[ 87, 13],
[ 87, 75],
[ 87, 10],
[ 87, 92],
[ 88, 13],
[ 88, 86],
[ 88, 15],
[ 88, 69],
[ 93, 14],
[ 93, 90],
[ 97, 32],
[ 97, 86],
[ 98, 15],
[ 98, 88],
[ 99, 39],
[ 99, 97],
[101, 24],
[101, 68],
[103, 17],
[103, 85],
[103, 23],
[103, 69],
[113, 8],
[113, 91],
[120, 16],
[120, 79],
[126, 28],
[126, 74],
[137, 18],
[137, 83]], dtype=int64)
三、寻找最佳组数(手肘法:找到拐点)
In [12]:
from sklearn.cluster import KMeans
wcss = []
for i in range(1,11):
kmeans = KMeans(n_clusters = i,max_iter = 300,n_init = 10,init= 'k-means++',random_state = 0)
kmeans.fit(X)
wcss.append(kmeans.inertia_) # 计算组间距离
plt.plot(range(1,11),wcss)
plt.title(u'手肘图像')
plt.xlabel(u'集群数')
plt.ylabel(u'组间距离')
plt.show()
从图中可以看出,最佳组数为5
四、K平均聚类算法训练
In [18]:
kmeans = KMeans(n_clusters = 5,max_iter = 300,n_init = 10,init= 'k-means++',random_state = 0) # init= 'k-means++' 避免了初始化陷阱
y_kmeans = kmeans.fit_predict(X)
print(y_kmeans ) # 打印分类结果
print(kmeans.cluster_centers_)# 打印聚类后的中心点
[4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4
3 4 3 4 3 4 1 4 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 2 0 2 1 2 0 2 0 2 1 2 0 2 0 2 0 2 0 2 1 2 0 2 0 2
0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0 2 0
2 0 2 0 2 0 2 0 2 0 2 0 2 0 2]
[[ 88.2 17.11428571]
[ 55.2962963 49.51851852]
[ 86.53846154 82.12820513]
[ 25.72727273 79.36363636]
[ 26.30434783 20.91304348]]
五、可视化集群
In [15]:
plt.scatter(X[y_kmeans == 0,0],X[y_kmeans == 0,1], s=100, c = 'red',label='Cluster 0') # 理性
plt.scatter(X[y_kmeans == 1,0],X[y_kmeans == 1,1], s=100, c = 'blue',label='Cluster 1') # 标准
plt.scatter(X[y_kmeans == 2,0],X[y_kmeans == 2,1], s=100, c = 'green',label='Cluster 2') # 目标客户
plt.scatter(X[y_kmeans == 3,0],X[y_kmeans == 3,1], s=100, c = 'cyan',label='Cluster 3') # 非理性,小心
plt.scatter(X[y_kmeans == 4,0],X[y_kmeans == 4,1], s=100, c = 'magenta',label='Cluster 4')# 消费敏感用户
plt.scatter(kmeans.cluster_centers_[:,0],kmeans.cluster_centers_[:,1], s=300, c = 'yellow',label='Centroids')
plt.title(u'顾客群组')
plt.xlabel(u'年收入')
plt.ylabel(u'购物指数')
plt.legend() # 标签显示
plt.show()
六、项目地址
https://coding.net/u/RuoYun/p/Python-of-machine-learning/git/tree/master/00机器学习/4.集群和关联规则学习/1.K平均聚类?public=true