美文网首页
SIFT + SVM/KMeans 图像分类

SIFT + SVM/KMeans 图像分类

作者: 缪斯muse | 来源:发表于2018-06-28 17:34 被阅读0次

    先贴出代码:

    import cv2
    import numpy as np
    import os, codecs
    from sklearn.cluster import KMeans
    from sklearn import svm, metrics
    from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV
    from sklearn.metrics import confusion_matrix
    
    def get_sift(img_path):
        img = cv2.imread(img_path)
        img = cv2.resize(img, (32,32), interpolation = cv2.INTER_CUBIC)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        detector = cv2.xfeatures2d.SIFT_create()
        kp,des = detector.detectAndCompute(gray,None)
        return des
    
    def k_clusters(data_list, k, randomState = None):
        print("Generating SIFT features")
    
        feat_vec = []
        count = 0
        none_type = []
        for data in data_list:
            count += 1
            sift_feat = get_sift(data)
            if sift_feat is None:
                none_type.append(data)
                continue
            sift_feat = sift_feat.reshape(-1, 1)
            feat_vec.append(sift_feat[0].tolist())
            if count%50 == 0:
                print("The current state: " + str(count))
        input = np.array(feat_vec)
        kmeans = KMeans(n_clusters = k, random_state = randomState).fit(input)
    
        for data in none_type:
            del data_list[data]
        grid = svm_classify(data_list, input)
    
        return kmeans
    
    def svm_classify(data_list, x):
        print("Doing SVM")
    
        y = []
        for key in data_list:
            y.append(data_list[key])
        y = np.array(y).flatten()
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 1)
        svc = svm.SVC(decision_function_shape="ovo")
        grid = GridSearchCV(svc,
             param_grid=[ { "kernel":["rbf", "linear", "poly"],
                            "C":[0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4],
                            "gamma": [1/0.1**2, 1/0.5**2, 1, 1/2**2, 1/4**2],
                            "degree": [2, 3] } ],
            cv=5, n_jobs=-1).fit(x_train, y_train)
    
        print("**********************")
        print("The best accuracy during validation is: " + str(grid.best_score_) +
              "\nKernel : " + str(grid.best_estimator_.kernel) +
              "\nUsing 10-fold cross-validation" +
              ", mean score for cross-validation is " + str(grid.best_score_) +
              "\nPenalty parameter C = " + str(grid.best_estimator_.C) +
              "\nGamma = " + str(grid.best_estimator_.gamma) +
              "\nDegree = " + str(grid.best_estimator_.degree))
        print("The test accuracy is: " + str(metrics.accuracy_score(y_test, grid.predict(x_test))))
        return grid
    
    
    def label_dict(data_list, labels):
        files = [data.split('/')[-1] for data in data_list]
        return dict(zip(files, labels))
    
    def read_file(path):
        print("Generating data_list")
    
        data_list = {}
        test_list = {}
        file_list = os.listdir(path)
        for file in file_list:
            if file != '.DS_Store':
                img_list = os.listdir(path + '/' + file)
                for img in img_list:
                    if img != '.DS_Store':
                        data_list[os.path.join(path + '/' + file, img)] = file
                        test_list[img] = file
        return data_list,test_list
    
    def calculate_accuracy(real, result):
        sum = 0
        right = 0
        for key in result:
            if str(result[key]+1) == real[key][-1]:
                right += 1
            sum += 1
        return right/sum
    
    if __name__ == "__main__":
        data_list, test_list = read_file()
        # data_list, test_list = read_file()
    
        km = k_clusters(data_list, 4)
        result = label_dict(data_list, km.labels_)
    
        print("**********************")
        print("The KMeans accuracy is: " + str(calculate_accuracy(test_list, result)))
        print("**********************")
    

    相关文章

      网友评论

          本文标题:SIFT + SVM/KMeans 图像分类

          本文链接:https://www.haomeiwen.com/subject/xsknyftx.html