
第六章 Pytorch加速及优化(性能提升) 之五
模型优化--剪枝 Pruning
现在的模型基本上都是成百上千万个参数,使得模型部署变得十分困难。
在现在模型参数过多的状况下,大家就在寻找是否有什么方式可以在精减掉一些参数,又可以保证模型精度及性能提升,即就是本节所要介绍的——剪枝技巧。
通过剪枝后,模型部署更便捷(占用更少内存,低功耗及减少一些硬件设备资源)。
Pruning is a technique that reduces the number of model parameters with minimal effect on performance.
剪枝可以在单层(a single layer),多层(multiple layer)或整个模型(an entire model)中进行。
LeNet5剪枝示例
- LeNet5网络结构
含有5个子模块:conv1,conv2,fc1,fc2和fc3。可以通过name_parameters()来查看各模块(层)的权重与偏置值情况。
import torch
from torch import nn
import torch.nn.functional as F
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet5()
for n, p in model.named_parameters():
print(n, ':', p.dtype)
"""
conv1.weight : torch.float32
conv1.bias : torch.float32
conv2.weight : torch.float32
conv2.bias : torch.float32
fc1.weight : torch.float32
fc1.bias : torch.float32
fc2.weight : torch.float32
fc2.bias : torch.float32
fc3.weight : torch.float32
fc3.bias : torch.float32
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet5().to(device)
# 打印出conv1层的参数(权重与偏置)
print(list(model.conv1.named_parameters()))
""""
[('weight', Parameter containing:
tensor([[[[-0.0254, -0.1115, -0.0827, 0.0797, 0.0523],
[ 0.0688, 0.0615, -0.0759, -0.0043, -0.0328],
[ 0.0675, 0.0057, -0.0423, -0.0888, -0.0361],
[-0.1121, 0.0517, 0.0228, -0.0806, -0.0785],
[-0.0847, 0.0142, -0.1019, -0.1008, -0.1116]],
[[ 0.1151, -0.0205, 0.0262, -0.0170, 0.0272],
[-0.0599, -0.1126, -0.1107, -0.0598, 0.0042],
[-0.0488, -0.0979, 0.0312, 0.0500, 0.1130],
[ 0.1132, 0.0287, -0.0994, 0.0156, 0.0690],
[-0.0146, 0.1097, 0.0898, 0.0058, 0.0998]],
[[ 0.1002, -0.0450, -0.0825, 0.0467, 0.0433],
[-0.0459, -0.0875, -0.0004, -0.0487, -0.0339],
[ 0.0943, -0.0994, 0.0373, 0.0701, -0.1118],
[ 0.0719, 0.0674, 0.0533, -0.1095, 0.0715],
[-0.0975, 0.0738, 0.1103, -0.0869, 0.0697]]],
[[[ 0.0356, -0.1042, -0.0460, 0.0180, -0.0991],
[-0.0112, -0.0852, 0.0637, 0.0459, -0.0785],
[ 0.0289, 0.0028, -0.0814, 0.0551, -0.0631],
[ 0.0308, -0.0570, -0.0495, -0.0321, 0.0580],
[-0.0466, -0.0319, -0.1131, 0.0338, -0.0257]],
[[ 0.0440, -0.0991, 0.0583, 0.0582, 0.0998],
[ 0.1060, 0.0901, -0.1064, -0.1107, 0.0029],
[-0.0760, 0.0742, -0.0715, -0.0847, 0.0834],
[-0.0198, 0.0122, -0.1090, 0.1133, -0.0565],
[-0.0185, 0.0326, 0.0253, 0.0050, 0.0366]],
[[ 0.0216, -0.0240, -0.0316, -0.0702, -0.1046],
[-0.0834, 0.0543, -0.0351, 0.0328, 0.0366],
[-0.0165, 0.0361, -0.0984, -0.0974, -0.1094],
[ 0.0296, 0.0665, -0.0207, 0.0672, -0.1143],
[ 0.0940, 0.0174, -0.0256, -0.1088, -0.0470]]],
[[[-0.0897, 0.0794, -0.0743, -0.1144, 0.0853],
[-0.0068, 0.0790, -0.0465, 0.0928, 0.0438],
[-0.0138, 0.0193, 0.0531, -0.1072, -0.0127],
[-0.0078, -0.0460, -0.0003, -0.0630, -0.0127],
[ 0.0368, -0.1079, 0.1085, -0.0132, 0.0334]],
[[-0.0276, -0.0570, -0.0167, -0.0674, -0.0074],
[-0.0432, 0.0750, 0.0877, 0.1045, -0.1132],
[-0.0501, 0.0281, -0.0430, -0.1030, 0.1126],
[ 0.0121, -0.0478, 0.0391, 0.0598, 0.0571],
[ 0.1018, 0.0410, -0.0727, -0.0121, -0.0578]],
[[-0.0081, -0.0545, 0.0606, 0.0648, -0.1060],
[ 0.0764, -0.1059, 0.0099, 0.0187, -0.0928],
[-0.0235, -0.0289, 0.0144, -0.0721, -0.0726],
[-0.0386, 0.0340, 0.0867, -0.1127, -0.0857],
[ 0.0107, 0.0146, -0.0849, 0.0524, -0.0900]]],
[[[ 0.0776, 0.0557, -0.0765, -0.1140, 0.0576],
[-0.0059, 0.0346, 0.0455, 0.0054, -0.1109],
[ 0.0361, -0.0814, 0.0402, -0.0102, -0.0602],
[ 0.0790, 0.0850, -0.0594, 0.0911, -0.1050],
[-0.0766, -0.0501, -0.1025, -0.0241, -0.0324]],
[[-0.0884, 0.0497, 0.0938, -0.0150, -0.1074],
[ 0.0292, -0.0107, 0.0899, 0.0282, 0.1153],
[-0.0436, -0.0003, 0.0186, 0.1088, 0.0628],
[ 0.0650, -0.0890, -0.0791, 0.0365, 0.0138],
[-0.0130, -0.0286, -0.0172, 0.0826, 0.0048]],
[[-0.0259, -0.0353, 0.0268, 0.0855, -0.0649],
[-0.0093, 0.0942, 0.0686, -0.0389, -0.0243],
[ 0.0929, -0.1022, -0.0687, -0.0074, 0.1133],
[ 0.1036, 0.0301, 0.0482, 0.0721, 0.0886],
[-0.0423, 0.0273, 0.0875, -0.0517, -0.0984]]],
[[[-0.1149, -0.0226, -0.0565, -0.0358, 0.0875],
[ 0.0141, 0.0230, 0.0436, -0.0414, 0.0428],
[ 0.0095, -0.0065, 0.0439, -0.0247, 0.0942],
[ 0.0874, -0.0608, -0.0328, 0.0800, 0.0835],
[ 0.0480, -0.0660, -0.0568, 0.0273, -0.0273]],
[[-0.1025, 0.0872, -0.0406, -0.0210, 0.0457],
[ 0.0217, 0.0395, 0.0367, -0.0048, -0.0257],
[-0.0946, -0.0753, 0.1068, -0.1134, -0.0725],
[ 0.0415, -0.0461, 0.0042, 0.0258, 0.0976],
[-0.0452, -0.1107, -0.0572, 0.0010, -0.0404]],
[[-0.0284, -0.0555, 0.0831, -0.0226, 0.0250],
[-0.0154, -0.0054, 0.0758, -0.0524, 0.0314],
[-0.0345, -0.0292, 0.0265, -0.1015, -0.0720],
[-0.0884, -0.0717, -0.0888, 0.1005, 0.0853],
[ 0.0581, 0.0599, 0.0862, 0.0140, -0.0529]]],
[[[-0.1089, -0.0443, 0.0031, -0.0961, 0.1019],
[-0.0164, 0.0342, -0.1050, 0.0010, 0.0803],
[ 0.1119, -0.0472, 0.1131, -0.0579, 0.0490],
[ 0.0479, 0.0411, -0.0623, -0.0223, 0.0469],
[-0.0606, 0.0521, -0.0535, 0.1140, -0.0084]],
[[-0.0067, 0.0567, 0.0127, 0.1140, 0.0917],
[ 0.0808, -0.0852, -0.0181, -0.0386, -0.0450],
[ 0.0882, -0.1146, -0.1066, -0.0310, 0.0851],
[ 0.0862, 0.1042, -0.1042, 0.0970, 0.0506],
[-0.1061, -0.0642, -0.0727, -0.0392, -0.1108]],
[[ 0.0244, 0.0048, 0.0172, -0.1152, 0.0844],
[-0.0951, 0.0273, -0.0766, -0.1125, -0.0725],
[ 0.0865, -0.0762, 0.0003, -0.0828, -0.1036],
[ 0.0394, -0.1054, -0.0131, -0.0687, -0.0399],
[ 0.0303, 0.0047, -0.0058, 0.0210, -0.0742]]]], device='cuda:0',
requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1018, 0.0211, 0.0306, 0.0260, 0.0620, -0.0182], device='cuda:0',
requires_grad=True))]
"""
局部或全局剪枝
局部剪枝 主要是对个别模块(如某一层,模块等)进行剪枝操作。如下所示:
import torch.nn.utils.prune as prune
# 只对conv1层的权重进行随机剪枝操作
prune.random_unstructured(model.conv1, name='weight', amount=0.25)
# 也可对偏置进行剪枝操作
prune.random_unstructured(model.conv1, name='bias', amount=0.25)
剪枝可以迭代式运用,因此在实际应用中可以针对不同维度用不同的方法进行剪枝操作。
同时剪枝不仅可应用于模块中,也可以对参数进行剪枝的。
示例:对网络结构中的卷积层进行权重剪枝,对线性连接层进行不同的权重比值剪枝操作
model = LeNet5().to(device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
# Prune all 2D convolutional layers by 30%
prune.random_unstructured(module,name='weight', amount=0.3)
# Prune all linear layers by 50%.
elif isinstance(module, torch.nn.Linear):
prune.random_unstructured(module, name='weight', amount=0.5)
全局剪枝 就是对整个模型进行剪枝操作。
示例:对模型的参数进行25%剪枝操作
model = LeNet5().to(device)
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
# prune 25% of all the parameters in the entire model
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.25
)
Pytorch提供的prune模块接口


自定义自己的剪枝方法
If you can’t find a pruning method that suits your needs, you can create your own pruning method. To do so, create a subclass from the BasePruningMethod class provided in torch.nn.utils.prune.
you will need to write your own _init_() constructor and compute_mask() method to describe how your pruning method computes the mask. In addition, you’ll need to specify the type of pruning (structured, unstructured, or global).
示例:
class MyPruningMethod(prune.BasePruningMethod):
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
def my_unstructured(module, name):
MyPruningMethod.apply(module, name)
return module
# 对模型进行自定义剪枝操作
model = LeNet5().to(device)
my_unstructured(model.fc1, name='bias')
网友评论