kmeans算法
kmeans算法是一种聚类算法,用于无标签数据的自行归类。
讲kmeans的原理有很多,个人参考的是以下一个
刘建平:K-Means聚类算法原理
需要注意的是,kmeans算法只适用于凸数据集,无法适用于凹数据集。
python实现
个人使用numpy对kmeans类进行了实现,如以下代码所示
#kmeans.py
import numpy as np
import matplotlib.pyplot as plt
from numpy.core.defchararray import center
#K: num of clusters
#N: max iteration
class kmeans:
def __init__(self, K:int, N:int=100):
self.K = K
self.N = N
self.ret_cluster = None
self.ndpoints = None
self.centerpoints = None
def train(self, ndpoints:np.ndarray, show_process:bool=True):
self.ndpoints = ndpoints#should be something like a 2-dim array
#select k start points out of ndpoints
if self.K > np.shape(ndpoints)[0]:
print("[-] K too big")
return
else:
centerpoints_index = set()
while len(centerpoints_index) < self.K:
centerpoints_index.add(np.random.randint(0, self.K))
self.centerpoints = self.ndpoints[list(centerpoints_index)]
#begin clustering
self.ret_cluster = list()
for i in range(self.K):
self.ret_cluster.append(list())
for i in range(self.N):
self.ret_cluster = list()
for ii in range(self.K):
self.ret_cluster.append(list())
for j in range(np.shape(self.ndpoints)[0]):
#find the min distance and attach the point to one cluster
dist_k = np.ndarray((self.K,), dtype=np.float)
for k in range(self.K):
dist_k[k] = np.linalg.norm(self.centerpoints[k] - self.ndpoints[j])
self.ret_cluster[dist_k.argmin()].append(self.ndpoints[j])
#re-calculate the centerpoints
for k in range(self.K):
self.centerpoints[k] = np.average(np.array(self.ret_cluster[k]), axis=0)
def print_clusters(self):
if self.ret_cluster is None:
print("[-] no clusters are created yet")
return
else:
print("[+] Num of clusters : ", len(self.ret_cluster), sep=' ', end='\n')
ind = 0
for cluster in self.ret_cluster:
ind += 1
print("cluster", ind, ":", sep=' ', end='\n')
print(cluster, end="\n\n")
def draw_clusters_2d(self):
if self.ret_cluster is None:
print("[-] no clusters are created yet")
return
elif np.shape(self.ndpoints)[1] != 2:
print("[-] dimension higher than 2, which is not considered by this kmeans instance")
return
else:
print("[+] drawing by matplotlib")
#draw clusters using matplotlib.pyplot
ax = plt.figure(0)
for cluster in self.ret_cluster:
color = (np.random.random(), np.random.random(), np.random.random())
for point in cluster:
plt.scatter(point[0], point[1], c=color)
for center in self.centerpoints:
plt.scatter(center[0], center[1], marker='+')
plt.show()
通过代码创建对象实例并进行训练。
#main.py
import kmeans
import numpy as np
kmeans_cluster_machine = kmeans.kmeans(3)
ndpoints = np.array([
[-1.26, 0.46],
[-1.15, 0.49],
[-1.19, 0.36],
[-1.33, 0.28],
[-1.06, 0.22],
[-1.27, 0.03],
[-1.28, 0.15],
[-1.06, 0.08],
[-1.00, 0.38],
[-0.44, 0.29],
[-0.37, 0.45],
[-0.22, 0.36],
[-0.34, 0.18],
[-0.42, 0.06],
[-0.11, 0.12],
[-0.17, 0.32],
[-0.27, 0.08],
[-0.49, -0.34],
[-0.39, -0.28],
[-0.40, -0.45],
[-0.15, -0.33],
[-0.15, -0.21],
[-0.33, -0.30],
[-0.23, -0.45],
[-0.27, -0.59],
[-0.61, -0.65],
[-0.61, -0.53],
[-0.52, -0.53],
[-0.42, -0.56],
[-1.39, -0.26]])
kmeans_cluster_machine.train(ndpoints)
kmeans_cluster_machine.print_clusters()
kmeans_cluster_machine.draw_clusters_2d()
输出结果为
[+] Num of clusters : 3
cluster 1 :
[array([-1.26, 0.46]), array([-1.15, 0.49]), array([-1.19, 0.36]), array([-1.33, 0.28]), array([-1.06, 0.22]), array([-1.27,
0.03]), array([-1.28, 0.15]), array([-1.06, 0.08]), array([-1. , 0.38]), array([-1.39, -0.26])]
cluster 2 :
[array([-0.44, 0.29]), array([-0.37, 0.45]), array([-0.22, 0.36]), array([-0.34, 0.18]), array([-0.42, 0.06]), array([-0.11,
0.12]), array([-0.17, 0.32]), array([-0.27, 0.08])]
cluster 3 :
[array([-0.49, -0.34]), array([-0.39, -0.28]), array([-0.4 , -0.45]), array([-0.15, -0.33]), array([-0.15, -0.21]), array([-0.33,
-0.3 ]), array([-0.23, -0.45]), array([-0.27, -0.59]), array([-0.61, -0.65]), array([-0.61, -0.53]), array([-0.52, -0.53]), array([-0.42, -0.56])]
[+] drawing by matplotlib
画出的图为
kmeans聚类结果
可以直观地看出kmeans实现了预期的聚类效果
总结
- kmeans是一种简单而且高效的算法,可以对数据进行很好的聚类,但是也有缺点,由其缺点衍生出kmeans++、KNN、BIRCH等算法
- 进行kmeans类的实现过程中,有许多子算法值得注意,比如:从一个序列中不重复的挑选个数固定的部分元素,本类采用了使用python集合,向其中添加随机元素避免重复的方法进行处理。
网友评论