KNN(K Nearest Neighbor)
k鄰近算法可以算是一種監督式學習算法,從部分已知的資料來推測未知的資料。
我們以下圖為例子,綠色是未知點我們要預測他,藍跟紅是已知點。
首先我們要先假定一個k,這個k表示從未知點計算多少個離他最近的幾個已知點,假設k=3,我們會先計算離綠色最近的3個已知點,依距離排列是(藍、藍、紅)。
我們可以自己設定選擇法,如從中隨機選擇一個來預測、或從中選擇多數者、或以距離做加權選擇合適者、或降低k值直到只有一個類別。
k值選擇
k值選擇其中一種方法是我們可以從數個已知點,以不同的k值去預測他們,選擇正確率高的k值。
多維距離計算
計算向量空間上的距離幾種方法
來自wiki百科缺點
- 資料不是關鍵特徵時(資料的類別只取決於其中幾個特徵),那麼k鄰近算法預測可能與真實相差甚遠,這時也可以使用feature select的一些方法,如PCA主成分分析進行降維之後再進行k鄰近算法分類。
- 維度大的時候,非常容易受雜訊影響,因為維度一多任何一個維的誤差都會影響距離的估算。
- 維度大的時候彼此離的距離越遠,資料非常離散,點與點的距離會與平均距離相差不大,造成計算距離變得沒有意義。
K-mean(K平均算法)
K-mean與KNN概念類似,而K-mean是非監督式的一種集群(clusters)算法,我們利用計算點之間的距離將鄰近的點分為一個群,有兩個原則1.群組中心式所有同一個群組中的算平均,2.群組中每一個點都比其他群組的點更接近中心。
如果使用窮舉法會非常耗時,所以K-mean會使用最大期望算法(Expectation-maximization)。
- 最大期望算法:
1.隨機選擇與類別數量相同的幾個點當作群組中心
2.E-step:隨機指定幾個點找到與他最近的群組中心歸為同一類
3.M-step:計算上一步同一類的平均值為新的群組中心
4.判定所有舊群組中心是否接近新的群組中心(收斂),接近停止,不接近則重複2.3.4步驟。
-
收斂
但最大期望算法在不同的初始群組中心並不一定會收斂到全局最佳解,所以sklearn預設使用10組不同的隨機群組中心執行10次。 -
Kernel
K-mean也可以使用一些SVM技巧擬合非線性結果。
sklearn實作
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
#資料建立
from sklearn.datasets.samples_generator import make_blobs
X,Y = make_blobs(n_samples=300,n_features=2,centers=4,cluster_std=1.5,random_state=0)
plt.scatter(X[:,0],X[:,1],c=Y,cmap='jet')
<matplotlib.collections.PathCollection at 0x19cd89a0a20>
from sklearn.cluster import KMeans
KM = KMeans(n_clusters=4)
KM.fit(X)
predict = KM.predict(X)
plt.scatter(X[:,0],X[:,1],c=predict,cmap='jet')
<matplotlib.collections.PathCollection at 0x19cd92cd7b8>
from sklearn.metrics import confusion_matrix
import seaborn as sns
print(Y[:20],'\n',predict[:20])
confusion = confusion_matrix(Y, predict)
print(confusion)
sns.heatmap(confusion,annot=True)
[1 3 0 3 1 1 2 0 3 3 2 3 0 3 1 0 0 1 2 2]
[2 0 3 0 3 1 2 3 0 0 2 0 3 0 1 2 3 1 2 2]
[[ 5 7 8 55]
[ 0 63 8 4]
[ 3 2 64 6]
[64 0 2 9]]
<matplotlib.axes._subplots.AxesSubplot at 0x19cd93c1860>
new_predict = np.vectorize({0:3,1:1,2:2,3:0}.get)(predict)
Accuracy = np.mean(Y == new_predict)
print('Accuracy:',Accuracy)
Accuracy: 0.82
error = (Y != new_predict).astype(np.int8)
plt.scatter(X[:,0],X[:,1],c=error,cmap='jet',alpha=0.6)
<matplotlib.collections.PathCollection at 0x19cd9492ba8>
KNN與K-mean差異
https://zhuanlan.zhihu.com/p/31580379
其他文章
https://blog.csdn.net/luanpeng825485697/article/details/78962316
https://ithelp.ithome.com.tw/articles/10197110
网友评论