美文网首页
k-means算法

k-means算法

作者: 就是果味熊 | 来源:发表于2020-03-05 15:38 被阅读0次
    #%%
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import cv2
    #%%
    
    def assignment(df,centroids,colmap):
        for i in centroids.keys():
            df['distance_from_{}'.format(i)] = (
                    np.sqrt(
                            (df['x'] - centroids[i][0])**2 + (df['y'] - centroids[i][1])**2
                            )
                    )
        distance_from_centroid_id = ['distance_from_{}'.format(i) for i in centroids.keys()]
        # df.loc切片操作,idxmin返回最小值的索引(取决于比较的axis)
        df['closest'] = df.loc[:,  distance_from_centroid_id].idxmin(axis=1)
        # lstrip()截掉字符串左边的空格或指定字符
        df['closest'] = df['closest'].map(lambda x : int(x.lstrip('distance_from_')))
        df['color'] = df['closest'].map(lambda x : colmap[x])
        return df
    
    
    def update(df,centroids):
        # recalculate the centroids
        for i in centroids.keys():
            centroids[i][0] = np.mean(df[df['closest'] == i]['x'])
            centroids[i][1] = np.mean(df[df['closest'] == i]['y'])
        return centroids
    
    #%%
    
    def main():
        df = pd.DataFrame({
                'x' : [12, 20, 28, 18, 10, 29, 33, 24, 45, 45, 52, 51, 52, 55, 53, 55, 61, 64, 69, 72, 23],
                'y' : [39, 36, 30, 52, 54, 20, 46, 55, 59, 63, 70, 66, 63, 58, 23, 14, 8, 19, 7, 24, 77]
                }
                )
        
        k = 3
        # randomly choose centroids
        centroids = {
                i : [np.random.randint(0,80), np.random.randint(0,80)] for i in range(k)
                }
        colmap = {0:'r', 1:'g', 2:'b'}
    #    print(df)
    #    print(centroids)
        df = assignment(df,centroids,colmap)
        
        plt.scatter(df['x'], df['y'], color=df['color'], alpha=0.5, edgecolors='k')
        for i in centroids.keys():
            plt.scatter(*centroids[i],color=colmap[i],linewidth=6)
        plt.xlim(0, 80)
        plt.ylim(0, 80)
        plt.show()
        
        for i in range(10):
            plt.close()
            
            closest_centroids = df['closest'].copy(deep=True)
            centroids = update(df, centroids)
            
            plt.scatter(df['x'],df['y'],color=df['color'],alpha=0.5,edgecolors='k')
            for i in centroids.keys():
                plt.scatter(*centroids[i],color=colmap[i],linewidths=6)
            plt.xlim(0,80)
            plt.ylim(0,80)
            plt.show()
            
            df = assignment(df,centroids,colmap)
            
            if closest_centroids.equals(df['closest']):
                break
        
    if __name__ == '__main__':
        main()
    

    相关文章

      网友评论

          本文标题:k-means算法

          本文链接:https://www.haomeiwen.com/subject/gbjfrhtx.html