美文网首页
改进版的fashion-mnist DCGAN

改进版的fashion-mnist DCGAN

作者: 圣_狒司机 | 来源:发表于2019-08-16 12:45 被阅读0次

    在loss函数写法上做改进,代码更简单;

    import tensorflow as tf
    from tensorflow.keras.datasets.fashion_mnist import load_data
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense,Reshape,Conv2DTranspose,Conv2D,MaxPool2D,Flatten,BatchNormalization
    import numpy as np
    import matplotlib.pyplot as plt
    
    (train_x,train_y),(test_x,test_y) = load_data()
    train_x = train_x[:150]/255
    x_real,y_real = zip(*zip(train_x,np.ones(train_x.shape[0])))
    
    g = Sequential([Dense(4*4*128,input_shape=(10,)),
                    Reshape((4,4,128)),
                    Conv2DTranspose(64,(4,4),padding="valid",activation="relu"),
                    BatchNormalization(),
                    Conv2DTranspose(32,(2,2),strides=(2, 2),padding="same",activation="relu"),
                    BatchNormalization(),
                    Conv2DTranspose(1, (2,2),strides=(2, 2),padding="same",activation="tanh"),
                    Reshape((28,28))])
    
    d = Sequential([Reshape((28,28,1),input_shape=(28,28)),
                    Conv2D(32,(2,2),padding="same",activation="relu"),
                    MaxPool2D((2,2)),
                    Conv2D(64,(2,2),padding="same",activation="relu"),
                    MaxPool2D((2,2)),
                    Conv2D(64,(2,2),padding="valid",activation="relu"),
                    MaxPool2D((2,2)),
                    Flatten(),
                    Dense(1,activation="sigmoid")])
    gan = Sequential([g,d])
    d.compile(optimizer="adam",loss="binary_crossentropy",metrics=['accuracy'])
    
    for i in range(50):
        print(f"===============判别器第{i+1}轮训练================")
        d.trainable = True
        x_fake,y_fake = zip(*zip(g(tf.random.uniform((train_x.shape[0],10),1,0)),np.zeros(train_x.shape[0])))
        x = x_real + x_fake
        y = y_real + y_fake
        dataset = tf.data.Dataset.from_tensor_slices((np.array(x),np.array(y))).shuffle(150).batch(20)
        d.fit(dataset,epochs=2)
        
        print(f"===============生成器第{i+1}轮训练================")
        d.trainable = False
        gan.compile(optimizer="adam",loss="binary_crossentropy")
        x = tf.random.uniform((100,10),1,0)
        y = 1-d(g(x))
        gan.fit(x,y,epochs=50)
    
    img = g(tf.random.uniform((1,10),1,0))[0]
    plt.imshow(img)
    plt.show()
    

    训练20个来回,能看到可识别的效果。
    可以调节加载数据量,当然越多越慢!

    相关文章

      网友评论

          本文标题:改进版的fashion-mnist DCGAN

          本文链接:https://www.haomeiwen.com/subject/lhmfsctx.html