pytorch中的线性模块的实现如下,在init函数中定义weight值和bias值。
class Linear(Module):
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def forward(self, input):
return F.linear(input, self.weight, self.bias)
所以若要对linear子模块的参数进行初始化,利用如下策略可以对单个linear子模块进行参数初始化。
import torch.nn as nn
from torch.nn import init
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
('linear', nn.Linear(num_inputs, 1))
]))
print(net )
print(net[0])
init.normal_(net[0].weight, mean=0.0, std=0.01)
init.constant_(net[0].bias, val=0.0) # 也可以直接修改bias的data: net[0].bias.data.fill_(0)
#----------------
LinearNet(
(linear): Linear(in_features=2, out_features=1, bias=True)
)
Linear(in_features=2, out_features=1, bias=True)
<class 'torch.nn.modules.linear.Linear'>
网友评论