美文网首页
Fisher线性判别(模式识别1)

Fisher线性判别(模式识别1)

作者: 小火伴 | 来源:发表于2018-01-17 16:50 被阅读59次
0 1 2 3

程序

main.py

# %%
import numpy as np
from Fisher import Fisher
import process_data as pd
import Eval
import time
import os





# To->main

### t0=time.time()
# file_name = './Sonar.csv'
file_name = './usps_3&8.csv'

cross_validation_number = 10
class_number = 2


# dataset = dict()

def main():
    data = dict(x=[[], []], y=[[], []])
    val_data = dict(x=[[], []], y=[[], []])
    Pred_True = [0] * class_number
    pred_num=[0,0]
    L = [0, 0]
    last_j=float('-inf')

    datas, L[0], L[1], Lall = pd.read_datas(file_name)
    data_block = pd.split_datas(datas, cross_validation_number)
    # i = 0
    for cla, val in data_block:
        # print(i)
        # i += 1
        for j in range(len(cla)):
            (data['x'][j], data['y'][j], val_data['x'][j], val_data['y'][j]) = (
            cla[j][:, :-1], cla[j][:, -1:], val[j][:, :-1], val[j][:, -1:])
        myfisher = Fisher(L[0], L[1], Lall, data)

        W = myfisher.W_Direction()  # 投影参数
        J=Eval.simple_Fisher_Friterion(W,myfisher.between_class_scatter_matrix())#评估一下投影效果,并更新投影向量
        if J>last_j:
            last_j=J
            W_great=W
        W0 = myfisher.OneKey_W0(W)  # 阈值
        # %%
        for j in range(class_number):
            predict_y = myfisher.Pred_result(W, val_data['x'][j], W0)
            # print(predict_y)
            pred_num[j]+=len(val_data['x'][j])
            Pred_True[j] += Eval.Comp_as1(predict_y, val_data['y'][j][0][0])
            # if j==1:
            #     print(Eval.Comp_as1(predict_y, val_data['y'][j][0][0]))

            # print(W,W0)
            # print(S_W)
            # print(time.time()-t0)
            # break
    print('  预测      第1类   第2类   全部')
    for j in range(class_number):
        print('实际第{0}类    {1}     {2}      {3}'.format(j+1,Pred_True[j],pred_num[j]-Pred_True[j],pred_num[j]))
    print('mixed_accuracy:%f' % ((Pred_True[0]+pred_num[1]-Pred_True[1])/(pred_num[0]+pred_num[1])))

    # print(os.path.isfile(Eval.w_filename),W_great)
    # if os.path.isfile(Eval.w_filename)==False and W_great:
    print('评价函数最大的W保存在',os.path.abspath(Eval.w_filename))
    Eval.save_par(W_great)
    pass


if __name__ == '__main__':
    main()

'''
sonar
  预测      第1类   第2类   全部
实际第1类    63     33      96
实际第2类    36     72      108
mixed_accuracy:0.661765

usps_3&8
  预测      第1类   第2类   全部
实际第1类    789     31      820
实际第2类    20     680      700
mixed_accuracy:0.966447
'''

Fisher.py

import numpy as np

x = []  # n*d
w = []  # d*1
y = []  # 0*n

x = np.array(x)
w = np.array(w)
y = np.array(y)


def within_class_scatter_matrix(x, m):
    '''

    :param x: n*d->n*d*1
    :param m: d*1->1*d*1
    :return:类內离散度矩阵 d*d
    '''
    x = x[:, :, np.newaxis]
    m_d = len(m)
    m = m[np.newaxis, :, :]  # 1*d*1
    n, d = np.shape(x)[0], np.shape(x)[1]
    assert d == m_d
    temp = x - m
    temp_tran = np.transpose(temp, (0, 2, 1))  # n*1*d
    # print(list(map(np.shape,[x-m,x_transpose-m_transpose])))
    # 三维点乘变循环
    s = []
    for a, b in zip(temp, temp_tran):
        s.append(np.dot(a, b))
    s = np.array(s)
    s = np.sum(s, 0)
    assert s.shape == (d, d)
    return s


def pooled_within_class_scatter_matrix(wcsm1, wcsm2, L1, L2, Lall):
    '''

    :param wcsm1: d*d
    :param wcsm2:
    :return: 总类內离散度矩阵 d*d
    '''
    return (L1 * wcsm1 + L2 * wcsm2) / Lall

def projective_mean(w, m):
    '''

    :param w: 投影方向 d*1
    :param m: 类均值向量 d*1
    :return:投影后的均值 0*0 scalar
    '''

    def pred_y(x, w):
        '''
        :param x: n*d
        :param w: d*1
        :return: 预测值 0*n
        '''
        return x.dot(w).ravel()

    assert np.shape(w) == np.shape(m)
    # print('w:',np.transpose(w).shape,'m',m.shape)
    # print(pred_y(np.transpose(w),m)[0])
    return pred_y(np.transpose(w), m)[0]


def projective_within_class_scatter_matrix(y, M):
    '''

    :param y: 投影后的样本0*n
    :param M: 均值向量投影后 0*0
    :return:投影后类內离散度0*0 scalar
    '''
    M = np.array([M])  # 0*1
    return pow(np.linalg.norm(y - M), 2)


def projective_between_class_scatter_matrix(m1, m2):
    '''

    :param m1: 投影后均值
    :param m2:
    :return: 投影后类间离散度0*0 scalar
    '''
    return pow((m1 - m2), 2)


def w_direction(S_W, m1, m2):
    '''

    :param S_W:总类內离散度矩阵 d*d
    :param m1: 类均值向量 d*1
    :param m2:
    :return: d*1 投影参数(大小无所谓)
    '''
    # print(S_W**-1)
    # print()
    return np.dot(np.linalg.inv(S_W), (m1 - m2))


def W0(m1, m2):
    '''
    不考虑先验概率
    :param m1: scalar 投影后均值
    :param m2:
    :return: 阈值
    '''
    return -(m1 + m2) / 2

class Fisher(object):
    def __init__(self, L1, L2, Lall, data):
        self.L1 = L1
        self.L2 = L2
        self.Lall = Lall
        self.M = [[], []]  # 均值向量
        self.M[0], self.M[1] = map(self.__class_mean_vector, [data['x'][0], data['x'][1]])
        self.data_x=data['x']

    def __class_mean_vector(self, x):
        '''

        :param x:n*d
        :return:类均值向量 d*1
        '''
        return np.transpose(np.mean(x, 0, keepdims=True))

    def Pred_result(self, w, x, w0):
        '''

        :param w:投影参数d*1
        :param x: n*d
        :param w0: 0*0
        :return: 决策结果 (n,)
        '''
        w0 = np.array([w0])
        temp = np.dot(x, w)
        return temp.ravel() + w0

    def W_Direction(self):
        Sw1, Sw2 = within_class_scatter_matrix(self.data_x[0], self.M[0]), within_class_scatter_matrix(self.data_x[1], self.M[1])
        SW = pooled_within_class_scatter_matrix(Sw1, Sw2, self.L1, self.L2, self.Lall)
        W = w_direction(SW, self.M[0], self.M[1])
        return W

    def OneKey_W0(self, W):
        m1, m2 = projective_mean(W, self.M[0]), projective_mean(W, self.M[1])
        return W0(m1, m2)

    def between_class_scatter_matrix(self):
        '''

        :param m1: 第一类均值向量 d*1
        :param m2: 第二类均值向量
        :return: 类间离散度矩阵 d*d
        '''
        m_ = self.M[0] - self.M[1]

        return np.dot(m_, np.transpose(m_))

Eval.py

import json
import numpy as np


#To->Eval
w_filename='w.json'

def save_par(w):
    with open(w_filename,'w') as fp:
        json.dump(list(w.ravel()),fp)

def load_par():
    with open(w_filename,'r') as fp:
        return np.array(json.load(fp))[:,np.newaxis]

def Comp_as1(y):
    # y=(y>0.0 if label>0 else y<=0.0)
    y= y>0.0
    return np.sum(y)

def simple_Fisher_Friterion(w,S_b):
    '''
    Fisher简化版的判别函数
    :param w: 投影向量d*1
    :param S_b: 类间离散度矩阵
    :return:
    '''
    return np.dot(np.dot(np.transpose(w),S_b),w)[0][0]

process_data.py

import logging
import numpy as np
import random

#%%
# To->process_data
## ---sonar
# label=('R','M')
# file_size=(208,61)
# label_dict={label[0]:1,label[1]:-1}
## ---usps
label=('3','8')
file_size=(9298,257)
label_dict={label[0]:3,label[1]:8}

def save_38(data):
    import csv
    with open('usps_3&8.csv','w',newline='') as csvfile:
        writer=csv.writer(csvfile)
        writer.writerows(data)
        print('写完了')
        exit(0)

def read_datas(file_name):
    '''
    读取并处理文件
    :param file_name:
    :return:
    '''
    datas1,datas2=[],[]
    with open(file_name,'r') as f:
        datas=f.readlines()
    #%%
    for i in range(len(datas)):
        datas[i]=datas[i][:-1]
        datas[i]=datas[i].split(',')
        # print(i)
        datas[i][:-1]=list(map(float,datas[i][:-1]))
        # print(datas[i][-1])
        # try:
        datas[i][-1]=label_dict[datas[i][-1]]#替换成数字
        # except KeyError:
        #     continue
    # print('数据大小',np.array(datas).shape)
    # assert np.array(datas).shape==file_size
    # 可以用转成数组处理
    for data in datas:
        # print(data[-1])
        #分别添加到两个数据
        if data[-1]==label_dict[label[0]]:
            datas1.append(data)
        elif data[-1]==label_dict[label[1]]:
            datas2.append(data)
        else:
            # continue
            print('这个标签 %s 未知,请看一下,在第%d行' % (data[-1],datas.index(data)))
    # 可以用转成数组处理
    #随机一下顺序
    # save_38(datas1+datas2)
    map(random.shuffle,[datas1,datas2])
    L1,L2,Lall=len(datas1),len(datas2),file_size[0]
    # assert (L1+L2)==Lall
    print('——————数据读取完成!——————')
    return ((datas1,datas2),L1,L2,Lall)
#%%
def split_datas(datas,n):
    '''
    验证集和训练集分开
    :param n: n折交叉验证
    :param datas: 打乱后的数据,
    :return: 若n=10,第一类,第二类,交叉验证集20*61
    '''
    data_block_len=[len(data)//n for data in datas]
    for i in range(n):
        val,cla=[],[]
        for j,data in enumerate(datas):
            val_left,val_right=data_block_len[j]*(i),data_block_len[j]*(i+1)
            val.append(np.array(data[val_left:val_right]))
            cla.append(np.array(data[:val_left]+data[val_right:]))
        yield cla,val

相关文章

网友评论

      本文标题:Fisher线性判别(模式识别1)

      本文链接:https://www.haomeiwen.com/subject/efbyoxtx.html