美文网首页
TORCH09-05:逻辑回归

TORCH09-05:逻辑回归

作者: 杨强AT南京 | 来源:发表于2020-10-15 08:00 被阅读0次

    1. 人工智能的技术学习

    1.1. 基于工程数据集的划分

    1. 机器视觉

    2. 自然语言

    3. 商业数据分析

    1.2. 数据集,特征抽取,分类

    1. 模式识别

    2. 机器学习

    1.3. 技术讲解的结构

    1. 分类

    2. 特征学习(深度神经网络 / 卷积神经网络 / 循环神经网络(lstm))

    3. 数据集

    1.4. 技术框架选择

    1. Maxnet
    2. Tensorflow
    3. Pytorch
    • UI:Qt
      • |- OpenCV
      • |- Numpy
      • |- Matplotlib

    2. 分类

    2.1. 逻辑回归(分类模型)的数学模型

    • 目标:
      1. 分类的原理
      2. 分类结果的概率(Sigmoid函数 + Softmax函数)
    1. 回顾线性回归模型
    • y = x W + b

    • 实际的模型

      • y ^{\prime} = xW + b +\epsilon
      • y -y ^{\prime} = \epsilon
      • \epsilon服从Gauss分布
    • {x_1, x_2, \dots, x_n} \to {y_1, y_2, \dots, y_n}

    1. 逻辑分类模型
    • y_A = xW_A + b_A

    • y_B = xW_B + b_B

    • y =\begin{cases} 1, y_A - y_B >0 \\ 0, \text{其他} \end{cases}

    • y =\begin{cases} 1, x(W_A - W_B)>0 \\ 0, \text{其他} \end{cases}

    • xW^{\prime} + b^{\prime}= y

    • 条件

      • y值必须是离散的
    • 误差模型

      • y = Wx + b + \epsilon
    • 对一个样本x

      • p(y = 1) = p(xW + b + \epsilon>0) = p(\epsilon > -xW -b)
    • 如果假设p服从正态分布

      • p(y=1) = 1- F_{\epsilon}(-xW)
    • probit模型

    • 要命的问题:

      • 高斯分布没有累积分布函数;
      • 近似高斯累积分布函数的函数:Lapalce分布函数
    • 逻辑分布函数:

      • p(x) = \dfrac{e^{-x}}{(1 + e ^{-x})^2}
    • Sigmoid(x) = \dfrac{1}{1 + e^ {-x}}

      • 逻辑分布中的x替换为-1.702x ,则曲线与高斯累积分布函数的曲线,基本上完全重合。
    • 多个样本误差概率

      • p(Y) = \prod \limits _{y_i \in Y} p(y_i)
    • \prod h(X_i)^{y_i} (1- h(X_i))^{1-y_i}

    • 概率最大,找到W b是的概率最大

    • 最小值模型

      • 求自然对数,取负数。
      • L= - \sum \limits _i y_i ln(h(X_i)) - (1-y_i)(1- h(X_i))
    • 这个函数的最小值,无法使用最小二乘法求解:

      • 梯度下降法
    • 这个最小值的求解函数:

      • 俗称损失函数
    • 总结:

      • 逻辑回归模型:

        • y = Sigmoid(xW + b) \in [0,1)
      • 损失模型:

        • 交叉熵函数
      • 求解模型:

        • 梯度下降法
    • y = \begin{bmatrix}1 \\ 0 \end{bmatrix}

    • y = \begin{bmatrix}0 \\ 1 \end{bmatrix}

    • y = \begin{bmatrix}x_1 \\ x_2 \end{bmatrix}

    • x_1> x_2属于第一类,且概率是x_1

    2.2. 梯度下降法

    2.2.1. 数据怎么表示Tensor

    1. 数据表示方式

      1. Python的数据表达式/数据变量

        • int,float
        • list,tuple
      2. numpy

        • 向量 + 向量运算
      3. Tensor

        • 向量 + 向量运算 + 求导
    2. PyTorch的张量的表示

    import torch
    # help(torch)
    # help(torch.tensor)
    # help(torch.storage)
    
    • 张量的构造:

      • BoolTensor
      • ByteTensor
      • CharTensor
      • DoubleTensor
      • FloatTensor
      • IntTensor
      • LongTensor
      • ShortTensor
    • 张量:

      • 构造器
      • 运算符
        • 数学运算
        • 数据结构运算
          • 下标运算
      • 数据
      • 函数
        • 基本操作(数据结构)
        • 数学运算(四则+向量+矩阵运算+矩阵分析)
        • 自动求导
    import torch
    help(torch.FloatTensor.__init__)
    
    Help on wrapper_descriptor:
    
    __init__(self, /, *args, **kwargs)
        Initialize self.  See help(type(self)) for accurate signature.
    
    • 编程语言:
      • Cython
    import torch
    t = torch.FloatTensor((1, 1, 2))
    print(t.shape)
    
    
    torch.Size([3])
    
    t2 = torch.Tensor([1,2,3,4])   # FloatTensor /DoubleTensor
    print(t2)
    help(torch.Tensor.__init__)
    
    tensor([1., 2., 3., 4.])
    Help on wrapper_descriptor:
    
    __init__(self, /, *args, **kwargs)
        Initialize self.  See help(type(self)) for accurate signature.
    
    • 提供工厂模式
      • 函数构造对象
    import torch
    # help(torch.tensor)
    # help(torch.zeros)
    # 特殊张量的工厂函数
    
    • 推荐方式
    import torch
    tr = torch.tensor([1,2,3], dtype=torch.float64)
    print(tr)
    
    tensor([1., 2., 3.], dtype=torch.float64)
    
    1. Tensor的下标操作运算
    import torch
    t = torch.tensor(
        [
            [1, 2, 3, 4],
            [4, 5, 6, 7],
            [6, 7, 8, 9]
        ]
    )
    t
    
    tensor([[1, 2, 3, 4],
            [4, 5, 6, 7],
            [6, 7, 8, 9]])
    
    t[1]
    
    tensor([4, 5, 6, 7])
    
    t[0:2:1]
    
    tensor([[1, 2, 3, 4],
            [4, 5, 6, 7]])
    
    t[slice(0,2,1)]
    
    tensor([[1, 2, 3, 4],
            [4, 5, 6, 7]])
    
    t[...]
    
    tensor([[1, 2, 3, 4],
            [4, 5, 6, 7],
            [6, 7, 8, 9]])
    
    • 多个参数
    t[0,1]
    
    tensor(2)
    
    t[0:2:, 1]
    
    tensor([2, 5])
    
    t[...,1]
    
    tensor([2, 5, 7])
    
    • 支持数组
    import torch
    import numpy as np
    t = torch.tensor(
        [
            [1, 2, 3],
            [4, 5, 6],
            [6, 7, 8]
        ]
    )
    
    
    idx = np.array(
        [
            [1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]
        ]
    )
    
    print(idx[ [0,1] ])
    print(t[ [1, 2] ])
    
    [[1 2 3]
     [4 5 6]]
    tensor([[4, 5, 6],
            [6, 7, 8]])
    
    • 支持逻辑数组
    import torch
    import numpy as np
    t = torch.tensor(
        [
            [1, 2, 3],
            [4, 5, 6],
            [6, 7, 8]
        ]
    )
    
    
    idx = np.array(
        [
            [1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]
        ]
    )
    
    # print(idx > 3)
    # print(t > 4 )
    # t [ t > 3] = 88 
    # print(t)
    print(t[t>3])
    
    tensor([4, 5, 6, 6, 7, 8])
    
    t == t
    t.T
    
    
    tensor([[1, 4, 6],
            [2, 5, 7],
            [3, 6, 8]])
    

    2.2.2. 求导实现(自动求导)

    • 自定求导的规则

      • 对值求导
      • 上下变化:函数:在函数的环境中
        • requires_grad = False/True (未来是的时候,会被自定跟踪)
        • grad_fn
        • grad返回导数(某个值点)
        • dtype:float求导
    • f^{\prime}(x) = \lim \limits _{\epsilon \to 0} \dfrac{f(x+\epsilon)- f(x-\epsilon)}{2\epsilon}

    import torch
    # 1. 定义求导的值
    x = torch.Tensor([5])  # , dtype=torch.float
    x.requires_grad=True
    # 2. 运算的函数与求导的值有关
    y = x ** 2
    # 3. 求导
    
    y.backward()
    # 4. 输出导数
    print(x.grad)
    
    
    tensor([10.])
    
    
    y_ = 2 * x
    print(y_)
    
    
    tensor([10.], grad_fn=<MulBackward0>)
    

    2.2.3. 梯度下降的实现

    • 梯度下降算法模型:
      • x -= grad * \eta
        • grad决定方向正确
        • \eta决定速度
    1. numpy
      • y = x^2 - 2x + 1 = (x - 1) ^ 2
    import numpy as np
    x = 0
    grad_fn = lambda x:  2 *x - 2
    learning_rate = 0.01
    
    epoch = 1000
    
    x_list = [] 
    
    # 迭代梯度下降
    for e in range(epoch):
        # 1. 求导
        x_grad = grad_fn(x)
        # 2. 迭代更新x
        x -= learning_rate * x_grad
        # 3. 记录x
        x_list.append(x)
    
    print(x)
    %matplotlib inline
    import matplotlib.pyplot as plt
    plt.plot(range(epoch), x_list)
    plt.show()
    
    0.9999999983170326
    
    1. tensor实现
    import torch
    
    x = torch.Tensor([0.0])
    x.requires_grad = True
    
    learning_rate = 0.01
    epoch = 1000
    x_list = [] 
    
    for e in range(epoch):
        # 1. 函数(损失函数)
        y = x ** 2 - 2 * x + 1
        # 2. 求导
        y.backward(retain_graph=True)
        
        # 费求导跟踪环境
        with torch.autograd.no_grad():
            # 3. 更新
            x -= learning_rate * x.grad   # 这个操作被自定求导进行了跟踪
            # 4. 记录x
            x_list.append(x.detach().clone().numpy())
            x.grad.zero_()
        
    print(x.detach().clone().numpy())
    
        
    print(x)
    %matplotlib inline
    import matplotlib.pyplot as plt
    plt.plot(range(epoch), x_list)
    plt.show()
    
    [0.99999857]
    tensor([1.0000], requires_grad=True)
    

    2.3. 分类器(逻辑回归)的学习实现

    • 使用鸢尾花数据

    • 逻辑回归模型:

      1. 输出模型
        • y= S(xW)
      2. 误差模型
        • L = cross_entropy_loss(x)
          • 隐含的W(W b)x(x, 1)
      3. 求最小值(梯度下降法)
      4. 最总通过w得到一个最小w是的损失最下
        • 起到分类的作用
    1. 输出模型的实现
      • 鸢尾花数据集
    import sklearn
    import sklearn.datasets
    import torch
    
    data, target = sklearn.datasets.load_iris(return_X_y=True)
    data, target
    
    (array([[5.1, 3.5, 1.4, 0.2],
            [4.9, 3. , 1.4, 0.2],
            [4.7, 3.2, 1.3, 0.2],
            [4.6, 3.1, 1.5, 0.2],
            [5. , 3.6, 1.4, 0.2],
            [5.4, 3.9, 1.7, 0.4],
            [4.6, 3.4, 1.4, 0.3],
            [5. , 3.4, 1.5, 0.2],
            [4.4, 2.9, 1.4, 0.2],
            [4.9, 3.1, 1.5, 0.1],
            [5.4, 3.7, 1.5, 0.2],
            [4.8, 3.4, 1.6, 0.2],
            [4.8, 3. , 1.4, 0.1],
            [4.3, 3. , 1.1, 0.1],
            [5.8, 4. , 1.2, 0.2],
            [5.7, 4.4, 1.5, 0.4],
            [5.4, 3.9, 1.3, 0.4],
            [5.1, 3.5, 1.4, 0.3],
            [5.7, 3.8, 1.7, 0.3],
            [5.1, 3.8, 1.5, 0.3],
            [5.4, 3.4, 1.7, 0.2],
            [5.1, 3.7, 1.5, 0.4],
            [4.6, 3.6, 1. , 0.2],
            [5.1, 3.3, 1.7, 0.5],
            [4.8, 3.4, 1.9, 0.2],
            [5. , 3. , 1.6, 0.2],
            [5. , 3.4, 1.6, 0.4],
            [5.2, 3.5, 1.5, 0.2],
            [5.2, 3.4, 1.4, 0.2],
            [4.7, 3.2, 1.6, 0.2],
            [4.8, 3.1, 1.6, 0.2],
            [5.4, 3.4, 1.5, 0.4],
            [5.2, 4.1, 1.5, 0.1],
            [5.5, 4.2, 1.4, 0.2],
            [4.9, 3.1, 1.5, 0.2],
            [5. , 3.2, 1.2, 0.2],
            [5.5, 3.5, 1.3, 0.2],
            [4.9, 3.6, 1.4, 0.1],
            [4.4, 3. , 1.3, 0.2],
            [5.1, 3.4, 1.5, 0.2],
            [5. , 3.5, 1.3, 0.3],
            [4.5, 2.3, 1.3, 0.3],
            [4.4, 3.2, 1.3, 0.2],
            [5. , 3.5, 1.6, 0.6],
            [5.1, 3.8, 1.9, 0.4],
            [4.8, 3. , 1.4, 0.3],
            [5.1, 3.8, 1.6, 0.2],
            [4.6, 3.2, 1.4, 0.2],
            [5.3, 3.7, 1.5, 0.2],
            [5. , 3.3, 1.4, 0.2],
            [7. , 3.2, 4.7, 1.4],
            [6.4, 3.2, 4.5, 1.5],
            [6.9, 3.1, 4.9, 1.5],
            [5.5, 2.3, 4. , 1.3],
            [6.5, 2.8, 4.6, 1.5],
            [5.7, 2.8, 4.5, 1.3],
            [6.3, 3.3, 4.7, 1.6],
            [4.9, 2.4, 3.3, 1. ],
            [6.6, 2.9, 4.6, 1.3],
            [5.2, 2.7, 3.9, 1.4],
            [5. , 2. , 3.5, 1. ],
            [5.9, 3. , 4.2, 1.5],
            [6. , 2.2, 4. , 1. ],
            [6.1, 2.9, 4.7, 1.4],
            [5.6, 2.9, 3.6, 1.3],
            [6.7, 3.1, 4.4, 1.4],
            [5.6, 3. , 4.5, 1.5],
            [5.8, 2.7, 4.1, 1. ],
            [6.2, 2.2, 4.5, 1.5],
            [5.6, 2.5, 3.9, 1.1],
            [5.9, 3.2, 4.8, 1.8],
            [6.1, 2.8, 4. , 1.3],
            [6.3, 2.5, 4.9, 1.5],
            [6.1, 2.8, 4.7, 1.2],
            [6.4, 2.9, 4.3, 1.3],
            [6.6, 3. , 4.4, 1.4],
            [6.8, 2.8, 4.8, 1.4],
            [6.7, 3. , 5. , 1.7],
            [6. , 2.9, 4.5, 1.5],
            [5.7, 2.6, 3.5, 1. ],
            [5.5, 2.4, 3.8, 1.1],
            [5.5, 2.4, 3.7, 1. ],
            [5.8, 2.7, 3.9, 1.2],
            [6. , 2.7, 5.1, 1.6],
            [5.4, 3. , 4.5, 1.5],
            [6. , 3.4, 4.5, 1.6],
            [6.7, 3.1, 4.7, 1.5],
            [6.3, 2.3, 4.4, 1.3],
            [5.6, 3. , 4.1, 1.3],
            [5.5, 2.5, 4. , 1.3],
            [5.5, 2.6, 4.4, 1.2],
            [6.1, 3. , 4.6, 1.4],
            [5.8, 2.6, 4. , 1.2],
            [5. , 2.3, 3.3, 1. ],
            [5.6, 2.7, 4.2, 1.3],
            [5.7, 3. , 4.2, 1.2],
            [5.7, 2.9, 4.2, 1.3],
            [6.2, 2.9, 4.3, 1.3],
            [5.1, 2.5, 3. , 1.1],
            [5.7, 2.8, 4.1, 1.3],
            [6.3, 3.3, 6. , 2.5],
            [5.8, 2.7, 5.1, 1.9],
            [7.1, 3. , 5.9, 2.1],
            [6.3, 2.9, 5.6, 1.8],
            [6.5, 3. , 5.8, 2.2],
            [7.6, 3. , 6.6, 2.1],
            [4.9, 2.5, 4.5, 1.7],
            [7.3, 2.9, 6.3, 1.8],
            [6.7, 2.5, 5.8, 1.8],
            [7.2, 3.6, 6.1, 2.5],
            [6.5, 3.2, 5.1, 2. ],
            [6.4, 2.7, 5.3, 1.9],
            [6.8, 3. , 5.5, 2.1],
            [5.7, 2.5, 5. , 2. ],
            [5.8, 2.8, 5.1, 2.4],
            [6.4, 3.2, 5.3, 2.3],
            [6.5, 3. , 5.5, 1.8],
            [7.7, 3.8, 6.7, 2.2],
            [7.7, 2.6, 6.9, 2.3],
            [6. , 2.2, 5. , 1.5],
            [6.9, 3.2, 5.7, 2.3],
            [5.6, 2.8, 4.9, 2. ],
            [7.7, 2.8, 6.7, 2. ],
            [6.3, 2.7, 4.9, 1.8],
            [6.7, 3.3, 5.7, 2.1],
            [7.2, 3.2, 6. , 1.8],
            [6.2, 2.8, 4.8, 1.8],
            [6.1, 3. , 4.9, 1.8],
            [6.4, 2.8, 5.6, 2.1],
            [7.2, 3. , 5.8, 1.6],
            [7.4, 2.8, 6.1, 1.9],
            [7.9, 3.8, 6.4, 2. ],
            [6.4, 2.8, 5.6, 2.2],
            [6.3, 2.8, 5.1, 1.5],
            [6.1, 2.6, 5.6, 1.4],
            [7.7, 3. , 6.1, 2.3],
            [6.3, 3.4, 5.6, 2.4],
            [6.4, 3.1, 5.5, 1.8],
            [6. , 3. , 4.8, 1.8],
            [6.9, 3.1, 5.4, 2.1],
            [6.7, 3.1, 5.6, 2.4],
            [6.9, 3.1, 5.1, 2.3],
            [5.8, 2.7, 5.1, 1.9],
            [6.8, 3.2, 5.9, 2.3],
            [6.7, 3.3, 5.7, 2.5],
            [6.7, 3. , 5.2, 2.3],
            [6.3, 2.5, 5. , 1.9],
            [6.5, 3. , 5.2, 2. ],
            [6.2, 3.4, 5.4, 2.3],
            [5.9, 3. , 5.1, 1.8]]),
     array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
            2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
            2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]))
    
    x = torch.Tensor(data[0:100])
    y = torch.Tensor(target[0:100])
    x, y
    
    (tensor([[5.1000, 3.5000, 1.4000, 0.2000],
             [4.9000, 3.0000, 1.4000, 0.2000],
             [4.7000, 3.2000, 1.3000, 0.2000],
             [4.6000, 3.1000, 1.5000, 0.2000],
             [5.0000, 3.6000, 1.4000, 0.2000],
             [5.4000, 3.9000, 1.7000, 0.4000],
             [4.6000, 3.4000, 1.4000, 0.3000],
             [5.0000, 3.4000, 1.5000, 0.2000],
             [4.4000, 2.9000, 1.4000, 0.2000],
             [4.9000, 3.1000, 1.5000, 0.1000],
             [5.4000, 3.7000, 1.5000, 0.2000],
             [4.8000, 3.4000, 1.6000, 0.2000],
             [4.8000, 3.0000, 1.4000, 0.1000],
             [4.3000, 3.0000, 1.1000, 0.1000],
             [5.8000, 4.0000, 1.2000, 0.2000],
             [5.7000, 4.4000, 1.5000, 0.4000],
             [5.4000, 3.9000, 1.3000, 0.4000],
             [5.1000, 3.5000, 1.4000, 0.3000],
             [5.7000, 3.8000, 1.7000, 0.3000],
             [5.1000, 3.8000, 1.5000, 0.3000],
             [5.4000, 3.4000, 1.7000, 0.2000],
             [5.1000, 3.7000, 1.5000, 0.4000],
             [4.6000, 3.6000, 1.0000, 0.2000],
             [5.1000, 3.3000, 1.7000, 0.5000],
             [4.8000, 3.4000, 1.9000, 0.2000],
             [5.0000, 3.0000, 1.6000, 0.2000],
             [5.0000, 3.4000, 1.6000, 0.4000],
             [5.2000, 3.5000, 1.5000, 0.2000],
             [5.2000, 3.4000, 1.4000, 0.2000],
             [4.7000, 3.2000, 1.6000, 0.2000],
             [4.8000, 3.1000, 1.6000, 0.2000],
             [5.4000, 3.4000, 1.5000, 0.4000],
             [5.2000, 4.1000, 1.5000, 0.1000],
             [5.5000, 4.2000, 1.4000, 0.2000],
             [4.9000, 3.1000, 1.5000, 0.2000],
             [5.0000, 3.2000, 1.2000, 0.2000],
             [5.5000, 3.5000, 1.3000, 0.2000],
             [4.9000, 3.6000, 1.4000, 0.1000],
             [4.4000, 3.0000, 1.3000, 0.2000],
             [5.1000, 3.4000, 1.5000, 0.2000],
             [5.0000, 3.5000, 1.3000, 0.3000],
             [4.5000, 2.3000, 1.3000, 0.3000],
             [4.4000, 3.2000, 1.3000, 0.2000],
             [5.0000, 3.5000, 1.6000, 0.6000],
             [5.1000, 3.8000, 1.9000, 0.4000],
             [4.8000, 3.0000, 1.4000, 0.3000],
             [5.1000, 3.8000, 1.6000, 0.2000],
             [4.6000, 3.2000, 1.4000, 0.2000],
             [5.3000, 3.7000, 1.5000, 0.2000],
             [5.0000, 3.3000, 1.4000, 0.2000],
             [7.0000, 3.2000, 4.7000, 1.4000],
             [6.4000, 3.2000, 4.5000, 1.5000],
             [6.9000, 3.1000, 4.9000, 1.5000],
             [5.5000, 2.3000, 4.0000, 1.3000],
             [6.5000, 2.8000, 4.6000, 1.5000],
             [5.7000, 2.8000, 4.5000, 1.3000],
             [6.3000, 3.3000, 4.7000, 1.6000],
             [4.9000, 2.4000, 3.3000, 1.0000],
             [6.6000, 2.9000, 4.6000, 1.3000],
             [5.2000, 2.7000, 3.9000, 1.4000],
             [5.0000, 2.0000, 3.5000, 1.0000],
             [5.9000, 3.0000, 4.2000, 1.5000],
             [6.0000, 2.2000, 4.0000, 1.0000],
             [6.1000, 2.9000, 4.7000, 1.4000],
             [5.6000, 2.9000, 3.6000, 1.3000],
             [6.7000, 3.1000, 4.4000, 1.4000],
             [5.6000, 3.0000, 4.5000, 1.5000],
             [5.8000, 2.7000, 4.1000, 1.0000],
             [6.2000, 2.2000, 4.5000, 1.5000],
             [5.6000, 2.5000, 3.9000, 1.1000],
             [5.9000, 3.2000, 4.8000, 1.8000],
             [6.1000, 2.8000, 4.0000, 1.3000],
             [6.3000, 2.5000, 4.9000, 1.5000],
             [6.1000, 2.8000, 4.7000, 1.2000],
             [6.4000, 2.9000, 4.3000, 1.3000],
             [6.6000, 3.0000, 4.4000, 1.4000],
             [6.8000, 2.8000, 4.8000, 1.4000],
             [6.7000, 3.0000, 5.0000, 1.7000],
             [6.0000, 2.9000, 4.5000, 1.5000],
             [5.7000, 2.6000, 3.5000, 1.0000],
             [5.5000, 2.4000, 3.8000, 1.1000],
             [5.5000, 2.4000, 3.7000, 1.0000],
             [5.8000, 2.7000, 3.9000, 1.2000],
             [6.0000, 2.7000, 5.1000, 1.6000],
             [5.4000, 3.0000, 4.5000, 1.5000],
             [6.0000, 3.4000, 4.5000, 1.6000],
             [6.7000, 3.1000, 4.7000, 1.5000],
             [6.3000, 2.3000, 4.4000, 1.3000],
             [5.6000, 3.0000, 4.1000, 1.3000],
             [5.5000, 2.5000, 4.0000, 1.3000],
             [5.5000, 2.6000, 4.4000, 1.2000],
             [6.1000, 3.0000, 4.6000, 1.4000],
             [5.8000, 2.6000, 4.0000, 1.2000],
             [5.0000, 2.3000, 3.3000, 1.0000],
             [5.6000, 2.7000, 4.2000, 1.3000],
             [5.7000, 3.0000, 4.2000, 1.2000],
             [5.7000, 2.9000, 4.2000, 1.3000],
             [6.2000, 2.9000, 4.3000, 1.3000],
             [5.1000, 2.5000, 3.0000, 1.1000],
             [5.7000, 2.8000, 4.1000, 1.3000]]),
     tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
             0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
             1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
             1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
             1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]))
    
    # 逻辑回归模型
    w = torch.randn(1, 4)  # 迭代更新
    b = torch.randn(1)
    
    y_ = torch.nn.functional.linear(input=x, weight=w, bias=b)
    print(y_.shape)
    sy_ = torch.sigmoid(y_)
    print(sy_.shape)
    
    torch.Size([100, 1])
    torch.Size([100, 1])
    
    1. 损失模型
    import sklearn
    import sklearn.datasets
    import torch
    
    data, target = sklearn.datasets.load_iris(return_X_y=True)
    
    x = torch.Tensor(data[0:100])
    y = torch.Tensor(target[0:100]).view(100, 1)
    
    # 逻辑回归模型
    w = torch.randn(1, 4)  # 迭代更新
    b = torch.randn(1)
    
    y_ = torch.nn.functional.linear(input=x, weight=w, bias=b)
    print(y_.shape)
    sy_ = torch.sigmoid(y_)
    print(sy_.shape)
    
    loss_logit = torch.nn.functional.binary_cross_entropy_with_logits(y_, y, reduction="sum")
    loss = torch.nn.functional.binary_cross_entropy(sy_, y, reduction="sum")
    print(loss, loss_logit)
    
    torch.Size([100, 1])
    torch.Size([100, 1])
    tensor(146.3248) tensor(146.3248)
    
    1. 分类器实现
    import sklearn
    import sklearn.datasets
    import torch
    
    data, target = sklearn.datasets.load_iris(return_X_y=True)
    
    x = torch.Tensor(data[0:100])
    y = torch.Tensor(target[0:100]).view(100, 1)   # 数据格式
    
    # 逻辑回归模型
    w = torch.randn(1, 4)  # 迭代更新
    b = torch.randn(1)
    
    
    # 学习目标:w,b
    w.requires_grad= True
    b.requires_grad= True
    
    epoch = 10000
    learning_rate = 0.0001
    
    for e in range(epoch):
        # 1. 计算输出
        y_ = torch.nn.functional.linear(input=x, weight=w, bias=b)
        sy_ = torch.sigmoid(y_)
        
        # 2. 计算误差
        loss = torch.nn.functional.binary_cross_entropy(sy_, y, reduction="mean")
        # 3. 求导
        loss.backward()
        # 4. 建立无导跟踪环境
        if e % 500 == 0:
            with torch.autograd.no_grad():
                # 5. 更新w,b
                w -= learning_rate * w.grad
                b -= learning_rate * b.grad
                # 6. 清零
                w.grad.zero_()
                b.grad.zero_()
                # 预测,观察损失
                sy_[sy_ > 0.5] = 1
                sy_[sy_ <=0.5] = 0
                correct_rate = (sy_ == y).float().mean()
                print(F"轮数:{e:05d},损失:{loss:10.6f}, 测试准确率:{correct_rate * 100.0:8.2f}%")
    
    
    轮数:00000,损失:  1.077572, 测试准确率:   50.00%
    轮数:00500,损失:  1.076803, 测试准确率:   50.00%
    轮数:01000,损失:  0.718069, 测试准确率:   50.00%
    轮数:01500,损失:  0.466284, 测试准确率:   57.00%
    轮数:02000,损失:  0.314386, 测试准确率:   83.00%
    轮数:02500,损失:  0.228021, 测试准确率:   97.00%
    轮数:03000,损失:  0.177865, 测试准确率:   99.00%
    轮数:03500,损失:  0.147339, 测试准确率:   99.00%
    轮数:04000,损失:  0.127924, 测试准确率:  100.00%
    轮数:04500,损失:  0.115136, 测试准确率:  100.00%
    轮数:05000,损失:  0.106481, 测试准确率:  100.00%
    轮数:05500,损失:  0.100490, 测试准确率:  100.00%
    轮数:06000,损失:  0.096258, 测试准确率:  100.00%
    轮数:06500,损失:  0.093207, 测试准确率:  100.00%
    轮数:07000,损失:  0.090956, 测试准确率:  100.00%
    轮数:07500,损失:  0.089251, 测试准确率:  100.00%
    轮数:08000,损失:  0.087923, 测试准确率:  100.00%
    轮数:08500,损失:  0.086854, 测试准确率:  100.00%
    轮数:09000,损失:  0.085964, 测试准确率:  100.00%
    轮数:09500,损失:  0.085200, 测试准确率:  100.00%
    
    • 作业:
      1. 独立完成鸢尾花分类;

      2. 整理笔记

        • 每一个代码都跑一遍
        • 补充你自己的理解
        • 补充某些函数帮助,并加上自己的理解;
      3. 四个特征取两个特征来训练,并且可视化分类的效果;(可选)

        • matplotlib可视化;
      4. 提交到git服务器,更新:README.md


    相关文章

      网友评论

          本文标题:TORCH09-05:逻辑回归

          本文链接:https://www.haomeiwen.com/subject/ddcnpktx.html