美文网首页
代码阅读笔记1

代码阅读笔记1

作者: 幽并游侠儿_1425 | 来源:发表于2020-07-21 12:43 被阅读0次

    代码位置

    代码结构:

    1.master文件夹
    (1) dice_loss.py
    (2) eval.py
    (3) predict.py: ** 完全没涉及pruning后的网络**
    (4) pruning.py:
    (5) submit.py:** 完全没涉及pruning后的网络**
    (6) train.py:** 完全没涉及pruning后的网络**
    (7) 文件夹unet
    <1> prune_layers.py
    <2> prune_unet_model.py : class PruneUNet
    <3> prune_unet_parts.py : 对p_double_conv, p_inconv, p_outconv, p_down, p_up 进行了定义

    <4> unet_model.py
    <5> unet_parts.py
    (8) 文件夹util
    <1> load.py
    <2> util.py
    <3> data_vis.py
    <4> crf.py

    阅读目的:

    能用,能跑,放自己的数据能跑

    阅读笔记

    A. pruning.py阅读
    1. pruning.py中的line 91的net.train()的理解
      net的定义:
    net = PruneUNet(n_channels=3, n_classes=1) 
    
    1. 中心for循环的理解
      for循环做了以下四件事情:
      对每个epoch:
      (1) reset the generator for training data and validation data.
      这里实际上对每一个epoch,在开始时对数据都做了traditional augmentation,但是我们的数据量足够多,不需要这么做。待改进
      (2) 取validation dataset的前四项,进行prediction和计算accuracy。
      这里因为validation dataset是自动生成的,所以虽然都是前4项,但是validation dataset是不一样的。
      (3) 用PruneUNet训练training dataset的前两个batch:
      PruneUNet位于unet文件夹的prune_unet_model.py
    • model.eval() :Pytorch会自动把BN和Dropout固定住,不会取平均,而是用训练好的值
    • model.train():让model变成训练模式,此时 dropout和batch normalization的操作在训练时起到防止网络过拟合的作用。
      训练,算loss,反向传播
      然后进行prune,
      对每个epoch,都要循环num_prune_iterations次,每一次运行一遍net.prune。net.prune的具体内容见B,总结下来是去掉一个channel。

    疑问:如果一直执行net.prune,都是去掉最小值,但是去掉最小值之后如果不删掉对应prune_feature_map 里的值,那么每次删掉的module里的filter不是一样的吗?
    回答:每一次找到对应layer_idx和filter_idx之后,对conv2d层执行prune,都需要运行位于prune_layer.py中的函数prune_feature_map,在这个函数中,执行了下面两步:

    indices = Variable(torch.LongTensor([i for i in range(self.out_channels) if i != map_index]))
    self.weight = nn.Parameter(self.weight.index_select(0, indices).data)
    

    对bias和对weight有一样的操作。
    最后将输出channel减一。
    这里重点理解这个index_select函数:
    函数格式:

    index_select(
        dim,
        index)
    

    参数含义:

    dim:表示从第几维挑选数据,类型为int值;index:表示从第一个参数维度中的哪个位置挑选数据,类型为torch.Tensor类的实例;

    (4) 继续对第一次循环里的validation的数据用pruned的代码进行预测和计算loss。
    (5) if save_cp时,保存net.state_dict()

    B. prune_unet_model.py阅读

    PruneUNet这个class中定义了4个函数:
    __ init __,forward,set_pruning和prune。
    其中 __ init __里,所有的down和up layer 以及output layer 都是pruned layer。
    其中prune是具体进性layer prune的函数,做了以下事情:
    (1) 去掉model里的大的block
    (2) 找到泰勒估计中,最小估计值所对应的layer和filter的位置,用prune_feature_map函数进行prune。
    (3) 如果下一层不是最后一层,对应去drop掉下一层的输入channel
    (4) down layer的channel改变之后,对应up layer的channel也要改变,这里用hard code去写。
    进一步去看:line68的taylor_estimates_by_module 和 estimates_by_f_map是怎么计算得到的。
    对每一个module list的module,在line64进行了module.taylor_estimates,去进行排序。

    先取出每个module_list 的module.taylor_estimates和idx,
    再从module.taylor_estimates里取出f_map_idx和对应的估计值estimate。

    C. prune_layers.py阅读

    在prune_layers.py中,定义了class PrunableConv2d(nn.Conv2d)class PrunableBatchNorm2d(nn.BatchNorm2d),对PrunableConv2d(nn.Conv2d),定义了属性taylor_estimates

    D. 提问:
    问题1:

    pruning.py基于前几个training batch和几个epoch和手动输入的num_prune_iterations对unet进行pruning,那么如何用prune好的网络对我们的数据进行计算呢?

    num_prune_iterations = 100,
    epochs=5

    这里又涉及两个问题:(a) prune完需要retrain吗? (b) 如何用pruned的网络进行inference?
    对问题(a),其实在pruning.py里,有反向传播更新梯度值和权重值的过程了,未必要重新去train。
    对问题(b),需要继续阅读代码。
    代码里并没有写,可能需要自己在prediction的代码里导入pruned_unet

    问题2: 如何计算每个module的taylor_estimates

    prune_layers.pyclass PrunableConv2d(nn.Conv2d)中有一个函数_calculate_taylor_estimate(self, _, grad_input, grad_output)专门计算taylor_estimates。
    这里有注释:# skip dim 1 as it is kernel size
    其中,_recent_activations是forward之后该conv2d层的output。

    mul_(value)
    mul()的直接运算形式,即直接执行并且返回修改后的张量

    # skip dim 1 as it is kernel size
    estimates = self._recent_activations.mul_(grad_output[0])
    estimates = estimates.mean(dim=(0, 2, 3))        
    # normalization
    self.taylor_estimates = torch.abs(estimates) / torch.sqrt(torch.sum(estimates * estimates))
    

    修改代码据为己用:

    A. pruning.py改动记录

    1. 把optimizer从SGD改成Adam,和自己的UNet保持一致。(已完成)
    2. line61criterion = nn.BCELoss()改成自己定义的Diceloss
    3. line183的net的定义n_channel从3改成1
    4. summary(net, (3, 640, 640))注释掉这一步可视化,因为暂时不探究其参数含义。
      B.

    相关文章

      网友评论

          本文标题:代码阅读笔记1

          本文链接:https://www.haomeiwen.com/subject/mtabkktx.html