美文网首页大作业
K-means报告(模式识别3)

K-means报告(模式识别3)

作者: 小火伴 | 来源:发表于2018-01-17 17:18 被阅读59次
    K-Means-1.png K-Means-2.png K-Means-3.png K-Means-4.png

    程序


    
    # coding: utf-8
    
    # # 第三次模式识别作业
    
    
    # In[1]:
    
    get_ipython().magic('matplotlib inline')
    
    
    # In[2]:
    
    from sklearn.datasets import load_iris
    import matplotlib.pyplot as plt
    import numpy as np
    
    
    # In[ ]:
    
    K=3
    iris = load_iris()
    X = iris.data
    Y = iris.target
    
    
    # # 随机洗牌数据
    
    # In[13]:
    
    shuffle_para=np.arange(Y.shape[0])
    np.random.shuffle(shuffle_para)
    X,Y=X[shuffle_para],Y[shuffle_para]
    
    
    # # 每次随机一样
    
    # In[ ]:
    
    np.random.seed(980406)
    
    
    # # 分类
    
    # In[ ]:
    
    cla=[]
    for i in range(K):
        cla.append(np.where(Y==i))
    
    
    # # 初始点
    
    # In[14]:
    
    initial_point=X[np.random.randint(0,X.shape[0],(3,))]
    initial_point
    
    
    # In[15]:
    
    mean_point=initial_point
    
    
    # In[16]:
    
    print(X.shape)
    
    
    # # 开始迭代
    
    # In[17]:
    
    accu=[]
    n=0
    while True:
        # 计算到k个中心的欧氏距离
        distances=[]
        for p in mean_point:
            distances.append(np.linalg.norm((X-p),axis=1))
            pass
        distances=np.array(distances)
        y=np.argmin(distances,0)
        y=np.array(y,dtype=int)
        # 保存上次点
        last_point=mean_point
        # 生成新点
        mean_point=[]
        for i in range(K):
            mean_point.append(np.mean(X[(y==i),:],axis=0))
        mean_point=np.array(mean_point)
        J=np.linalg.norm(last_point-mean_point,axis=1)
        # 每一个都是<0.01
        if False not in list(J<0.001):
            break
            pass
        if(n==20):
            print('到达最大迭代次数')
            break
        
        # 看把原始数据的每一类还保留多少个为一类
        corr=0
        for c in cla:
            corr+=(max(np.bincount(y[c])))
        accu.append(corr/Y.shape[0])
        print(accu[-1])
        n+=1
        pass
    
    
    # # 画图
    
    # In[18]:
    
    plt.ylim([0.6,1])
    plt.xticks(list(range(n)), rotation=20)
    plt.xlabel('Interations')
    plt.ylabel('Accuracy')
    plt.plot(np.arange(n),accu)
    
    
    # In[19]:
    
    mean_point.shape
    
    
    # In[20]:
    
    label=(('Sepal length','Sepal width'),('Petal length','Petal width'))
    def scat(i):
        plt.scatter(X[:, i*2], X[:,2*(i+1)-1], c=y,marker='+')
        plt.scatter(mean_point[:,i*2],mean_point[:,(i+1)*2-1],c=np.arange(K),marker='o')
        plt.xlabel(label[i][0])
        plt.ylabel(label[i][1])
    i=0
    scat(i)
    
    
    # In[21]:
    
    scat(1)
    
    
    

    相关文章

      网友评论

        本文标题:K-means报告(模式识别3)

        本文链接:https://www.haomeiwen.com/subject/kqpuoxtx.html