美文网首页
mnist数据集 DCGAN训练

mnist数据集 DCGAN训练

作者: small瓜瓜 | 来源:发表于2020-04-09 06:39 被阅读0次
    import tensorflow as tf
    from tensorflow.keras import datasets, Sequential, layers, losses
    from PIL import Image
    import numpy as np
    import os
    
    
    def save_result(val_out, val_block_size, image_path, color_mode):
        def preprocess(img):
            img = ((img + 1.0) * 127.5).astype(np.uint8)
            # img = img.astype(np.uint8)
            return img
    
        preprocesed = preprocess(val_out)
        final_image = np.array([])
        single_row = np.array([])
        for b in range(val_out.shape[0]):
            # concat image into a row
            if single_row.size == 0:
                single_row = preprocesed[b, :, :, :]
            else:
                single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
    
            # concat image row to final_image
            if (b + 1) % val_block_size == 0:
                if final_image.size == 0:
                    final_image = single_row
                else:
                    final_image = np.concatenate((final_image, single_row), axis=0)
    
                # reset single row
                single_row = np.array([])
    
        if final_image.shape[2] == 1:
            final_image = np.squeeze(final_image, axis=2)
        Image.fromarray(final_image).save(image_path)
    
    
    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
    
    # 观察mnist数据类型
    print(x_train, y_train, x_test, y_test)
    
    # 4 * 4 * 7 => 28 * 28 * 1
    # (60000, 28, 28) - (10000, 28, 28)
    generator = Sequential([
        layers.Dense(4 * 4 * 7, activation=tf.nn.leaky_relu),
        layers.Reshape(target_shape=(4, 4, 7)),
        layers.Conv2DTranspose(14, 5, 2, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2DTranspose(5, 3, 1, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2DTranspose(1, 4, 2, activation=tf.nn.tanh),
        layers.Reshape(target_shape=(28, 28)),
    ])
    
    discriminator = Sequential([
        layers.Reshape((28, 28, 1)),
        layers.Conv2D(3, 4, 2, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Conv2D(28, 5, 2, activation=tf.nn.leaky_relu),
        layers.BatchNormalization(),
        layers.Flatten(),
        layers.Dense(1)
    ])
    # 5s 89us/sample - loss: 0.0264 - accuracy: 0.9949 - val_loss: 0.1412 - val_accuracy: 0.9863
    
    # 超参数
    dim_h = 100
    epochs = int(9e+7)
    batch_size = 128
    learning_rate = 2e-3
    
    
    def preprocess(pre_x, pre_y):
        pre_x = tf.cast(pre_x, dtype=tf.float32) / 255.
        pre_y = tf.cast(pre_y, dtype=tf.int32)
        return pre_x, pre_y
    
    
    db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) \
        .map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
    
    db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) \
        .map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
    
    generator.build((None, dim_h))
    generator.summary()
    
    discriminator.build((None, 28, 28, 1))
    discriminator.summary()
    
    # 是不是对应的
    print(generator(tf.random.normal((1, dim_h))))
    print(discriminator(tf.random.normal((1, 28, 28, 1))))
    
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    cross_entropy = losses.BinaryCrossentropy(from_logits=True)
    
    for epoch in range(epochs):
        for step, (true_x, y) in enumerate(db_train):
            with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
                # 随机一个标准的种子
                random_seek = tf.random.normal((batch_size, dim_h))
                # 生成一批假图片
                false_x = generator(random_seek)
                # 通过判断器鉴别假图片
                false_y = discriminator(false_x)
                true_y = discriminator(true_x)
                false_loss = cross_entropy(tf.zeros_like(false_y), false_y)
                true_loss = cross_entropy(tf.ones_like(true_y), true_y)
                d_loss = false_loss + true_loss
                g_loss = cross_entropy(tf.ones_like(false_y), false_y)
            d_grad = d_tape.gradient(d_loss, discriminator.trainable_variables)
            d_optimizer.apply_gradients(zip(d_grad, discriminator.trainable_variables))
            g_grad = g_tape.gradient(g_loss, generator.trainable_variables)
            g_optimizer.apply_gradients(zip(g_grad, generator.trainable_variables))
    
        print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
        # 打印一张图片
        z = tf.random.normal([100, dim_h])
        fake_image = generator(z, training=False)
        if not os.path.exists('mnist-images'):
            os.mkdir('mnist-images')
        img_path = os.path.join('mnist-images', 'gan-one%d.png' % epoch)
        fake_image = tf.expand_dims(fake_image, axis=3)
        save_result(fake_image.numpy(), 10, img_path, color_mode='P')
    

    运行打印:

    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense (Dense)                multiple                  11312     
    _________________________________________________________________
    reshape (Reshape)            multiple                  0         
    _________________________________________________________________
    conv2d_transpose (Conv2DTran multiple                  2464      
    _________________________________________________________________
    batch_normalization (BatchNo multiple                  56        
    _________________________________________________________________
    conv2d_transpose_1 (Conv2DTr multiple                  635       
    _________________________________________________________________
    batch_normalization_1 (Batch multiple                  20        
    _________________________________________________________________
    conv2d_transpose_2 (Conv2DTr multiple                  81        
    _________________________________________________________________
    reshape_1 (Reshape)          multiple                  0         
    =================================================================
    Total params: 14,568
    Trainable params: 14,530
    Non-trainable params: 38
    _________________________________________________________________
    Model: "sequential_1"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    reshape_2 (Reshape)          multiple                  0         
    _________________________________________________________________
    conv2d (Conv2D)              multiple                  51        
    _________________________________________________________________
    batch_normalization_2 (Batch multiple                  12        
    _________________________________________________________________
    conv2d_1 (Conv2D)            multiple                  336       
    _________________________________________________________________
    batch_normalization_3 (Batch multiple                  48        
    _________________________________________________________________
    conv2d_2 (Conv2D)            multiple                  8428      
    _________________________________________________________________
    batch_normalization_4 (Batch multiple                  112       
    _________________________________________________________________
    flatten (Flatten)            multiple                  0         
    _________________________________________________________________
    dense_1 (Dense)              multiple                  449       
    =================================================================
    Total params: 9,436
    Trainable params: 9,350
    Non-trainable params: 86
    _________________________________________________________________
    2020-04-09 01:48:39.536782: I
    0 d-loss: 0.3025842308998108 g-loss: 2.1815402507781982
    1 d-loss: 0.4411752223968506 g-loss: 2.2589144706726074
    2 d-loss: 0.44797778129577637 g-loss: 1.6934151649475098
    3 d-loss: 0.5456695556640625 g-loss: 2.8530282974243164
    4 d-loss: 0.4662773609161377 g-loss: 2.8335046768188477
    5 d-loss: 0.3083723187446594 g-loss: 3.31571102142334
    6 d-loss: 0.23971307277679443 g-loss: 2.7306787967681885
    7 d-loss: 0.5093101263046265 g-loss: 2.2037010192871094
    8 d-loss: 0.36190980672836304 g-loss: 4.701327323913574
    9 d-loss: 0.43783730268478394 g-loss: 4.386983394622803
    10 d-loss: 0.3293834328651428 g-loss: 2.7903919219970703
    11 d-loss: 0.3341054916381836 g-loss: 3.461742877960205
    12 d-loss: 0.30024251341819763 g-loss: 2.703348159790039
    13 d-loss: 0.5041056871414185 g-loss: 2.055236577987671
    14 d-loss: 0.3214653432369232 g-loss: 3.836017370223999
    15 d-loss: 0.353255033493042 g-loss: 2.407291889190674
    16 d-loss: 0.29138171672821045 g-loss: 3.264908790588379
    17 d-loss: 0.26765525341033936 g-loss: 5.04957914352417
    18 d-loss: 0.30802297592163086 g-loss: 4.81702995300293
    19 d-loss: 0.4219457805156708 g-loss: 5.073997497558594
    20 d-loss: 0.3222273588180542 g-loss: 4.902792930603027
    21 d-loss: 0.2720641791820526 g-loss: 4.062989234924316
    22 d-loss: 0.23554465174674988 g-loss: 3.8483152389526367
    23 d-loss: 0.7654502987861633 g-loss: 1.9122107028961182
    24 d-loss: 0.30941855907440186 g-loss: 3.5728230476379395
    25 d-loss: 0.3057532012462616 g-loss: 3.3135852813720703
    26 d-loss: 0.27834975719451904 g-loss: 4.216555118560791
    27 d-loss: 0.3380710482597351 g-loss: 3.4310202598571777
    28 d-loss: 0.24519062042236328 g-loss: 3.3996081352233887
    29 d-loss: 0.3752197027206421 g-loss: 4.753103256225586
    30 d-loss: 0.3422132134437561 g-loss: 2.5223147869110107
    31 d-loss: 0.7221729755401611 g-loss: 6.252880573272705
    32 d-loss: 0.2636100947856903 g-loss: 2.730095863342285
    33 d-loss: 0.5032351613044739 g-loss: 4.74068021774292
    34 d-loss: 0.5151199698448181 g-loss: 2.1353204250335693
    35 d-loss: 0.3672966957092285 g-loss: 3.2035529613494873
    36 d-loss: 0.26749682426452637 g-loss: 4.3134074211120605
    37 d-loss: 0.4011297821998596 g-loss: 3.9894635677337646
    38 d-loss: 0.30018627643585205 g-loss: 3.174570322036743
    39 d-loss: 0.3114895224571228 g-loss: 3.8470301628112793
    40 d-loss: 0.4029478430747986 g-loss: 4.338008403778076
    41 d-loss: 0.2539215683937073 g-loss: 3.3293800354003906
    42 d-loss: 0.4008435904979706 g-loss: 4.759911060333252
    43 d-loss: 0.3200976550579071 g-loss: 3.518287420272827
    44 d-loss: 0.23928441107273102 g-loss: 3.9704060554504395
    45 d-loss: 0.2731139063835144 g-loss: 2.855978488922119
    46 d-loss: 0.2689163088798523 g-loss: 3.992715835571289
    47 d-loss: 0.4422256052494049 g-loss: 2.3679072856903076
    48 d-loss: 0.3424515128135681 g-loss: 4.078521251678467
    49 d-loss: 0.4493892192840576 g-loss: 5.751364231109619
    50 d-loss: 0.15650558471679688 g-loss: 3.686434268951416
    51 d-loss: 0.34632429480552673 g-loss: 2.620640516281128
    52 d-loss: 0.2551218867301941 g-loss: 3.5799636840820312
    53 d-loss: 0.6334245800971985 g-loss: 6.563322067260742
    54 d-loss: 0.5916560292243958 g-loss: 4.386569976806641
    55 d-loss: 0.4112924039363861 g-loss: 4.473291873931885
    56 d-loss: 0.17079852521419525 g-loss: 3.530954360961914
    57 d-loss: 0.29201382398605347 g-loss: 3.409097909927368
    58 d-loss: 0.5939719080924988 g-loss: 2.13434100151062
    59 d-loss: 0.4775002896785736 g-loss: 1.9119606018066406
    60 d-loss: 0.27252131700515747 g-loss: 4.3400983810424805
    61 d-loss: 0.27781713008880615 g-loss: 3.718961238861084
    62 d-loss: 0.3048217296600342 g-loss: 3.391570568084717
    63 d-loss: 0.29252439737319946 g-loss: 3.842097759246826
    64 d-loss: 0.2879011034965515 g-loss: 2.697906255722046
    65 d-loss: 0.4146934151649475 g-loss: 2.5592713356018066
    66 d-loss: 0.25841444730758667 g-loss: 3.3485231399536133
    67 d-loss: 0.34248021245002747 g-loss: 2.9927332401275635
    68 d-loss: 0.19441872835159302 g-loss: 3.862999677658081
    69 d-loss: 0.40257516503334045 g-loss: 2.6037850379943848
    70 d-loss: 0.33036375045776367 g-loss: 3.1049559116363525
    71 d-loss: 0.2422482967376709 g-loss: 3.0365424156188965
    72 d-loss: 0.24604055285453796 g-loss: 3.5101194381713867
    73 d-loss: 1.3223328590393066 g-loss: 9.184656143188477
    74 d-loss: 0.20355640351772308 g-loss: 3.8176610469818115
    75 d-loss: 0.1851392537355423 g-loss: 3.5737180709838867
    76 d-loss: 0.23111382126808167 g-loss: 3.312542676925659
    77 d-loss: 0.12925150990486145 g-loss: 3.7841544151306152
    78 d-loss: 0.4086154103279114 g-loss: 2.4864935874938965
    79 d-loss: 0.29721730947494507 g-loss: 2.7293453216552734
    80 d-loss: 0.2804826498031616 g-loss: 3.2309751510620117
    81 d-loss: 0.22704683244228363 g-loss: 3.60378360748291
    82 d-loss: 0.21729540824890137 g-loss: 3.577629327774048
    83 d-loss: 0.18626506626605988 g-loss: 4.590834617614746
    84 d-loss: 0.39497512578964233 g-loss: 2.1493382453918457
    85 d-loss: 0.3183228075504303 g-loss: 5.248997688293457
    86 d-loss: 0.19268733263015747 g-loss: 4.473655700683594
    87 d-loss: 0.2456638216972351 g-loss: 4.372949600219727
    88 d-loss: 0.19136309623718262 g-loss: 4.760179042816162
    89 d-loss: 0.22348923981189728 g-loss: 4.247585296630859
    90 d-loss: 0.2525639533996582 g-loss: 5.267736434936523
    91 d-loss: 0.22230832278728485 g-loss: 4.386148452758789
    92 d-loss: 0.36075153946876526 g-loss: 3.1002907752990723
    93 d-loss: 0.13224007189273834 g-loss: 4.696763038635254
    94 d-loss: 0.32201671600341797 g-loss: 3.0803260803222656
    95 d-loss: 0.3892339766025543 g-loss: 6.24675178527832
    96 d-loss: 0.2373712956905365 g-loss: 3.386235475540161
    97 d-loss: 0.28235626220703125 g-loss: 2.9006311893463135
    98 d-loss: 0.40496787428855896 g-loss: 2.5861682891845703
    99 d-loss: 0.23271213471889496 g-loss: 3.9647161960601807
    100 d-loss: 0.21597206592559814 g-loss: 4.855806350708008
    101 d-loss: 0.2240012288093567 g-loss: 3.9054088592529297
    102 d-loss: 0.18842440843582153 g-loss: 3.8246123790740967
    103 d-loss: 0.3447532653808594 g-loss: 5.665205478668213
    104 d-loss: 0.5192641615867615 g-loss: 6.705690860748291
    105 d-loss: 0.24415946006774902 g-loss: 3.7206287384033203
    106 d-loss: 0.6034714579582214 g-loss: 2.0236997604370117
    107 d-loss: 0.35751140117645264 g-loss: 4.795225143432617
    108 d-loss: 0.3134361207485199 g-loss: 5.555920600891113
    109 d-loss: 0.32808077335357666 g-loss: 8.050716400146484
    110 d-loss: 0.6285152435302734 g-loss: 8.5546293258667
    111 d-loss: 0.2469012439250946 g-loss: 4.087368011474609
    112 d-loss: 0.20410647988319397 g-loss: 3.3828139305114746
    113 d-loss: 0.21991196274757385 g-loss: 3.9452338218688965
    114 d-loss: 0.17175406217575073 g-loss: 4.084678649902344
    115 d-loss: 0.1731255054473877 g-loss: 4.145017623901367
    116 d-loss: 0.1547868698835373 g-loss: 4.784377574920654
    117 d-loss: 0.2667906880378723 g-loss: 3.4771580696105957
    118 d-loss: 1.0230473279953003 g-loss: 9.154877662658691
    119 d-loss: 0.12319549918174744 g-loss: 4.466263294219971
    120 d-loss: 0.18622779846191406 g-loss: 4.449687480926514
    121 d-loss: 0.1805429607629776 g-loss: 5.678531646728516
    122 d-loss: 0.8003900647163391 g-loss: 1.3731383085250854
    123 d-loss: 0.16159069538116455 g-loss: 4.971116065979004
    124 d-loss: 0.3253968060016632 g-loss: 6.937915802001953
    125 d-loss: 0.23539406061172485 g-loss: 4.205863952636719
    126 d-loss: 0.1837971806526184 g-loss: 4.123740196228027
    127 d-loss: 0.24565982818603516 g-loss: 4.757949352264404
    128 d-loss: 0.37760159373283386 g-loss: 2.4897053241729736
    129 d-loss: 0.6163845062255859 g-loss: 2.643826484680176
    130 d-loss: 0.31684044003486633 g-loss: 4.014310836791992
    131 d-loss: 0.1259973794221878 g-loss: 5.237534046173096
    132 d-loss: 0.2750729024410248 g-loss: 6.2339630126953125
    133 d-loss: 0.623859703540802 g-loss: 8.091939926147461
    134 d-loss: 0.17348800599575043 g-loss: 3.39019775390625
    135 d-loss: 0.2724202871322632 g-loss: 4.073886871337891
    136 d-loss: 0.1955956220626831 g-loss: 4.553360462188721
    137 d-loss: 0.31421488523483276 g-loss: 3.671926975250244
    138 d-loss: 0.16272708773612976 g-loss: 5.542398929595947
    139 d-loss: 0.18932181596755981 g-loss: 3.518281936645508
    140 d-loss: 0.18914400041103363 g-loss: 3.4262115955352783
    141 d-loss: 0.20608733594417572 g-loss: 3.9967246055603027
    142 d-loss: 0.11989393830299377 g-loss: 5.275413990020752
    143 d-loss: 0.22725379467010498 g-loss: 5.856490135192871
    144 d-loss: 0.24864554405212402 g-loss: 3.8872880935668945
    145 d-loss: 0.34549522399902344 g-loss: 7.933394908905029
    146 d-loss: 0.217384934425354 g-loss: 3.0618345737457275
    Process finished with exit code 0
    

    下面是训练到100多个epoch的结果,但是基本几十个就有这样的结果了,效果不是特别好,少量存在坍塌现象,不过还有一个更为严重的问题是,零的数量有点多,可能是零所在的区域比例比较大


    gan-one0.png gan-one1.png gan-one2.png gan-one3.png gan-one4.png gan-one5.png gan-one6.png gan-one7.png
    gan-one8.png gan-one9.png gan-one10.png

    对抗神经网络中的判断器,我们可以将其看作为一个不断升级的损失函数,生成器生成的图片越真实,损失函数值越小,进而让生成器不断向优化

    相关文章

      网友评论

          本文标题:mnist数据集 DCGAN训练

          本文链接:https://www.haomeiwen.com/subject/wbafmhtx.html