美文网首页
Machine_learning(持续补充......)

Machine_learning(持续补充......)

作者: 进击的STE | 来源:发表于2018-07-13 22:42 被阅读0次

1.线性回归

涉及到了批量梯度下降算法和正规方程求解

#-*- coding:utf-8 -*-
import numpy as np
import time 



#装饰器用来计算运行时间
def compute_for_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        call_back = func(*args, **kwargs)
        end_time = time.time()
        runntime = end_time - start_time
        return call_back, runntime
    return wrapper

#代价函数
def cost_function(theta, X, y):
    '''
    '''
    m = len(X)#
    return (X.dot(theta) - y).T.dot(X.dot(theta) - y) / (2 * m)
    

#正规方程
@compute_for_time
def regular_equation(X, y):
    '''
        X:特征矩阵
        y:标签向量
    返回:
        theta:相关系数向量
    '''
    theta = np.linalg.pinv(X.T.dot(X)).dot(X.T).dot(y)
    error = cost_function(theta, X, y)
    print('最终的代价值为:{}'.format(error))
    return theta
    
#梯度下降法
@compute_for_time
def compute_for_theta(alpha, maxloops, X, y):
    '''
       alpha:学习率
       maxloops:最大迭代次数
    '''
    m, n = X.shape
    
    #初始化theta
    theta = np.random.randn(n)
    #theta = np.zeros(n)
    
    minerror = 1e-3 #当梯度下降不明显时,即收敛阈值
    
    is_coveraged = False
    count = 0
    
    while not is_coveraged:
        
        sum_tmp = X.dot(theta) - y
        
        temp_theta = theta - alpha * (sum_tmp.dot(X)) / m#批量更新,也可以采用for循环逐个去更新theta
        
        if np.sum(abs(temp_theta - theta)) < minerror:
            is_coveraged = True
            print('参数已经收敛')
            
        count += 1
        theta = temp_theta
        error = cost_function(theta, X, y)
        
        if count % 10 == 0:
           print('第{}次迭代,当前损失值为:{}'.format(count, error))
        
        if count > maxloops:
           is_reveraged = True
           print('已达到最大迭代次数{}'.format(maxloops))
           
    return theta
    
    #注对于读取到的特征数据,一般还是需要做前期处理
    X = data_load[:, 0].reshape(-1, 1)#将列向量转换为行向量,如 [1, 2, 3, 4] ---> [[1], [2], [3], [4]]
    X = np.hstack((np.ones_like(X), X))#水平叠加

相关文章

网友评论

      本文标题:Machine_learning(持续补充......)

      本文链接:https://www.haomeiwen.com/subject/vcxepftx.html