首先是K-menas算法函数,这里使用3维形式,尽管Iris数据集有4维,但是考虑到4维数据无法绘制图像,所以在这里选择三维。
function [ resX,resY,resZ,recordd] = Kmeans( x,y,z,k )
%resX 存储结果
%resY
%resZ
%centerX 存储聚类中心点
%centerY
%centerZ
%oldcenterX 存储上一次的聚类中心,用于判断是否达到结束循环的条件
%oldcenterY
%oldcenterZ
%recordd 记录每个聚类中的元素个数
resX=zeros(k,length(x));
resY=zeros(k,length(y));
resZ=zeros(k,length(x));
centerX=zeros(1,k);
centerY=zeros(1,k);
centerZ=zeros(1,k);
oldcenterX=zeros(1,k);
oldcenterY=zeros(1,k);
oldcenterZ=zeros(1,k);
recordd=zeros(1,k);
for i=1:k
%生成随机聚类中心
centerX(i)=x(round(rand()*length(x)));
centerY(i)=y(round(rand()*length(x)));
centerZ(i)=z(round(rand()*length(x)));
%保证聚类中心不能够重复
if(i > 1 && centerX(i) == centerX(i-1) && centerY(i) == centerY(i-1)&¢erZ(i) == centerZ(i-1))
i=i-1;
end
end
while 1
resX(:)=0; %将每个聚类的点清空,用新的聚类中心重新进行聚类
resY(:)=0;
resZ(:)=0;
recordd(:)=0;
for i=1:length(x)
%给每一个点分类
belong=1;
for j=2:k
%采用欧式距离,寻找到当前点最近的聚类中心
if(power((x(i)-centerX(belong)),2)+power((y(i)-centerY(belong)),2)+power((z(i)-centerZ(belong)),2)>power((x(i)-centerX(j)),2)+power((y(i)-centerY(j)),2)+power((z(i)-centerZ(j)),2))
belong=j;
end
end
%当前聚类中心内部点数加1
recordd(belong)=recordd(belong)+1;
%将此点放入聚类中心中
resX(belong,recordd(belong))=x(i);
resY(belong,recordd(belong))=y(i);
resZ(belong,recordd(belong))=z(i);
end
%存放聚类中心
oldcenterX=centerX;
oldcenterY=centerY;
oldcenterZ=centerZ;
%更新聚类中心
for i=1:k
%recordd(i)=0表示此类内部无点,则保持聚类中心不变
if(recordd(i)==0)
continue
end
centerX(i)=sum(resX(i,:))/recordd(i);
centerY(i)=sum(resY(i,:))/recordd(i);
centerZ(i)=sum(resZ(i,:))/recordd(i);
end
%如果聚类中心没有改变就停止
if mean([centerX == oldcenterX centerY == oldcenterY centerZ==oldcenterZ]) == 1
break;
end
end
centerX
centerY
centerZ
%下面只是优化内存
maxPos = max(recordd);
resX = resX(:,1:maxPos);
resY = resY(:,1:maxPos);
resZ = resZ(:,1:maxPos);
end
然后是读取数据和显示函数
由于.data数据集本身matlab读取较为困难,所以将其转化为了.txt形式
%读取Iris数据
[data1,data2,data3,data4,data5]=textread('Iris.txt','%f%f%f%f%s','delimiter',',');
%选取k=3
k = 3;
%调用函数,这里选择的是feature1、feature2、feature4
[resX resY resZ record] = Kmeans(data1,data2,data4,k);
%绘制图像
for i = 1:length(record)
plot3(resX(i,1:record(i)),resY(i,1:record(i)),resZ(i,1:record(i)),'*')
hold on
end
% 下面是标记出每一个类别的类别代表点
for i = 1:length(record)
plot3(mean(resX(i,1:record(i)),2)',mean(resY(i,1:record(i)),2)',mean(resZ(i,1:record(i)),2)','Marker','square','Color','k','MarkerFaceColor','k','LineStyle','none')
end
网友评论