LinerUnit

作者: 乔一波一 | 来源:发表于2023-08-09 10:39 被阅读0次

    线性单元

    1. 感知器有一个问题,当面对的数据集不是线性可分的时候,『感知器规则』可能无法收敛,这意味着我们永远也无法完成一个感知器的训练。为了解决这个问题,我们使用一个连续的线性函数来替代感知器的阶跃函数,这种感知器就叫做线性单元。线性单元在面对线性不可分的数据集时,会收敛到一个最佳的近似上。
    2. 那么线性单元就是将感知机的输出激活函数由分段函数改为了连续函数,进而输出的值域也由{0,1}\rightarrow[-\infty,+\infty]

    举例说明

    当我们说模型时,我们实际上在谈论根据输入x预测输出y的算法。比如,x可以是一个人的工作年限,y可以是他的月薪,我们可以用某种算法来根据一个人的工作年限来预测他的收入。\\y=w*x+b

    其中w,b是可以拟合年限输入和月薪输出的待求权重参数。工作年限称为一个特征,输入可以包含多个特征如:行业,公司,职级等。当特征变多时,对应的每个特征都需要一个权重w_i用于拟合输入和输出之间的关系。
    \\y = w_1*x_1+w_2*x_2+\dots+w_n*x_n+b,矩阵表示
    y=\textbf{W}^T\textbf{X}\\其中
    \textbf{W}=\begin{bmatrix} w_i\\ \vdots \\ w_n\\ b \\ \end{bmatrix}, \textbf{X}=\begin{bmatrix} x_i \\ \vdots \\ x_n \\ 1\\ \end{bmatrix}\\

    代码

    由于相较于Perceptron只改变了激活函数,所以我们可以继承Perceptron快速实现LinerUnit

    class LinerUnit(Perceptron):
        def __init__(self, input_dim, activator) -> None:
            super().__init__(input_dim, activator)
    

    生成训练数据,定义可视化

    # 新定义的连续线性激活函数
    def liner_activater(x):
        return x
    
    def get_training_dataset():
        """
        construct training_set, consist of n samples
        Working years and corresponding salary.
        """
        data = [[5], [3], [8], [1.4], [10.1], [8.1]]
        labels = [5500, 2300, 7600, 1800, 11400, 20000]
        return data, labels
    
    def train_liner_unit(iterations, lr):
        """
        Train a liner_unit with training_set.
        """
        lu = LinerUnit(input_dim=1, activator=liner_activater)
        lu.train(*get_training_dataset(), iterations=iterations, lr=lr)
        return lu
    
    def show_results(linear_unit, samples):
        """
        Visualize the line after the linear unit fit
        """
        predicts = [linear_unit.predict(s) for s in samples]
        plt.scatter(samples, predicts, marker="o")
        x_fit = np.linspace(start=0, stop=max(samples), num=100)
        y_fit = linear_unit.weights * x_fit + linear_unit.bias
        plt.plot(x_fit, y_fit, linestyle="-")
        plt.xlabel("Working years")
        plt.ylabel("Salary")
        plt.show()
    

    训练,测试,并可视化

    if __name__ == "__main__":
        linear_unit = train_liner_unit(10, 0.1)
        test_samples = [[3.4], [15], [1.5], [6.3], [8]]
        # test
        for year in test_samples:
            print(f"Work {year} years, monthly salary = {linear_unit.predict(year)}")
    
        show_results(linear_unit=linear_unit, samples=test_samples)
    

    结果

    控制台输出.png 可视化结果.png

    相关文章

      网友评论

          本文标题:LinerUnit

          本文链接:https://www.haomeiwen.com/subject/amiopdtx.html