美文网首页
运用深度学习写对联(一)

运用深度学习写对联(一)

作者: 远古胖达 | 来源:发表于2018-01-24 22:55 被阅读0次

Start

看了一篇文章,使用深度学习写对联。很应景。
自己也去做一下这个实验。
使用RNN,这种网络对于时序行数据有很好的效果。

数据

半人工半程序,搜集了4300条对联数据,未排重。
格式为一行一对,使用英文逗号和句号,如下:

骏马凌云敬业同圆华夏梦,灵羊送福挥毫共绘核心图.
马蹄踏上复兴路,羊角顶开幸福门.
丰年富足人欢笑,盛世平安犬不惊.
九州日月开春景,四海笙歌颂狗年.

使用一个类去存储字典数据,其中有两种不同的数据结构:

  1. 文字到文字编号的映射(map)
  2. 文字编号到文字的映射(array)
idx_to_char = list(set(couplet_chars))
char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])
vocab_size = len(char_to_idx)
couplet_indices = [char_to_idx[char] for char in couplet_chars]

输入网络的数据需要时定长向量,将couplet_indices中的文字编码转变成one-hot编码。

def get_inputs(data):
    return [nd.one_hot(X, vocab_size) for X in data.T]

定义网络

RNN的原理图如下:


image.png

这次使用的RNN中,存在一层隐藏层,当前层的输入由两部分组成,一是来自输入层,二是来着当前层的前一时刻的输出,即图中的隐藏状态。
所以我们需要的参数矩阵相对于普通深度网络,多了一个相对于隐藏状态的参数矩阵:W_hh

def get_params():
    # 隐含层
    W_xh = nd.random_normal(scale=std, shape=(input_dim, hidden_dim), ctx=ctx)
    W_hh = nd.random_normal(scale=std, shape=(hidden_dim, hidden_dim), ctx=ctx)
    b_h = nd.zeros(hidden_dim, ctx=ctx)

    # 输出层
    W_hy = nd.random_normal(scale=std, shape=(hidden_dim, output_dim), ctx=ctx)
    b_y = nd.zeros(output_dim, ctx=ctx)

    params = [W_xh, W_hh, b_h, W_hy, b_y]
    for param in params:
        param.attach_grad()
    return params

def rnn(inputs, state, *params):
    # num_steps: 采样步长
    # inputs: num_steps 个尺寸为 batch_size * vocab_size 矩阵。
    # H: 尺寸为 batch_size * hidden_dim 矩阵。
    # outputs: num_steps 个尺寸为 batch_size * vocab_size 矩阵。
    H = state
    W_xh, W_hh, b_h, W_hy, b_y = params
    outputs = []
    for X in inputs:
        H = nd.tanh(nd.dot(X, W_xh) + nd.dot(H, W_hh) + b_h)
        Y = nd.dot(H, W_hy) + b_y
        outputs.append(Y)
    return (outputs, H)

训练

给定初始时刻的输入,两个字,春,犬。

image.png

结束

至此结束,有一些对联还是不错的。
之后会详细分析RNN和本次的模型,以及如何评价模型的好坏。

相关文章

网友评论

      本文标题:运用深度学习写对联(一)

      本文链接:https://www.haomeiwen.com/subject/xvypaxtx.html