美文网首页
pytorch如何不受权重文件限制,随心所欲定义模型

pytorch如何不受权重文件限制,随心所欲定义模型

作者: 瀚文文文问问 | 来源:发表于2022-01-01 13:46 被阅读0次

问题引入

闲来无事,自己搭建了yolov3的backbone darknet53玩一玩,使用pytorch搭建完成后,去yolov3的官网下载了yolo_weight.pth权重文件,并不能直接载入,有这样几个问题:

  • 我搭建的模型只是backbone部分,而官方的权重文件包含了yolov3_head部分,如何去掉yolov3_head只载入backbone部分的权重成为了第一个问题。
  • 我在搭建darkent53的时候只是参照了yolov3论文的图片,只有一些残差结构的内部去看了源码,所以我整体构建的方法和每一层的命名与官方不同。
  • 第三个问题是我打印了我自己的参数字典(model_dict)和官方权重的参数字典(weights_dict),我发现我的所有bn层的参数都比官网的多了一个参数num_batches_tracked,这是由于官方训练得到的权重文件是基于torch0.3.1,版本太老导致的。
    左边weights_dict,右边图model_dict.png

解决

我通过两个步骤解决了问题,首先过滤掉model_dict中的num_batches_tracked,之后使用循环进行遍历赋值。

##############################################################
#  > File Name        : darknet53.py
#  > Author           : zhw
#  > Created Time     : 2021年12月31日 星期五 22时12分45秒
##############################################################
import torch 
import torch.nn as nn
import math
from collections import OrderedDict

def ConvBNLRelu(in_channels, out_channels, kernel, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.1)
        )

class BasicBlock(nn.Module):
    def __init__(self, in_channels, channels_list):
        super(BasicBlock,self).__init__()
        self.conv1 = ConvBNLRelu(in_channels, channels_list[0], 1, 1)
        self.conv2 = ConvBNLRelu(channels_list[0], channels_list[1], 3, 1, 1)
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x += residual
        return x
class Darknet53(nn.Module):
    def __init__(self):
        super(Darknet53, self).__init__()
        self.in_channels = 32
        layers = [1, 2, 8, 8, 4]
        self.conv1 = nn.Conv2d(3,self.in_channels,3,1,1,bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu1 = nn.LeakyReLU(0.1)
        self.layer1 = self._make_layer([32,64], layers[0])
        self.layer2 = self._make_layer([64,128], layers[1])
        self.layer3 = self._make_layer([128,256], layers[2])
        self.layer4 = self._make_layer([256,512], layers[3])
        self.layer5 = self._make_layer([512,1024], layers[4])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def _make_layer(self, channels_list, blocks):
        layers = []
        layers.append(("ds_conv", nn.Conv2d(self.in_channels, channels_list[1], 3, 2, 1, bias=False)))
        layers.append(("ds_bn", nn.BatchNorm2d(channels_list[1])))
        layers.append(("ds_relu", nn.LeakyReLU(0.1)))
        self.in_channels = channels_list[1]
        for i in range(blocks):
            layers.append(("residual_{}".format(i), BasicBlock(self.in_channels, channels_list)))
        return nn.Sequential(OrderedDict(layers))
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out2 = self.layer3(out)
        out1 = self.layer4(out2)
        out0 = self.layer5(out1)
        return out0, out1, out2

def main():
    t = torch.randn([4,3,416,416])
    model = Darknet53()
    weights_path = "/home/zhw/dataset/weights/object_detection/yolo_weights.pth"
    weights_dict = torch.load(weights_path)
    weights_list_key = list(weights_dict.keys())
    len_weights = len(weights_list_key)
    model_dict = model.state_dict()
    # 过滤掉num_batches_tracked参数
    model_dict = {k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k}
    model_list_key = list(model_dict.keys())
    len_model_dict = len(model_list_key)
    m,n = 0,0
    # 循环赋值,并保证shape一致
    while m < len_weights and n < len_model_dict:
        weights_name,model_name = weights_list_key[m],model_list_key[n]
        weights_shape,model_shape = weights_dict[weights_name].shape,model_dict[model_name].shape
        if weights_shape != model_shape:
           continue
        model_dict[model_name] = weights_dict[weights_name] 
        n += 1
        m += 1
    model.load_state_dict(model_dict)
    if n == min(len_weights, len_model_dict):
        print("all weights was loaded")
    #out0, out1, out2 = model(t) 
    #print("out0 shape: ", out0.shape)
    #print("out1 shape: ", out1.shape)
    #print("out2 shape: ", out2.shape)

if __name__ == "__main__":
    main()

总结

这种方法对好多情况其实都适用,但有的时候我们搭建的模型名字都与权重一致,只是想加载部分权重的话,其实用不着这么麻烦,之后我会给出总结。

相关文章

网友评论

      本文标题:pytorch如何不受权重文件限制,随心所欲定义模型

      本文链接:https://www.haomeiwen.com/subject/qxboqrtx.html