import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers, losses
from PIL import Image
import numpy as np
import os
def save_result(val_out, val_block_size, image_path, color_mode):
def preprocess(img):
img = ((img + 1.0) * 127.5).astype(np.uint8)
# img = img.astype(np.uint8)
return img
preprocesed = preprocess(val_out)
final_image = np.array([])
single_row = np.array([])
for b in range(val_out.shape[0]):
# concat image into a row
if single_row.size == 0:
single_row = preprocesed[b, :, :, :]
else:
single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
# concat image row to final_image
if (b + 1) % val_block_size == 0:
if final_image.size == 0:
final_image = single_row
else:
final_image = np.concatenate((final_image, single_row), axis=0)
# reset single row
single_row = np.array([])
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
Image.fromarray(final_image).save(image_path)
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
# 观察mnist数据类型
print(x_train, y_train, x_test, y_test)
# 4 * 4 * 7 => 28 * 28 * 1
# (60000, 28, 28) - (10000, 28, 28)
generator = Sequential([
layers.Dense(4 * 4 * 7, activation=tf.nn.leaky_relu),
layers.Reshape(target_shape=(4, 4, 7)),
layers.Conv2DTranspose(14, 5, 2, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2DTranspose(5, 3, 1, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2DTranspose(1, 4, 2, activation=tf.nn.tanh),
layers.Reshape(target_shape=(28, 28)),
])
discriminator = Sequential([
layers.Reshape((28, 28, 1)),
layers.Conv2D(3, 4, 2, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2D(12, 3, 1, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Conv2D(28, 5, 2, activation=tf.nn.leaky_relu),
layers.BatchNormalization(),
layers.Flatten(),
layers.Dense(1)
])
# 5s 89us/sample - loss: 0.0264 - accuracy: 0.9949 - val_loss: 0.1412 - val_accuracy: 0.9863
# 超参数
dim_h = 100
epochs = int(9e+7)
batch_size = 128
learning_rate = 2e-3
def preprocess(pre_x, pre_y):
pre_x = tf.cast(pre_x, dtype=tf.float32) / 255.
pre_y = tf.cast(pre_y, dtype=tf.int32)
return pre_x, pre_y
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) \
.map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) \
.map(preprocess).shuffle(batch_size * 5).batch(batch_size, drop_remainder=True)
generator.build((None, dim_h))
generator.summary()
discriminator.build((None, 28, 28, 1))
discriminator.summary()
# 是不是对应的
print(generator(tf.random.normal((1, dim_h))))
print(discriminator(tf.random.normal((1, 28, 28, 1))))
g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
cross_entropy = losses.BinaryCrossentropy(from_logits=True)
for epoch in range(epochs):
for step, (true_x, y) in enumerate(db_train):
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
# 随机一个标准的种子
random_seek = tf.random.normal((batch_size, dim_h))
# 生成一批假图片
false_x = generator(random_seek)
# 通过判断器鉴别假图片
false_y = discriminator(false_x)
true_y = discriminator(true_x)
false_loss = cross_entropy(tf.zeros_like(false_y), false_y)
true_loss = cross_entropy(tf.ones_like(true_y), true_y)
d_loss = false_loss + true_loss
g_loss = cross_entropy(tf.ones_like(false_y), false_y)
d_grad = d_tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(d_grad, discriminator.trainable_variables))
g_grad = g_tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(g_grad, generator.trainable_variables))
print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
# 打印一张图片
z = tf.random.normal([100, dim_h])
fake_image = generator(z, training=False)
if not os.path.exists('mnist-images'):
os.mkdir('mnist-images')
img_path = os.path.join('mnist-images', 'gan-one%d.png' % epoch)
fake_image = tf.expand_dims(fake_image, axis=3)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
运行打印:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) multiple 11312
_________________________________________________________________
reshape (Reshape) multiple 0
_________________________________________________________________
conv2d_transpose (Conv2DTran multiple 2464
_________________________________________________________________
batch_normalization (BatchNo multiple 56
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr multiple 635
_________________________________________________________________
batch_normalization_1 (Batch multiple 20
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr multiple 81
_________________________________________________________________
reshape_1 (Reshape) multiple 0
=================================================================
Total params: 14,568
Trainable params: 14,530
Non-trainable params: 38
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape_2 (Reshape) multiple 0
_________________________________________________________________
conv2d (Conv2D) multiple 51
_________________________________________________________________
batch_normalization_2 (Batch multiple 12
_________________________________________________________________
conv2d_1 (Conv2D) multiple 336
_________________________________________________________________
batch_normalization_3 (Batch multiple 48
_________________________________________________________________
conv2d_2 (Conv2D) multiple 8428
_________________________________________________________________
batch_normalization_4 (Batch multiple 112
_________________________________________________________________
flatten (Flatten) multiple 0
_________________________________________________________________
dense_1 (Dense) multiple 449
=================================================================
Total params: 9,436
Trainable params: 9,350
Non-trainable params: 86
_________________________________________________________________
2020-04-09 01:48:39.536782: I
0 d-loss: 0.3025842308998108 g-loss: 2.1815402507781982
1 d-loss: 0.4411752223968506 g-loss: 2.2589144706726074
2 d-loss: 0.44797778129577637 g-loss: 1.6934151649475098
3 d-loss: 0.5456695556640625 g-loss: 2.8530282974243164
4 d-loss: 0.4662773609161377 g-loss: 2.8335046768188477
5 d-loss: 0.3083723187446594 g-loss: 3.31571102142334
6 d-loss: 0.23971307277679443 g-loss: 2.7306787967681885
7 d-loss: 0.5093101263046265 g-loss: 2.2037010192871094
8 d-loss: 0.36190980672836304 g-loss: 4.701327323913574
9 d-loss: 0.43783730268478394 g-loss: 4.386983394622803
10 d-loss: 0.3293834328651428 g-loss: 2.7903919219970703
11 d-loss: 0.3341054916381836 g-loss: 3.461742877960205
12 d-loss: 0.30024251341819763 g-loss: 2.703348159790039
13 d-loss: 0.5041056871414185 g-loss: 2.055236577987671
14 d-loss: 0.3214653432369232 g-loss: 3.836017370223999
15 d-loss: 0.353255033493042 g-loss: 2.407291889190674
16 d-loss: 0.29138171672821045 g-loss: 3.264908790588379
17 d-loss: 0.26765525341033936 g-loss: 5.04957914352417
18 d-loss: 0.30802297592163086 g-loss: 4.81702995300293
19 d-loss: 0.4219457805156708 g-loss: 5.073997497558594
20 d-loss: 0.3222273588180542 g-loss: 4.902792930603027
21 d-loss: 0.2720641791820526 g-loss: 4.062989234924316
22 d-loss: 0.23554465174674988 g-loss: 3.8483152389526367
23 d-loss: 0.7654502987861633 g-loss: 1.9122107028961182
24 d-loss: 0.30941855907440186 g-loss: 3.5728230476379395
25 d-loss: 0.3057532012462616 g-loss: 3.3135852813720703
26 d-loss: 0.27834975719451904 g-loss: 4.216555118560791
27 d-loss: 0.3380710482597351 g-loss: 3.4310202598571777
28 d-loss: 0.24519062042236328 g-loss: 3.3996081352233887
29 d-loss: 0.3752197027206421 g-loss: 4.753103256225586
30 d-loss: 0.3422132134437561 g-loss: 2.5223147869110107
31 d-loss: 0.7221729755401611 g-loss: 6.252880573272705
32 d-loss: 0.2636100947856903 g-loss: 2.730095863342285
33 d-loss: 0.5032351613044739 g-loss: 4.74068021774292
34 d-loss: 0.5151199698448181 g-loss: 2.1353204250335693
35 d-loss: 0.3672966957092285 g-loss: 3.2035529613494873
36 d-loss: 0.26749682426452637 g-loss: 4.3134074211120605
37 d-loss: 0.4011297821998596 g-loss: 3.9894635677337646
38 d-loss: 0.30018627643585205 g-loss: 3.174570322036743
39 d-loss: 0.3114895224571228 g-loss: 3.8470301628112793
40 d-loss: 0.4029478430747986 g-loss: 4.338008403778076
41 d-loss: 0.2539215683937073 g-loss: 3.3293800354003906
42 d-loss: 0.4008435904979706 g-loss: 4.759911060333252
43 d-loss: 0.3200976550579071 g-loss: 3.518287420272827
44 d-loss: 0.23928441107273102 g-loss: 3.9704060554504395
45 d-loss: 0.2731139063835144 g-loss: 2.855978488922119
46 d-loss: 0.2689163088798523 g-loss: 3.992715835571289
47 d-loss: 0.4422256052494049 g-loss: 2.3679072856903076
48 d-loss: 0.3424515128135681 g-loss: 4.078521251678467
49 d-loss: 0.4493892192840576 g-loss: 5.751364231109619
50 d-loss: 0.15650558471679688 g-loss: 3.686434268951416
51 d-loss: 0.34632429480552673 g-loss: 2.620640516281128
52 d-loss: 0.2551218867301941 g-loss: 3.5799636840820312
53 d-loss: 0.6334245800971985 g-loss: 6.563322067260742
54 d-loss: 0.5916560292243958 g-loss: 4.386569976806641
55 d-loss: 0.4112924039363861 g-loss: 4.473291873931885
56 d-loss: 0.17079852521419525 g-loss: 3.530954360961914
57 d-loss: 0.29201382398605347 g-loss: 3.409097909927368
58 d-loss: 0.5939719080924988 g-loss: 2.13434100151062
59 d-loss: 0.4775002896785736 g-loss: 1.9119606018066406
60 d-loss: 0.27252131700515747 g-loss: 4.3400983810424805
61 d-loss: 0.27781713008880615 g-loss: 3.718961238861084
62 d-loss: 0.3048217296600342 g-loss: 3.391570568084717
63 d-loss: 0.29252439737319946 g-loss: 3.842097759246826
64 d-loss: 0.2879011034965515 g-loss: 2.697906255722046
65 d-loss: 0.4146934151649475 g-loss: 2.5592713356018066
66 d-loss: 0.25841444730758667 g-loss: 3.3485231399536133
67 d-loss: 0.34248021245002747 g-loss: 2.9927332401275635
68 d-loss: 0.19441872835159302 g-loss: 3.862999677658081
69 d-loss: 0.40257516503334045 g-loss: 2.6037850379943848
70 d-loss: 0.33036375045776367 g-loss: 3.1049559116363525
71 d-loss: 0.2422482967376709 g-loss: 3.0365424156188965
72 d-loss: 0.24604055285453796 g-loss: 3.5101194381713867
73 d-loss: 1.3223328590393066 g-loss: 9.184656143188477
74 d-loss: 0.20355640351772308 g-loss: 3.8176610469818115
75 d-loss: 0.1851392537355423 g-loss: 3.5737180709838867
76 d-loss: 0.23111382126808167 g-loss: 3.312542676925659
77 d-loss: 0.12925150990486145 g-loss: 3.7841544151306152
78 d-loss: 0.4086154103279114 g-loss: 2.4864935874938965
79 d-loss: 0.29721730947494507 g-loss: 2.7293453216552734
80 d-loss: 0.2804826498031616 g-loss: 3.2309751510620117
81 d-loss: 0.22704683244228363 g-loss: 3.60378360748291
82 d-loss: 0.21729540824890137 g-loss: 3.577629327774048
83 d-loss: 0.18626506626605988 g-loss: 4.590834617614746
84 d-loss: 0.39497512578964233 g-loss: 2.1493382453918457
85 d-loss: 0.3183228075504303 g-loss: 5.248997688293457
86 d-loss: 0.19268733263015747 g-loss: 4.473655700683594
87 d-loss: 0.2456638216972351 g-loss: 4.372949600219727
88 d-loss: 0.19136309623718262 g-loss: 4.760179042816162
89 d-loss: 0.22348923981189728 g-loss: 4.247585296630859
90 d-loss: 0.2525639533996582 g-loss: 5.267736434936523
91 d-loss: 0.22230832278728485 g-loss: 4.386148452758789
92 d-loss: 0.36075153946876526 g-loss: 3.1002907752990723
93 d-loss: 0.13224007189273834 g-loss: 4.696763038635254
94 d-loss: 0.32201671600341797 g-loss: 3.0803260803222656
95 d-loss: 0.3892339766025543 g-loss: 6.24675178527832
96 d-loss: 0.2373712956905365 g-loss: 3.386235475540161
97 d-loss: 0.28235626220703125 g-loss: 2.9006311893463135
98 d-loss: 0.40496787428855896 g-loss: 2.5861682891845703
99 d-loss: 0.23271213471889496 g-loss: 3.9647161960601807
100 d-loss: 0.21597206592559814 g-loss: 4.855806350708008
101 d-loss: 0.2240012288093567 g-loss: 3.9054088592529297
102 d-loss: 0.18842440843582153 g-loss: 3.8246123790740967
103 d-loss: 0.3447532653808594 g-loss: 5.665205478668213
104 d-loss: 0.5192641615867615 g-loss: 6.705690860748291
105 d-loss: 0.24415946006774902 g-loss: 3.7206287384033203
106 d-loss: 0.6034714579582214 g-loss: 2.0236997604370117
107 d-loss: 0.35751140117645264 g-loss: 4.795225143432617
108 d-loss: 0.3134361207485199 g-loss: 5.555920600891113
109 d-loss: 0.32808077335357666 g-loss: 8.050716400146484
110 d-loss: 0.6285152435302734 g-loss: 8.5546293258667
111 d-loss: 0.2469012439250946 g-loss: 4.087368011474609
112 d-loss: 0.20410647988319397 g-loss: 3.3828139305114746
113 d-loss: 0.21991196274757385 g-loss: 3.9452338218688965
114 d-loss: 0.17175406217575073 g-loss: 4.084678649902344
115 d-loss: 0.1731255054473877 g-loss: 4.145017623901367
116 d-loss: 0.1547868698835373 g-loss: 4.784377574920654
117 d-loss: 0.2667906880378723 g-loss: 3.4771580696105957
118 d-loss: 1.0230473279953003 g-loss: 9.154877662658691
119 d-loss: 0.12319549918174744 g-loss: 4.466263294219971
120 d-loss: 0.18622779846191406 g-loss: 4.449687480926514
121 d-loss: 0.1805429607629776 g-loss: 5.678531646728516
122 d-loss: 0.8003900647163391 g-loss: 1.3731383085250854
123 d-loss: 0.16159069538116455 g-loss: 4.971116065979004
124 d-loss: 0.3253968060016632 g-loss: 6.937915802001953
125 d-loss: 0.23539406061172485 g-loss: 4.205863952636719
126 d-loss: 0.1837971806526184 g-loss: 4.123740196228027
127 d-loss: 0.24565982818603516 g-loss: 4.757949352264404
128 d-loss: 0.37760159373283386 g-loss: 2.4897053241729736
129 d-loss: 0.6163845062255859 g-loss: 2.643826484680176
130 d-loss: 0.31684044003486633 g-loss: 4.014310836791992
131 d-loss: 0.1259973794221878 g-loss: 5.237534046173096
132 d-loss: 0.2750729024410248 g-loss: 6.2339630126953125
133 d-loss: 0.623859703540802 g-loss: 8.091939926147461
134 d-loss: 0.17348800599575043 g-loss: 3.39019775390625
135 d-loss: 0.2724202871322632 g-loss: 4.073886871337891
136 d-loss: 0.1955956220626831 g-loss: 4.553360462188721
137 d-loss: 0.31421488523483276 g-loss: 3.671926975250244
138 d-loss: 0.16272708773612976 g-loss: 5.542398929595947
139 d-loss: 0.18932181596755981 g-loss: 3.518281936645508
140 d-loss: 0.18914400041103363 g-loss: 3.4262115955352783
141 d-loss: 0.20608733594417572 g-loss: 3.9967246055603027
142 d-loss: 0.11989393830299377 g-loss: 5.275413990020752
143 d-loss: 0.22725379467010498 g-loss: 5.856490135192871
144 d-loss: 0.24864554405212402 g-loss: 3.8872880935668945
145 d-loss: 0.34549522399902344 g-loss: 7.933394908905029
146 d-loss: 0.217384934425354 g-loss: 3.0618345737457275
Process finished with exit code 0
下面是训练到100多个epoch的结果,但是基本几十个就有这样的结果了,效果不是特别好,少量存在坍塌现象,不过还有一个更为严重的问题是,零的数量有点多,可能是零所在的区域比例比较大
gan-one0.png gan-one1.png gan-one2.png gan-one3.png gan-one4.png gan-one5.png gan-one6.png gan-one7.png
gan-one8.png gan-one9.png gan-one10.png
对抗神经网络中的判断器,我们可以将其看作为一个不断升级的损失函数,生成器生成的图片越真实,损失函数值越小,进而让生成器不断向优化
网友评论