美文网首页
nn.Linear()

nn.Linear()

作者: 三方斜阳 | 来源:发表于2021-02-17 09:10 被阅读0次

官网 nn.Linear()详解

Linear

作用:对输入数据进行线性变换

例子:

import torch
m = torch.nn.Linear(20, 30)
input = torch.randn(128, 20)#输入数据的维度(128,20)
output = m(input)
print(m.weight.shape)
print(m.bias.shape)
print(output.size())
 >>
torch.Size([30, 20])
torch.Size([30])
torch.Size([128, 30])
>>

理解:

线性变换的权重值 weight 和 偏置值 bias 会伴随训练过程不管更新参数,也就是注释中的 learnable ,他们的初始时刻都随机初始化 在区间 :
(-\sqrt{k},\sqrt{k}) , k=1/infeatures
上面的例子可以看到,输入数据会跟一个权重矩阵 A 相乘,A.shape=[30, 20],偏重为一个一维tensor,长度为[30],权重矩阵相乘得到的128个30维的向量,最后会给每一个向量加上这个偏置误差tensor,所以就对应线性变换公式:y=x*A^T+b
y=[128,20]*[20,30]+[30]=[128,30]
于是nn.Linear()也等价与下面的:

output = torch.mm(input , m.weight.t()) + m.bias  
print(output.size())
>>torch.Size([128, 30])

这个函数是用来设置神经网络中的全连接层的,输入输出都是二维 tensor
in_features:指的是输入的二维tensor的大小,即输入的[batch_size, size]中的size。
out_features:指的是输出的二维tensor的大小,即输出的二维张量的形状为[batch_size,output_size],也代表了该全连接层的神经元个数。

相关文章

  • pytorch.nn 相关函数

    PyTorch的nn.Linear()详解 PyTorch的nn.Linear()是用于设置网络中的全连接层的,需...

  • nn.linear()

    import torchimport torch.nn nn.linear()是用来设置网络中的全连接层的,而在全...

  • nn.Linear()

    官网 nn.Linear()详解[https://pytorch.org/docs/master/generate...

  • pyTorch上的TimeDistributed

    Keras有个TimeDistributed包装器,pytorch上用nn.Linear就能实现。老是忘在这里记录...

  • nn.Linear()和nn.Conv2d()

    今天看别人代码时,突然忘了nn.Linear()和nn.Conv2d()网络的相关信息,复习一下。nn.Linea...

网友评论

      本文标题:nn.Linear()

      本文链接:https://www.haomeiwen.com/subject/qzcbxltx.html