美文网首页
TensorFlow之使用训练好的VGG模型

TensorFlow之使用训练好的VGG模型

作者: 你要好好学习呀 | 来源:发表于2019-05-07 19:26 被阅读0次

    对于读取imagenet-vgg-verydeep-19.mat,python代码为:

    matpath = r"D:\imagenet-vgg-verydeep-19.mat"
    a = scipy.io.loadmat(matpath)
    

    在python代码中a是一个dict类型,长度为3,对应layers,classes,normalization. data["layers"]可以取到matlab中layers (层参数) 对应的数据结构,data["classes"]可以取到matlab中classes(分类信息)对应的数据结构,data["normalization"]可以取到matlab中normalization(正则化参数/像素平均值)对应的数据结构。
    data['layers']是一个143的ndarray.data['layers'][0]就是一个长度为43的ndarray,对应vgg19的43个各层操作的结果.具体为对应关系为:
    0 对应 conv1_1 (3, 3, 3, 64)
    1 对应 relu
    2 对应 conv1_2 (3, 3, 64, 64)
    3 对应 relu
    4 对应 maxpool
    5 对应 conv2_1 (3, 3, 64, 128)
    6 对应 relu
    7 对应 conv2_2 (3, 3, 128, 128)
    8 对应 relu
    9 对应 maxpool
    10 对应 conv3_1 (3, 3, 128, 256)
    11 对应 relu
    12 对应 conv3_2 (3, 3, 256, 256)
    13 对应 relu
    14 对应 conv3_3 (3, 3, 256, 256)
    15 对应 relu
    16 对应 conv3_4 (3, 3, 256, 256)
    17 对应 relu
    18 对应 maxpool
    19 对应 conv4_1 (3, 3, 256, 512)
    20 对应 relu
    21 对应 conv4_2 (3, 3, 512, 512)
    22 对应 relu
    23 对应 conv4_3 (3, 3, 512, 512)
    24 对应 relu
    25 对应 conv4_4 (3, 3, 512, 512)
    26 对应 relu
    27 对应 maxpool
    28 对应 conv5_1 (3, 3, 512, 512)
    29 对应 relu
    30 对应 conv5_2 (3, 3, 512, 512)
    31 对应 relu
    32 对应 conv5_3 (3, 3, 512, 512)
    33 对应 relu
    34 对应 conv5_4 (3, 3, 512, 512)
    35 对应 relu
    36 对应 maxpool
    37 对应 fullyconnected (7, 7, 512, 4096)
    38 对应 relu
    39 对应 fullyconnected (1, 1, 4096, 4096)
    40 对应 relu
    41 对应 fullyconnected (1, 1, 4096, 1000)
    42 对应 softmax
    Vgg-19的layers部分参数数据结构应该是:

    import scipy.io
    import numpy as np
    import tensorflow as tf
    import os
    import scipy.misc
    import matplotlib.pyplot as plt
    ''''
    enumerate()是python的内置函数、适用于python2.x和python3.x
    enumerate在字典上是枚举、列举的意思
    enumerate参数为可遍历/可迭代的对象(如列表、字符串)
    enumerate多用于在for循环中得到计数,利用它可以同时获得索引和值,即需要index和value值的时候可以使用enumerate
    
    matconvnet:weights are [width,,height,in_channels,out_channels]           
    tensorflow: weights are [height,width,in_channels,out_channels] 
    则需要转换成TensorFlow支持的格式    
    
    numpy.reshape(重塑)给数组一个新的形状而不改变其数据numpy.reshape(a, newshape, order=’C’)
    '''
    def _conv_layer(input,weights,biases):
        conv=tf.nn.conv2d(input,tf.constant(weights),strides=(1,1,1,1),padding='SAME')
        return tf.nn.bias_add(conv,biases)
    def _pool_layer(input):
        return tf.nn.max_pool(input,ksize=(1,2,2,1),strides=(1,2,2,1),padding='SAME')
    def preprocess(image,mean_pixel):
        return image-mean_pixel
    def unprocess(image,mean_pixel):
        return image+mean_pixel
    def imread(path):
        return scipy.misc .imread(path).astype(np.float)
    def imsave(path,img):
        img=np.clip(image,0,255).astype(np.uint8)
        scipy.misc.imsave(path,img)
    print("function for VGG ready")
    
    def Vnet(data_path,_input_image):
        layers=(
            'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
            'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
            'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
            'conv4_1','relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
            'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4'
        )
        data = scipy.io.loadmat(data_path)
        mean=data['normalization'][0][0][0]
        mean_pixel=np.mean(mean,axis=(0,1))#测试的输入值也要减去均值
        weights=data['layers'][0]
        net={}#字典结构
        current=_input_image
        for i,name in enumerate(layers):
            print('i=',i)
            kind=name[:4]
            if kind=='conv':
                kernels, bias = weights[i][0][0][0][0]
                kernels = np.transpose(kernels, (1, 0, 2, 3))
                bias = bias.reshape(-1)#变为一行
                current=_conv_layer(current, kernels, bias)
            elif kind=='relu':
                current=tf.nn.relu(current)
            elif kind=='pool':
                current=_pool_layer(current)
            net[name]=current#保存每一层的前向传播的结果
            '''
            python中assert断言是声明其布尔值必须为真的判定,如果发生异常就说明表达示为假。
            可以理解assert断言语句为raise-if-not,用来测试表示式,其返回值为假,就会触发异常。
            '''
    
        assert len(net)==len(layers)
        return net, mean_pixel, layers
    print('Network for VGG ready')
    
    #把每一层的结果可视化的表现出来
    cwd=os.getcwd()#得到当前的工作路径
    VGG_PATH=cwd+"/data/imagenet-vgg-verydeep-19.mat"
    IMG_PATH=cwd+"/data/cat.jpg"
    input_image=imread(IMG_PATH )
    shape=(1,input_image.shape[0],input_image.shape[1],input_image.shape[2])
    with tf.Session() as sess:
        image = tf.placeholder('float', shape=shape)
        nets, mean_pixel, all_layers = Vnet(VGG_PATH, image)
        input_image_pre = np.array([preprocess(input_image, mean_pixel)])
        layers = all_layers
        for i,layer in enumerate (layers):
            print("[%d%d]%s"%(i+1,len(layers),layer))
            features=nets[layer].eval(feed_dict={image:input_image_pre})
            print("Type of 'features' is", type(features ))
            print("Shape of 'feature' is %s"%(features.shape,))
            if 1:
                plt.figure(i+1,figsize=(10,5))
                plt.matshow(features[0,:,:,0],cmap=plt.cm.gray,fignum=i+1)
                plt.title(" "+layer)
                plt.colorbar()
                plt.show()
    print("finished")
    '''
    
    
    

    相关文章

      网友评论

          本文标题:TensorFlow之使用训练好的VGG模型

          本文链接:https://www.haomeiwen.com/subject/tmhjoqtx.html