美文网首页
Pytorch袖珍手册之十四

Pytorch袖珍手册之十四

作者: 深思海数_willschang | 来源:发表于2021-08-30 10:20 被阅读0次
pytorch pocket reference

第六章 Pytorch加速及优化(性能提升) 之五

模型优化--剪枝 Pruning

现在的模型基本上都是成百上千万个参数,使得模型部署变得十分困难。
在现在模型参数过多的状况下,大家就在寻找是否有什么方式可以在精减掉一些参数,又可以保证模型精度及性能提升,即就是本节所要介绍的——剪枝技巧。
通过剪枝后,模型部署更便捷(占用更少内存,低功耗及减少一些硬件设备资源)。

Pruning is a technique that reduces the number of model parameters with minimal effect on performance.

剪枝可以在单层(a single layer),多层(multiple layer)或整个模型(an entire model)中进行。

LeNet5剪枝示例

  • LeNet5网络结构
    含有5个子模块:conv1,conv2,fc1,fc2和fc3。可以通过name_parameters()来查看各模块(层)的权重与偏置值情况。
import torch
from torch import nn
import torch.nn.functional as F

class LeNet5(nn.Module):
   def __init__(self):
       super(LeNet5, self).__init__()
       self.conv1 = nn.Conv2d(3, 6, 5)
       self.conv2 = nn.Conv2d(6, 16, 5)
       self.fc1 = nn.Linear(16*5*5, 120)
       self.fc2 = nn.Linear(120, 84)
       self.fc3 = nn.Linear(84, 10)

   def forward(self, x):
       x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
       x = F.max_pool2d(F.relu(self.conv2(x)), 2)
       x = x.view(-1, int(x.nelement() / x.shape[0]))
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x

model = LeNet5()

for n, p in model.named_parameters():
   print(n, ':', p.dtype)

"""
conv1.weight : torch.float32
conv1.bias : torch.float32
conv2.weight : torch.float32
conv2.bias : torch.float32
fc1.weight : torch.float32
fc1.bias : torch.float32
fc2.weight : torch.float32
fc2.bias : torch.float32
fc3.weight : torch.float32
fc3.bias : torch.float32
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet5().to(device)
# 打印出conv1层的参数(权重与偏置)
print(list(model.conv1.named_parameters()))
""""
[('weight', Parameter containing:
tensor([[[[-0.0254, -0.1115, -0.0827,  0.0797,  0.0523],
         [ 0.0688,  0.0615, -0.0759, -0.0043, -0.0328],
         [ 0.0675,  0.0057, -0.0423, -0.0888, -0.0361],
         [-0.1121,  0.0517,  0.0228, -0.0806, -0.0785],
         [-0.0847,  0.0142, -0.1019, -0.1008, -0.1116]],

        [[ 0.1151, -0.0205,  0.0262, -0.0170,  0.0272],
         [-0.0599, -0.1126, -0.1107, -0.0598,  0.0042],
         [-0.0488, -0.0979,  0.0312,  0.0500,  0.1130],
         [ 0.1132,  0.0287, -0.0994,  0.0156,  0.0690],
         [-0.0146,  0.1097,  0.0898,  0.0058,  0.0998]],

        [[ 0.1002, -0.0450, -0.0825,  0.0467,  0.0433],
         [-0.0459, -0.0875, -0.0004, -0.0487, -0.0339],
         [ 0.0943, -0.0994,  0.0373,  0.0701, -0.1118],
         [ 0.0719,  0.0674,  0.0533, -0.1095,  0.0715],
         [-0.0975,  0.0738,  0.1103, -0.0869,  0.0697]]],


       [[[ 0.0356, -0.1042, -0.0460,  0.0180, -0.0991],
         [-0.0112, -0.0852,  0.0637,  0.0459, -0.0785],
         [ 0.0289,  0.0028, -0.0814,  0.0551, -0.0631],
         [ 0.0308, -0.0570, -0.0495, -0.0321,  0.0580],
         [-0.0466, -0.0319, -0.1131,  0.0338, -0.0257]],

        [[ 0.0440, -0.0991,  0.0583,  0.0582,  0.0998],
         [ 0.1060,  0.0901, -0.1064, -0.1107,  0.0029],
         [-0.0760,  0.0742, -0.0715, -0.0847,  0.0834],
         [-0.0198,  0.0122, -0.1090,  0.1133, -0.0565],
         [-0.0185,  0.0326,  0.0253,  0.0050,  0.0366]],

        [[ 0.0216, -0.0240, -0.0316, -0.0702, -0.1046],
         [-0.0834,  0.0543, -0.0351,  0.0328,  0.0366],
         [-0.0165,  0.0361, -0.0984, -0.0974, -0.1094],
         [ 0.0296,  0.0665, -0.0207,  0.0672, -0.1143],
         [ 0.0940,  0.0174, -0.0256, -0.1088, -0.0470]]],


       [[[-0.0897,  0.0794, -0.0743, -0.1144,  0.0853],
         [-0.0068,  0.0790, -0.0465,  0.0928,  0.0438],
         [-0.0138,  0.0193,  0.0531, -0.1072, -0.0127],
         [-0.0078, -0.0460, -0.0003, -0.0630, -0.0127],
         [ 0.0368, -0.1079,  0.1085, -0.0132,  0.0334]],

        [[-0.0276, -0.0570, -0.0167, -0.0674, -0.0074],
         [-0.0432,  0.0750,  0.0877,  0.1045, -0.1132],
         [-0.0501,  0.0281, -0.0430, -0.1030,  0.1126],
         [ 0.0121, -0.0478,  0.0391,  0.0598,  0.0571],
         [ 0.1018,  0.0410, -0.0727, -0.0121, -0.0578]],

        [[-0.0081, -0.0545,  0.0606,  0.0648, -0.1060],
         [ 0.0764, -0.1059,  0.0099,  0.0187, -0.0928],
         [-0.0235, -0.0289,  0.0144, -0.0721, -0.0726],
         [-0.0386,  0.0340,  0.0867, -0.1127, -0.0857],
         [ 0.0107,  0.0146, -0.0849,  0.0524, -0.0900]]],


       [[[ 0.0776,  0.0557, -0.0765, -0.1140,  0.0576],
         [-0.0059,  0.0346,  0.0455,  0.0054, -0.1109],
         [ 0.0361, -0.0814,  0.0402, -0.0102, -0.0602],
         [ 0.0790,  0.0850, -0.0594,  0.0911, -0.1050],
         [-0.0766, -0.0501, -0.1025, -0.0241, -0.0324]],

        [[-0.0884,  0.0497,  0.0938, -0.0150, -0.1074],
         [ 0.0292, -0.0107,  0.0899,  0.0282,  0.1153],
         [-0.0436, -0.0003,  0.0186,  0.1088,  0.0628],
         [ 0.0650, -0.0890, -0.0791,  0.0365,  0.0138],
         [-0.0130, -0.0286, -0.0172,  0.0826,  0.0048]],

        [[-0.0259, -0.0353,  0.0268,  0.0855, -0.0649],
         [-0.0093,  0.0942,  0.0686, -0.0389, -0.0243],
         [ 0.0929, -0.1022, -0.0687, -0.0074,  0.1133],
         [ 0.1036,  0.0301,  0.0482,  0.0721,  0.0886],
         [-0.0423,  0.0273,  0.0875, -0.0517, -0.0984]]],


       [[[-0.1149, -0.0226, -0.0565, -0.0358,  0.0875],
         [ 0.0141,  0.0230,  0.0436, -0.0414,  0.0428],
         [ 0.0095, -0.0065,  0.0439, -0.0247,  0.0942],
         [ 0.0874, -0.0608, -0.0328,  0.0800,  0.0835],
         [ 0.0480, -0.0660, -0.0568,  0.0273, -0.0273]],

        [[-0.1025,  0.0872, -0.0406, -0.0210,  0.0457],
         [ 0.0217,  0.0395,  0.0367, -0.0048, -0.0257],
         [-0.0946, -0.0753,  0.1068, -0.1134, -0.0725],
         [ 0.0415, -0.0461,  0.0042,  0.0258,  0.0976],
         [-0.0452, -0.1107, -0.0572,  0.0010, -0.0404]],

        [[-0.0284, -0.0555,  0.0831, -0.0226,  0.0250],
         [-0.0154, -0.0054,  0.0758, -0.0524,  0.0314],
         [-0.0345, -0.0292,  0.0265, -0.1015, -0.0720],
         [-0.0884, -0.0717, -0.0888,  0.1005,  0.0853],
         [ 0.0581,  0.0599,  0.0862,  0.0140, -0.0529]]],


       [[[-0.1089, -0.0443,  0.0031, -0.0961,  0.1019],
         [-0.0164,  0.0342, -0.1050,  0.0010,  0.0803],
         [ 0.1119, -0.0472,  0.1131, -0.0579,  0.0490],
         [ 0.0479,  0.0411, -0.0623, -0.0223,  0.0469],
         [-0.0606,  0.0521, -0.0535,  0.1140, -0.0084]],

        [[-0.0067,  0.0567,  0.0127,  0.1140,  0.0917],
         [ 0.0808, -0.0852, -0.0181, -0.0386, -0.0450],
         [ 0.0882, -0.1146, -0.1066, -0.0310,  0.0851],
         [ 0.0862,  0.1042, -0.1042,  0.0970,  0.0506],
         [-0.1061, -0.0642, -0.0727, -0.0392, -0.1108]],

        [[ 0.0244,  0.0048,  0.0172, -0.1152,  0.0844],
         [-0.0951,  0.0273, -0.0766, -0.1125, -0.0725],
         [ 0.0865, -0.0762,  0.0003, -0.0828, -0.1036],
         [ 0.0394, -0.1054, -0.0131, -0.0687, -0.0399],
         [ 0.0303,  0.0047, -0.0058,  0.0210, -0.0742]]]], device='cuda:0',
      requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1018,  0.0211,  0.0306,  0.0260,  0.0620, -0.0182], device='cuda:0',
      requires_grad=True))]

"""
局部或全局剪枝

局部剪枝 主要是对个别模块(如某一层,模块等)进行剪枝操作。如下所示:

import torch.nn.utils.prune as prune

# 只对conv1层的权重进行随机剪枝操作
prune.random_unstructured(model.conv1, name='weight', amount=0.25)
# 也可对偏置进行剪枝操作
prune.random_unstructured(model.conv1, name='bias', amount=0.25)

剪枝可以迭代式运用,因此在实际应用中可以针对不同维度用不同的方法进行剪枝操作。
同时剪枝不仅可应用于模块中,也可以对参数进行剪枝的。
示例:对网络结构中的卷积层进行权重剪枝,对线性连接层进行不同的权重比值剪枝操作

model = LeNet5().to(device)

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        # Prune all 2D convolutional layers by 30%
        prune.random_unstructured(module,name='weight', amount=0.3)
    # Prune all linear layers by 50%.
    elif isinstance(module, torch.nn.Linear):
        prune.random_unstructured(module,  name='weight', amount=0.5)

全局剪枝 就是对整个模型进行剪枝操作。
示例:对模型的参数进行25%剪枝操作

model = LeNet5().to(device)
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

# prune 25% of all the parameters in the entire model
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.25
)
Pytorch提供的prune模块接口
image.png
image.png

自定义自己的剪枝方法

If you can’t find a pruning method that suits your needs, you can create your own pruning method. To do so, create a subclass from the BasePruningMethod class provided in torch.nn.utils.prune.
you will need to write your own _init_() constructor and compute_mask() method to describe how your pruning method computes the mask. In addition, you’ll need to specify the type of pruning (structured, unstructured, or global).

示例:

class MyPruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = 'unstructured'
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

def my_unstructured(module, name):
    MyPruningMethod.apply(module, name)
    return module

# 对模型进行自定义剪枝操作
model = LeNet5().to(device)
my_unstructured(model.fc1, name='bias')

相关文章

  • Pytorch袖珍手册之十四

    第六章 Pytorch加速及优化(性能提升) 之五 模型优化--剪枝 Pruning 现在的模型基本上都是成百上千...

  • Pytorch袖珍手册之十

    第六章 Pytorch加速及优化(性能提升)之一 在实际应用中,我们可能面对的数据是比之前章节里的还要多,模型网络...

  • Pytorch袖珍手册之十一

    第六章 Pytorch加速及优化(性能提升) 之二 模型并行处理 model parallel processin...

  • Pytorch袖珍手册之五

    我用阿里云盘分享了「OReilly.PyTorch.Pocket.R...odels.149209000X.pdf...

  • Pytorch袖珍手册之四

    第三章 基于Pytorch的深度学习开发 前面章节我们已经了解tensor及其操作,这章主要就是学习如何用Pyto...

  • Pytorch袖珍手册之八

    我用阿里云盘分享了「OReilly.PyTorch.Pocket.R...odels.149209000X.pdf...

  • Pytorch袖珍手册之九

    第五章 基于Pytorch的深度学习网络结构自主式开发 前面章节我们主要通过pytorch提供的类,函数和各种库进...

  • Pytorch袖珍手册之六

    我用阿里云盘分享了「OReilly.PyTorch.Pocket.R...odels.149209000X.pdf...

  • Pytorch袖珍手册之七

    我用阿里云盘分享了「OReilly.PyTorch.Pocket.R...odels.149209000X.pdf...

  • Pytorch袖珍手册之十三

    第六章 Pytorch加速及优化(性能提升) 之四 模型优化--量化 Quantization 模型量化属于模型压...

网友评论

      本文标题:Pytorch袖珍手册之十四

      本文链接:https://www.haomeiwen.com/subject/ceowiltx.html