美文网首页
tensorflow CNN实现手写数字识别

tensorflow CNN实现手写数字识别

作者: 随时学丫 | 来源:发表于2019-04-29 14:06 被阅读0次

卷积神经网络

tf.nn.conv2d

TF-卷积函数 tf.nn.conv2d 介绍

tf.nn.conv2d(input, # [batch, in_height, in_width, in_channels]
            filter, # [filter_height, filter_width, in_channels, out_channels]
            strides, # [1,1,1,1]
            padding, # SAME VALID
            use_cudnn_on_gpu=True,
            data_format='NHWC',
            dilations=[1, 1, 1, 1],
            name=None)
  • input: [batch, in_height, in_width, in_channels] 4-D的 Tensor (float32/float64)

     batch_size,高度,宽度,通道数
    
  • filter:[filter_height, filter_width, in_channels, out_channels]

      卷积核高度,   卷积核宽度      输入通道数    输出通道数(卷积核个数)
    
  • strides:步长 [1,1,1,1] 1-D向量,长度为4

  • padding:填充 SAME VALID

  • use_cudnn_on_gpu:是否GPU加速

结果返回一个Tensor,这个输出,就是我们常说的 feature map,shape仍然是 [batch, height, width, channels] 这种形式。

图像卷积

3x3图像,1通道,1x1卷积核

1.考虑一种最简单的情况,现在有一张3×3单通道的图像(对应的shape:[1,3,3,1]),用一个1×1的卷积核(对应的shape:[1,1,1,1])去做卷积,最后会得到一张3×3的feature map

In [4]:

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,3,3,1]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([1,1,1,1]))
op1 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op1),'\n')
    print(sess.run(op1).shape)

运行结果

[[[[-0.06386475]
   [ 0.30251193]
   [-0.36254457]]

  [[-0.1863834 ]
   [-0.11046342]
   [ 0.12128225]]

  [[-0.16598591]
   [-0.06247617]
   [ 0.10344568]]]] 

(1, 3, 3, 1)

3x3图像,5通道,1x1卷积核

2.增加图片的通道数,使用一张3×3五通道的图像(对应的shape:[1,3,3,5]),用一个1×1的卷积核(对应的shape:[1,1,1,1])去做卷积,仍然是一张3×3的feature map,这就相当于每一个像素点,卷积核都与该像素点的每一个通道做卷积。

In [5]:

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,3,3,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([1,1,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[ 2.611188 ]
   [ 1.1356436]
   [ 0.1728566]]

  [[ 1.7238054]
   [-2.65044  ]
   [-0.6026321]]

  [[-0.9376406]
   [-2.1341398]
   [ 0.5801886]]]] 

(1, 3, 3, 1)

3x3图像,5通道,3x3卷积核,文本卷积维度和图像一样大,所以卷积之后只有1列

3.把卷积核扩大,现在用3×3的卷积核做卷积,最后的输出是一个值,相当于情况2的feature map所有像素点的值求和

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,3,3,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,3,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[-3.8528938]]]] 

(1, 1, 1, 1)

5x5图像,5通道,3x3卷积核

4.使用更大的图片将情况2的图片扩大到5×5,仍然是3×3的卷积核,令步长为1,输出3×3的feature map

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,5,5,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,3,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[-3.3537188]
   [ 2.6631894]
   [10.7735815]]

  [[ 6.7866626]
   [-5.753437 ]
   [ 6.8379397]]

  [[-7.2338777]
   [-3.8412943]
   [11.663807 ]]]] 

(1, 3, 3, 1)

5x5图像,5通道,3x3卷积核

5.上面我们一直令参数padding的值为‘VALID’,当其为‘SAME’时,表示卷积核可以停留在图像边缘,如下,输出5×5的feature map

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,5,5,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,3,5,1]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[ -4.1736193 ]
   [  7.2922435 ]
   [  2.9188058 ]
   [  0.49713266]
   [ -2.956664  ]]

  [[ -3.305164  ]
   [ -7.311406  ]
   [ -5.045771  ]
   [ -2.5354984 ]
   [ -7.40237   ]]

  [[  9.273168  ]
   [  1.2130424 ]
   [ -8.63011   ]
   [  8.675023  ]
   [  4.0911283 ]]

  [[ -2.295607  ]
   [  5.2230077 ]
   [-10.142306  ]
   [ -6.135029  ]
   [  1.3315554 ]]

  [[-11.159186  ]
   [ -3.5029335 ]
   [ -1.638276  ]
   [ -4.381499  ]
   [ -1.0199151 ]]]] 

(1, 5, 5, 1)

5x5图像,5通道,3x3卷积核,3个卷积核

6.如果卷积核有多个

此时输出3张5×5的feature map

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,5,5,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,3,5,3]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[  2.7570508   -3.232238    -2.6215773 ]
   [  2.532285     1.9889098    3.87929   ]
   [ -3.187311     8.91769      3.224719  ]
   [  0.44387245  -9.403946     2.468867  ]
   [  0.21311586  -1.590601     5.749056  ]]

  [[ -2.6277003    6.488189     9.992645  ]
   [  1.5766766  -11.48576      0.6145782 ]
   [ -5.0482545   -0.96584886   4.0381684 ]
   [ -0.797274     2.4302173   -3.8855307 ]
   [ -2.6238062    2.05465      2.9259453 ]]

  [[  4.714437    -0.9536078   -2.9879472 ]
   [  1.5400691    1.5240853   -6.90153   ]
   [ -3.6736727    3.85059     -0.5918405 ]
   [  7.023252     2.9593654  -13.595696  ]
   [ -5.041815    -2.7133517    0.6385279 ]]

  [[  1.477376     0.47209492   5.653083  ]
   [ -0.39575818  14.780628    -1.5949147 ]
   [  2.378466   -11.533363    -0.4041656 ]
   [ -0.4129743    6.5807753   -2.7889323 ]
   [  6.1631317   -0.49479347   1.52246   ]]

  [[ -6.97586      1.1432166   -5.064254  ]
   [ -9.823753    -3.1042528   -1.6604922 ]
   [  9.015108   -10.42481     -4.7503257 ]
   [  2.3552632    0.43692362   2.7325256 ]
   [ -2.1840062    2.729301    -4.588225  ]]]] 

(1, 5, 5, 3)

5x5图像,5通道,3x3卷积核,3个卷积核,步长为2

7.步长不为1的情况,文档里说了对于图片,因为只有两维,通常strides取[1,stride,stride,1]

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([1,5,5,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,3,5,3]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[ 2.783487  -2.09441    4.733526 ]
   [-7.3059773 -3.108855   4.022243 ]
   [-4.050215   3.0158758 -4.1893964]]

  [[-6.3690815 -5.2265515  1.1703218]
   [ 9.0784235 -3.5745146 14.855592 ]
   [-0.4078823  4.0576644  4.617129 ]]

  [[ 3.339266  -5.4302483 -3.154387 ]
   [ 1.3765206  2.6518223  5.6584387]
   [-3.9308991  1.4282804 -3.4455342]]]] 

(1, 3, 3, 3)

5x5图像,5通道,3x3卷积核,3个卷积核,步长为2,10张图像

8.如果batch值不为1,同时输入10张图

每张图,都有3张3×3的feature map,输出的shape就是[10,3,3,3]

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([10,5,5,5]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,3,5,3]))
op2 = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='SAME')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op2),'\n')
    print(sess.run(op2).shape)

运行结果

[[[[  3.7294517   -3.107039    -3.3261778 ]
   [  5.441018    -4.0686336    5.619485  ]
   [  2.1390946    3.5674727   -1.5291997 ]]

  [[ -5.6689725   16.974398     2.381929  ]
   [  2.4008992   -0.6694184    1.378117  ]
   [  1.2934582    7.2192235    0.48349503]]

  [[  2.9121394   -0.4573097   -9.765212  ]
   [  1.0088806   -0.7046843    6.591536  ]
   [ -0.72504395   0.43721557  -4.999654  ]]]


 [[[  1.967625    -1.6568589   -6.145099  ]
   [ -4.151078   -10.529405    -2.047928  ]
   [  3.4548922    2.7491624    2.9001775 ]]

  [[ -1.6896939    3.1873543    6.188783  ]
   [ 11.703161     1.6971766   -4.8438787 ]
   [  0.9549799   -1.1131762    4.593415  ]]

  [[ -1.1170579    0.4810401   -1.3526723 ]
   [  2.529728    -1.1482326    3.7958796 ]
   [ -0.24976896   3.3091352   -6.729189  ]]]


 [[[ -3.33531     -1.9344107    2.8165019 ]
   [ -1.6785766   -2.8081656    7.2197647 ]
   [  1.7976431    2.8334517   -0.08083367]]

  [[  2.5362453   -0.68693405   2.2952533 ]
   [  1.5236933    2.129165     0.194734  ]
   [  0.5964938   -8.6989565    5.084363  ]]

  [[  5.066878    -1.3026551   -5.7902007 ]
   [ -2.9802423    4.8924155    5.9025197 ]
   [ -3.933334     5.099715    -0.8536027 ]]]


 [[[ -1.0145748  -10.15126     -7.0179715 ]
   [ -8.451802    -0.17334843  -2.7171214 ]
   [  5.668031    -5.15528     -4.2402534 ]]

  [[ -2.0954626   -1.0145442    3.2066696 ]
   [  2.289553     0.9271075    0.8146973 ]
   [ -0.7423492    4.2153864   -4.70488   ]]

  [[  3.8675358   -0.35446188  -1.1588985 ]
   [ -4.8492827   -5.2945166    3.944246  ]
   [ -0.43092388  -0.8130417    2.3813803 ]]]


 [[[  0.66720986   3.8808417    2.1328838 ]
   [  7.446735    -4.522188    -5.990181  ]
   [ -2.5916054    5.0853543    2.6371577 ]]

  [[ -7.136207     0.6306949   -8.853178  ]
   [  3.7415988   -7.89348     -8.487032  ]
   [ -0.69531     -3.222552     1.5073893 ]]

  [[ -3.106504    -0.01809834  -9.029028  ]
   [ -4.416857     0.13292897   7.7073345 ]
   [  0.9844466    4.2795186   -0.76342046]]]


 [[[ -0.11672409  -7.369146    -4.8543487 ]
   [ -4.230579    -1.1736143    0.74828875]
   [ -0.6568188    6.765464    -4.9761944 ]]

  [[  3.8933635   -7.902747    -0.63001007]
   [  6.8245344    3.9199047   11.168122  ]
   [  3.7043867    0.31197003   0.04769279]]

  [[ -4.1409       8.580945     8.486864  ]
   [ -3.1867335    7.059393     6.296857  ]
   [ -0.9835455   -3.6718185    0.97860974]]]


 [[[  5.193179    -1.2967495    8.170371  ]
   [ -1.3087153   -5.5033283   -1.6919953 ]
   [  0.36510244  -6.296658    -3.7380807 ]]

  [[  3.5434737   13.3447695    1.7701437 ]
   [-10.250333    -9.407058    -3.2337494 ]
   [  8.435421     4.078936     3.4657378 ]]

  [[  3.4681797    2.228949     0.45596147]
   [ -0.57005715   4.670751     2.034872  ]
   [ -1.915133     7.9970365   -3.8922138 ]]]


 [[[  3.950432    -4.4767623   -6.447672  ]
   [  2.595737    -8.553671    -0.45686972]
   [  3.391854    -2.003466    -2.2928245 ]]

  [[ -3.6888146    3.5153918    1.2406276 ]
   [ -0.25753272   2.6999128   -2.8501456 ]
   [ -0.9058769    6.502099    -0.5419939 ]]

  [[ -0.68687534  -6.5038085    2.8593688 ]
   [ -3.683316     2.1430447    5.490655  ]
   [  5.7413816    3.3227494   -7.533464  ]]]


 [[[ -3.484571     5.3650527    2.9336984 ]
   [  0.6027174    3.7776787    1.0154141 ]
   [ -4.7919264    7.149525     1.9800262 ]]

  [[ -2.1547816   -3.2360375   -5.2381744 ]
   [  7.6362724    8.085188     9.068025  ]
   [ -4.549206    -3.8285804    5.8914824 ]]

  [[ -1.9079026    2.9233663    0.9151974 ]
   [  6.70253    -10.376949    -2.2334673 ]
   [  2.7263498    3.202616     3.6564238 ]]]


 [[[  6.6487746   -0.2954742    3.0371974 ]
   [  3.576       -7.2807136    4.2893467 ]
   [ -0.96813136  -5.533345    -6.83936   ]]

  [[ -0.10136782  -1.4625425   -7.0081096 ]
   [  3.8160882   -2.4150543   -3.9401052 ]
   [  2.7480733   -1.4603323   10.289123  ]]

  [[  2.629776    -3.5297518    0.4979372 ]
   [  0.9985927   -8.139794    -0.5185237 ]
   [ -4.5744176  -10.06965      5.6358476 ]]]] 

(10, 3, 3, 3)

文本卷积

5x5文本,1通道,2-gram卷积核,3个卷积核,10句话

import tensorflow as tf
#                 [batch, in_height, in_width, in_channels]
input = tf.Variable(tf.random_normal([10,5,5,1]))
#                 [filter_height, filter_width, in_channels, out_channels]
filter = tf.Variable(tf.random_normal([3,5,1,3]))
op1 = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(op1),'\n')
    print(sess.run(op1).shape)

运行结果

[[[[ -0.28560075   3.4164069    0.64287925]]

  [[  4.922906     6.569628    -4.377281  ]]

  [[  2.0096815   -0.9498653   -5.302     ]]]


 [[[ -0.30866227   0.65657634   0.08617933]]

  [[  2.1648352    2.0540233   -6.2501183 ]]

  [[  1.8437229   -3.3579445    0.648278  ]]]


 [[[ -6.248692    -8.374758     3.1102016 ]]

  [[ -4.6158714    1.0821313    2.8032086 ]]

  [[  1.912092     0.933113    -3.0924444 ]]]


 [[[  7.5399313    9.936766    -1.6083889 ]]

  [[  2.4991071   -0.3938322    5.2363515 ]]

  [[ -3.8184917    1.5327872   -0.9156568 ]]]


 [[[ -1.0705962    1.3645588   -2.2302496 ]]

  [[  0.9711383   -2.6879628   -2.1285567 ]]

  [[ -5.15031     -1.7857913   -0.64766765]]]


 [[[ -4.300858    -0.74519587   4.707138  ]]

  [[  1.1525508   -1.9355469    1.1351813 ]]

  [[ -1.930467     5.30831     -0.11006889]]]


 [[[ -0.30049637  -3.3917482   -0.98812234]]

  [[ -0.78466344  -3.508609     1.8363969 ]]

  [[  1.6145957    0.15216915  -0.27968606]]]


 [[[  1.201814     2.2275777    3.4975147 ]]

  [[  1.7633957    3.9830918   10.16128   ]]

  [[  1.9025049    4.217062     3.2219505 ]]]


 [[[  2.7462778   -0.87272054 -10.7139845 ]]

  [[  0.5596598   -7.9665465   -3.5733411 ]]

  [[  0.02203573   1.8229557    1.1090751 ]]]


 [[[ -2.8659163   -6.198704     1.1388084 ]]

  [[ -5.6100855    2.2285914    1.380748  ]]

  [[ -0.3560774    7.3229613    1.1240004 ]]]] 

(10, 3, 1, 3)

tf.nn.max_pool

TF-池化函数 tf.nn.max_pool 的介绍

tf.nn.max_pool(value, # [batch, in_height, in_width, in_channels]
                ksize, #  [1,in_height,in_width,1]
                strides, # [1,height,width,1]
                padding, # SAME VALID
                data_format='NHWC',
                name=None,)
  • value: 池化层的输入,一般池化层在卷积层后面,输入通常是feature_map

      依然是 [batch, in_height, in_width, in_channels]
    
  • ksize:池化窗口大小,4-D向量,一般是 [1,in_height,in_width,1]

       因为我们不想在 batch,in_channels上做池化,所以维度为1
    
  • strides:步长 [1,height,width,1] 1-D向量,长度为4

  • padding:填充 SAME VALID

返回一个Tensor,类型不变,shape仍然是 [batch, in_height, in_width, in_channels] 这种形式

4x4图像,2通道,池化核2x2,步长1

卷积由 [batch, in_height, in_width, in_channels]

[1,4,4,2] -> [1,3,3,2]

# 2x4x4
a=tf.constant([  
        [[1.0,2.0,3.0,4.0],  
        [5.0,6.0,7.0,8.0],  
        [8.0,7.0,6.0,5.0],  
        [4.0,3.0,2.0,1.0]], # 通道1图像
    
        [[4.0,3.0,2.0,1.0],  # 通道2图像
         [8.0,7.0,6.0,5.0],  
         [1.0,2.0,3.0,4.0],  
         [5.0,6.0,7.0,8.0]]  
    ])  

b=tf.reshape(a,[1,4,4,2])

max_pool_2x2 = tf.nn.max_pool(b,[1,2,2,1],[1,1,1,1],padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('池化前原始,image:\n')
    print(sess.run(a),'\n')
    print(sess.run(a).shape,'\n')
    print('池化前,reshape,image:\n')
    print(sess.run(b),'\n')
    print(sess.run(b).shape,'\n')
    print('池化后image:\n')
    print(sess.run(max_pool_2x2),'\n')
    print(sess.run(max_pool_2x2).shape,'\n')

运行结果

池化前原始,image:

[[[1. 2. 3. 4.]
  [5. 6. 7. 8.]
  [8. 7. 6. 5.]
  [4. 3. 2. 1.]]

 [[4. 3. 2. 1.]
  [8. 7. 6. 5.]
  [1. 2. 3. 4.]
  [5. 6. 7. 8.]]] 

(2, 4, 4) 

池化前,reshape,image:

[[[[1. 2.]
   [3. 4.]
   [5. 6.]
   [7. 8.]]

  [[8. 7.]
   [6. 5.]
   [4. 3.]
   [2. 1.]]

  [[4. 3.]
   [2. 1.]
   [8. 7.]
   [6. 5.]]

  [[1. 2.]
   [3. 4.]
   [5. 6.]
   [7. 8.]]]] 

(1, 4, 4, 2) 

池化后image:

[[[[8. 7.]
   [6. 6.]
   [7. 8.]]

  [[8. 7.]
   [8. 7.]
   [8. 7.]]

  [[4. 4.]
   [8. 7.]
   [8. 8.]]]] 

(1, 3, 3, 2) 

4x4图像,2通道,池化核2x2,步长2

卷积由 [batch, in_height, in_width, in_channels]

[1,4,4,2] -> [1,2,2,2]

# 2x4x4
a=tf.constant([  
        [[1.0,2.0,3.0,4.0],  
        [5.0,6.0,7.0,8.0],  
        [8.0,7.0,6.0,5.0],  
        [4.0,3.0,2.0,1.0]], # 通道1图像
    
        [[4.0,3.0,2.0,1.0],  # 通道2图像
         [8.0,7.0,6.0,5.0],  
         [1.0,2.0,3.0,4.0],  
         [5.0,6.0,7.0,8.0]]  
    ])  

b=tf.reshape(a,[1,4,4,2])

max_pool_2x2 = tf.nn.max_pool(b,[1,2,2,1],[1,2,2,1],padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('池化前原始,image:\n')
    print(sess.run(a),'\n')
    print(sess.run(a).shape,'\n')
    print('池化前,reshape,image:\n')
    print(sess.run(b),'\n')
    print(sess.run(b).shape,'\n')
    print('池化后image:\n')
    print(sess.run(max_pool_2x2),'\n')
    print(sess.run(max_pool_2x2).shape,'\n')

运行结果

池化前原始,image:

[[[1. 2. 3. 4.]
  [5. 6. 7. 8.]
  [8. 7. 6. 5.]
  [4. 3. 2. 1.]]

 [[4. 3. 2. 1.]
  [8. 7. 6. 5.]
  [1. 2. 3. 4.]
  [5. 6. 7. 8.]]] 

(2, 4, 4) 

池化前,reshape,image:

[[[[1. 2.]
   [3. 4.]
   [5. 6.]
   [7. 8.]]

  [[8. 7.]
   [6. 5.]
   [4. 3.]
   [2. 1.]]

  [[4. 3.]
   [2. 1.]
   [8. 7.]
   [6. 5.]]

  [[1. 2.]
   [3. 4.]
   [5. 6.]
   [7. 8.]]]] 

(1, 4, 4, 2) 

池化后image:

[[[[8. 7.]
   [7. 8.]]

  [[4. 4.]
   [8. 8.]]]] 

(1, 2, 2, 2) 

CNN手写数字识别

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

batch_size = 100
n_batch = mnist.train.num_examples // batch_size

def weight_variable(shape):
    return tf.Variable(tf.truncated_normal(shape,stddev=0.1))

def bias_vairable(shape):
    return tf.Variable(tf.constant(0.1, shape=shape))

def conv2d(x,W):
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32)

# x input  [batch,in_height,in_width,in_channels]
# [?,784] reshape -> [-1,28,28,1]
x_image = tf.reshape(x,[-1,28,28,1])

#w        [filter_height,filter_width,in_channels, out_channels]
W_conv1 = weight_variable([5,5,1,32]) # 5*5的采样窗口,32个卷积核从1个平面抽取特征
b_conv1 = bias_vairable([32]) #每个卷积核一个偏置值

# 28*28*1 的图片卷积之后变为 -1*28*28*32
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
# 池化之后变为 -1*14*14*32
h_pool1 = max_pool_2x2(h_conv1)

# 第二次卷积之后变为 14*14*64
W_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_vairable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2)

# 第二次池化之后变为 7*7*64
h_pool2 = max_pool_2x2(h_conv2)


# 第一个全连接层
W_fc1 = weight_variable([7*7*64,1024]) #上一层有7*7*64个神经元,全连接层有1024个神经元
b_fc1 = bias_vairable([1024])#1024个节点

# 把池化层2的输出扁平化为1维向量
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
#求第一个全连接的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 初始化第二个全连接层
W_fc2 = weight_variable([1024,10])
b_fc2 = bias_vairable([10])
logits = tf.matmul(h_fc1_drop,W_fc2) + b_fc2

#计算输出
prediction = tf.nn.sigmoid(logits)

#交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
#AdamOptimizer
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)

prediction_2 = tf.nn.softmax(prediction)
correct_prediction = (tf.equal(tf.argmax(prediction_2,1), tf.argmax(y,1)))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
        acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.0})
        print("Iter: " + str(epoch) + ", acc: " + str(acc))

运行结果

Iter: 0, acc: 0.9796
Iter: 1, acc: 0.9863
Iter: 2, acc: 0.9882
Iter: 3, acc: 0.9908
Iter: 4, acc: 0.99
Iter: 5, acc: 0.991
Iter: 6, acc: 0.9926
Iter: 7, acc: 0.989
Iter: 8, acc: 0.9913
Iter: 9, acc: 0.9896
Iter: 10, acc: 0.992
Iter: 11, acc: 0.9913
Iter: 12, acc: 0.9917
Iter: 13, acc: 0.9911
Iter: 14, acc: 0.9912
Iter: 15, acc: 0.9915
Iter: 16, acc: 0.9907
Iter: 17, acc: 0.9905
Iter: 18, acc: 0.9923
Iter: 19, acc: 0.9905
Iter: 20, acc: 0.9909

相关文章

网友评论

      本文标题:tensorflow CNN实现手写数字识别

      本文链接:https://www.haomeiwen.com/subject/susynqtx.html