mxnet 实现 sqeezenet

作者: cptn3m0 | 来源:发表于2016-06-22 23:30 被阅读1874次

    sqeezenet 是一种参数压缩的方式, 模型小了很多, 但是计算没有减少太多.
    paper: http://arxiv.org/abs/1602.07360
    github: https://github.com/DeepScale/SqueezeNet

    mxnet python 模型代码

    
    """
    paper: http://arxiv.org/abs/1602.07360
    github: https://github.com/DeepScale/SqueezeNet
    
    @article{SqueezeNet,
        Author = {Forrest N. Iandola and Matthew W. Moskewicz and Khalid Ashraf and Song Han and William J. Dally and Kurt Keutzer},
        Title = {SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <1MB model size},
        Journal = {arXiv:1602.07360},
        Year = {2016}
    }
    
    squeezenet
    """
    
    import find_mxnet
    import mxnet as mx
    
    # Basic Conv + BN + ReLU factory
    def FireModelFactory(data, size):
        if size == 1:
            n_s1x1 = 16
            n_e1x1 = 64
            n_e3x3 = 64
        elif size == 2:
            n_s1x1 = 32
            n_e1x1 = 128
            n_e3x3 = 128
        elif size == 3:
            n_s1x1 = 48
            n_e1x1 = 192
            n_e3x3 = 192
        elif size == 4:
            n_s1x1 = 64
            n_e1x1 = 256
            n_e3x3 = 256
    
        squeeze1x1 = mx.symbol.Convolution(
                data   = data, 
                kernel = (1,1), 
                pad    = (0,0),
                num_filter = n_s1x1 )
    
        relu_squeeze1x1 = mx.symbol.Activation( data=squeeze1x1, act_type="relu" )
    
        expand1x1 = mx.symbol.Convolution(
                data   = relu_squeeze1x1,
                kernel = (1,1),
                pad    = (0,0),
                num_filter = n_e1x1 )
    
        relu_expand1x1 = mx.symbol.Activation(data=expand1x1, act_type="relu" )
    
        expand3x3 = mx.symbol.Convolution(
                data   = relu_squeeze1x1,
                kernel = (3,3),
                pad    = (1,1),
                num_filter = n_e3x3 )
    
        relu_expand3x3 = mx.symbol.Activation(data=expand3x3, act_type="relu" )
        
        concat = mx.symbol.Concat( *[relu_expand1x1, relu_expand3x3] )
    
        return concat 
    
    def get_symbol(num_classes = 1000):
        data = mx.symbol.Variable(name="data")
    
        conv1 = mx.symbol.Convolution(data=data, kernel=(7,7), stride=(2,2), num_filter=96)
        relu_conv1 = mx.symbol.Activation(data=conv1, act_type="relu")
        maxpool1 = mx.symbol.Pooling(data=relu_conv1, kernel=(3,3), stride=(2,2), pool_type="max")
    
        fire2 = FireModelFactory(data=maxpool1, size=1)
        fire3 = FireModelFactory(data=fire2, size=1)
        fire4 = FireModelFactory(data=fire3, size=2)
    
        maxpool4 = mx.symbol.Pooling(data=fire4, kernel=(3,3), stride=(2,2), pool_type="max")
    
        fire5 = FireModelFactory(data=maxpool4, size=2)
        fire6 = FireModelFactory(data=fire5, size=3)
        fire7 = FireModelFactory(data=fire6, size=3)
        fire8 = FireModelFactory(data=fire7, size=4)
    
        maxpool8 = mx.symbol.Pooling(data=fire8, kernel=(3,3), stride=(2,2), pool_type="max")
    
        fire9 = FireModelFactory(data=maxpool8, size=4)
        dropout_fire9 = mx.symbol.Dropout(data=fire9, p=0.5)
    
        conv10 = mx.symbol.Convolution(data=dropout_fire9, kernel=(1,1), pad=(1,1), num_filter=1000)
        relu_conv10 = mx.symbol.Activation(data=conv10, act_type="relu")
        avgpool10 = mx.symbol.Pooling(data=relu_conv10, global_pool="true", kernel=(13,13), stride=(1,1), pool_type="avg")
    
        flatten = mx.symbol.Flatten(data=avgpool10, name='flatten')
    
        softmax = mx.symbol.SoftmaxOutput(data=flatten, name="softmax")
        return softmax
    
    
    

    论文中参数的具体表示:
    这个表给出的信息很全面, 可以用来对自己的网络结果做比较, 来检查是否网络实现有问题.

    fire_net2.png

    下面是 squeezenet 的网络架构的图示

    squeezenet-227.png

    这个图有点太大了, 有谁知道如何resize 这个图的方式, 可以留言给我, 我修改一下.


    需要注意的地方

    ** mxnet 的global pooling 的用法**

    avgpool10 = mx.symbol.Pooling(data=relu_conv10, global_pool="true", kernel=(13,13), stride=(1,1), pool_type="avg")
    
    

    这个有一个参数 global_pool, 这个参数是一个开关, 决定是否是global pooling 操作, 还是普通的pooling操作.
    这里还有一个小坑, 应该是代码中间有点小瑕疵, 瑕不掩瑜, 如果有时间了, 我提一个pr.


    训练脚本

    python train_imagenet.py \
      --batch-size  512 \
      --lr 0.04 \
      --lr-factor .94  \
      --gpus 0,1,2,3 \
      --num-epoch 1000 \
      --network squeezenet \
      --model-prefix model/squeezenet \
      --data-dir /search/data/1k/ \
      --train-dataset train_1k_256.rec \
      --val-dataset val_1k_256.rec \
      --data-shape 227 \
      --log-dir /search/data/user/cptn3m0/logs/ \
      --log-file suqeezenet-1m.log
    

    squeezenet 的训练现在已经收敛, 同时对比caffe最新版本, mxnet的收敛速度有明显的优势.

    训练结果, 后面给出哈.

    caffe 的版本的网络描述文件 prototxt, 以及模型参数caffemodel 可以在下面的网址找到

    https://github.com/DeepScale/SqueezeNet

    mxnet 版本的squeezenet 我还在训练, 没有找到合适的超参数, 头疼啊!


    训练结果(还在继续训练中)

    TOP1
    Epoch[0] Validation-accuracy=0.022162
    Epoch[1] Validation-accuracy=0.077706
    Epoch[2] Validation-accuracy=0.159512
    Epoch[3] Validation-accuracy=0.214047
    Epoch[4] Validation-accuracy=0.254604
    Epoch[5] Validation-accuracy=0.286223
    Epoch[6] Validation-accuracy=0.322604
    Epoch[7] Validation-accuracy=0.340362
    Epoch[8] Validation-accuracy=0.353355
    Epoch[9] Validation-accuracy=0.377770
    Epoch[10] Validation-accuracy=0.383929
    Epoch[11] Validation-accuracy=0.395719
    Epoch[12] Validation-accuracy=0.407187
    Epoch[13] Validation-accuracy=0.422513
    Epoch[14] Validation-accuracy=0.411445
    Epoch[15] Validation-accuracy=0.429787
    Epoch[16] Validation-accuracy=0.440171
    Epoch[17] Validation-accuracy=0.448293
    Epoch[18] Validation-accuracy=0.450375
    Epoch[19] Validation-accuracy=0.463289
    Epoch[20] Validation-accuracy=0.460333
    Epoch[21] Validation-accuracy=0.467395
    Epoch[22] Validation-accuracy=0.468770
    Epoch[23] Validation-accuracy=0.473502
    
    
    Top5
    Epoch[0] Validation-top_k_accuracy_5=0.078205
    Epoch[1] Validation-top_k_accuracy_5=0.207151
    Epoch[2] Validation-top_k_accuracy_5=0.354402
    Epoch[3] Validation-top_k_accuracy_5=0.433394
    Epoch[4] Validation-top_k_accuracy_5=0.484714
    Epoch[5] Validation-top_k_accuracy_5=0.524042
    Epoch[6] Validation-top_k_accuracy_5=0.568060
    Epoch[7] Validation-top_k_accuracy_5=0.586814
    Epoch[8] Validation-top_k_accuracy_5=0.602932
    Epoch[9] Validation-top_k_accuracy_5=0.627730
    Epoch[10] Validation-top_k_accuracy_5=0.636201
    Epoch[11] Validation-top_k_accuracy_5=0.644149
    Epoch[12] Validation-top_k_accuracy_5=0.656270
    Epoch[13] Validation-top_k_accuracy_5=0.666474
    Epoch[14] Validation-top_k_accuracy_5=0.660881
    Epoch[15] Validation-top_k_accuracy_5=0.679548
    Epoch[16] Validation-top_k_accuracy_5=0.691067
    Epoch[17] Validation-top_k_accuracy_5=0.695876
    Epoch[18] Validation-top_k_accuracy_5=0.699000
    Epoch[19] Validation-top_k_accuracy_5=0.710499
    Epoch[20] Validation-top_k_accuracy_5=0.709810
    Epoch[21] Validation-top_k_accuracy_5=0.711854
    Epoch[22] Validation-top_k_accuracy_5=0.714525
    Epoch[23] Validation-top_k_accuracy_5=0.718730
    

    相关文章

      网友评论

      • Quiet_c28d:您好,请问一下,想在mxnet下训练squeezenet,将squeezenet.py放到symbols中,之后还需要做什么?
        期待您的回复,谢谢!
        cptn3m0:@Quiet_c28d 我的模型是基于很久以前的mxnet模型,现在的mxnet框架不知道还是不是原来的样子。给个建议,参考mxnet提供的alexnet等经典网络的写法。
        Quiet_c28d:博主, 我修改了squeezenet.py的get_symbol()的参数后可以训练了,但是Train-accuracy很难增长、Validation-accuracy也不怎么变化,可能的原因是什么?(train 10000多张图片,val 1000多张图片)
      • 朱小虎XiaohuZhu:我们改版后应该会好一些

      本文标题:mxnet 实现 sqeezenet

      本文链接:https://www.haomeiwen.com/subject/dxundttx.html