美文网首页
【Tool】Keras 基础学习 V 网络可视化

【Tool】Keras 基础学习 V 网络可视化

作者: ItchyHiker | 来源:发表于2018-09-19 20:23 被阅读0次

可视化网络每一层输出

5H324J.png
5H3oe3.png
5H3g36.png
5H3qRK.png
5H376B.png
5H3nMp.png
5H59SE.png
5H530G.png
5H55Pn.png

从网络每一层输出的激活可以看出随着网络加深,提取到的特征也越来月抽象。

可视化卷积核

不同卷积核对不同特征的敏感程度不同,通过可gradient descent 算法可以看卷积核对什么样的输入有最大的response,从而可视化卷积核。

5H58US.png
5H5nNO.png
5H5Snh.png
5H5gZA.png
5H5iWH.png
5H5xcN.png
5H52iu.png
5H5Rm9.png
5H5wbe.png
5H5q9q.png
5H93Kd.png
from keras.models import load_model
from keras.models import Model
from keras.models import Input
from keras import backend as K
from keras.preprocessing import image
from keras.applications.vgg16 import VGG16
import numpy as np

import matplotlib.pyplot as plt


model = load_model("cat_dog_classification/tuneModel.h5")
model.summary()

img_path = "/Users/yuhua.cheng/Documents/study/Keras/cat_dog_classification/data/test/9999.jpg"
img = image.load_img(img_path, target_size=(150, 150))
img_tensor = image.img_to_array(img)
img_tensor = np.expand_dims(img_tensor, axis=0)
img_tensor /= 255.

# visualize activations
layer_outputs = [layer.output for layer in model.layers[1:16]]
activation_model = Model(inputs=model.input, outputs=layer_outputs)

activations = activation_model.predict(x=img_tensor, batch_size=1)

layer_names = []
for layer in model.layers[1:16]:
    layer_names.append(layer.name)
images_per_row = 16
for layer_name, layer_activation in zip(layer_names, activations):
    n_features = layer_activation.shape[-1]
    size = layer_activation.shape[1]
    n_cols = n_features // images_per_row
    display_grid = np.zeros((size*n_cols, images_per_row*size))

    for col in range(n_cols):
        for row in range(images_per_row):
            channel_image = layer_activation[0,:,:,col*images_per_row + row]
            channel_image -= channel_image.mean()
            channel_image /= channel_image.std()
            channel_image *= 64
            channel_image += 128
            channel_image = np.clip(channel_image, 0, 255).astype('uint8')
            display_grid[col*size:(col+1)*size, row*size:(row+1)*size] = channel_image

    scale = 1./size
    plt.figure(figsize=(scale*display_grid.shape[1], scale*display_grid.shape[0]))
    plt.title(layer_name)
    plt.grid(False)
    plt.imshow(display_grid, aspect='auto', cmap='viridis')
    plt.savefig(layer_name)
plt.show()

# visualize kernels
def deprocess_image(x):
    x -= x.mean();
    x /= (x.std() + 1e-5)
    x *= 0.1

    x += 0.5
    x = np.clip(x, 0, 1)
    x *= 255
    x = np.clip(x, 0, 255).astype('uint8')
    # print("x:", x)
    return x

def generate_pattern(layer_name, filter_index, size=150):
    layer_output = model.get_layer(layer_name).output
    loss = K.mean(layer_output[:,:,:,filter_index])

    grads = K.gradients(loss, model.input)[0]
    grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-5)

    iterate = K.function([model.input], [loss, grads])

    input_img_data = np.random.random((1, size, size, 3))*20 + 128
    step = 1
    for i in range(40):
        loss_value, grads_value = iterate([input_img_data])
        input_img_data += grads_value*step

    img = input_img_data[0]
    # print("img:",img)
    return deprocess_image(img)

for layer_name in layer_names:
    size = 64
    margin = 5

    results = np.zeros((8*size + 7*margin, 8*size + 7*margin, 3), dtype='uint8')

    for i in range(8):
        for j in range(8):
            filter_img = generate_pattern(layer_name, i+(j*8), size=size)
            # print("filter_img:", filter_img)
            horizontal_start = i*size + i*margin
            horizontal_end = horizontal_start + size
            vertical_start = j*size + j*margin
            vertical_end = vertical_start + size
            results[horizontal_start: horizontal_end, vertical_start:vertical_end,:] = filter_img

    # print("sum of results:", np.sum(results))
    # print(results.dtype)
    plt.figure(figsize=(20,20))
    plt.imshow(results)
    plt.savefig(layer_name)
    
plt.show()

相关文章

网友评论

      本文标题:【Tool】Keras 基础学习 V 网络可视化

      本文链接:https://www.haomeiwen.com/subject/fiwpnftx.html