美文网首页
深度学习模型压缩之模型剪枝

深度学习模型压缩之模型剪枝

作者: 小黄不头秃 | 来源:发表于2023-07-03 17:25 被阅读0次

我们在训练好了模型之后,在模型部署上线之前,我们通常会对模型进行优化。例如:我们的模型过大,使得模型无法部署到边缘设备上,我们需要对模型进行压缩,那么怎么对模型进行压缩和优化呢?

我们需要在减小模型大小的同时,尽可能的维持模型原有精度。

模型优化的方法:

一、模型剪枝

模型剪枝的类别:

  • 结构化剪枝:剪去网络层,会改变模型的结构。(对精度影响较大)
  • 非结构化剪枝:剪去的是神经元,不会改变模型的结构。(非结构化的剪枝效果有时候需要硬件的支持)

二、模型剪枝的代码实现

官方教程:Pruning Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation

(1)非结构化剪枝
  1. 导入所需要的包、构建网络模型
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)
  1. 对某一层进行剪枝,随机裁剪
module = model.conv1
print(list(module.named_parameters()))

print(list(module.named_buffers()))

prune.random_unstructured(module, name="weight", amount=0.3) # 剪去30%的参数

prune.random_unstructured(module, name="weight", amount=0.3) # 剪去3个神经元参数

# print(module._forward_pre_hooks) 

不推荐随机剪枝可以使用L1非结构化剪枝(每次会抑制权重较大的值)

prune.l1_unstructured(module, name="bias", amount=3)
(2)结构化剪枝

结构化剪枝,剪去某一层的参数。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

对整个神经网络进行剪枝。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

全局剪枝 (仅支持非结构化剪枝)

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

相关文章

网友评论

      本文标题:深度学习模型压缩之模型剪枝

      本文链接:https://www.haomeiwen.com/subject/lvkxudtx.html