美文网首页
pytorch实现一个CVAE对话系统

pytorch实现一个CVAE对话系统

作者: 又双叒叕苟了一天 | 来源:发表于2020-01-10 11:30 被阅读0次

    CVAE模型结构

    CVAE模型结构

    如上图所示,CVAE模型在seq2seq的基础上多了一个先验网络,一个识别网络。在训练时,从先验网络中采样潜变量用于解码,而测试时从识别网络采样潜变量。这里不考虑图中的dialog act和bow预测,即图中蓝色和黄色部分。
    所以基本模块主要包括Embedding,Encoder,PriorNet,RecognizeNet,Decoder。

    Embedding

    import torch.nn as nn
    
    
    class Embedding(nn.Module):
        def __init__(self, num_vocab,
                     embedding_size,
                     pad_id=0,
                     dropout=0.1):
            super(Embedding, self).__init__()
            self.embedding = nn.Embedding(num_vocab, embedding_size, padding_idx=pad_id)
            self.dropout = nn.Dropout(p=dropout)
    
        def forward(self, x):  # [batch, seq]
            return self.dropout(self.embedding(x))  # [batch, seq, embedding_size]
    

    参数分别是词汇表大小,词嵌入维度,用于pad句子的符号在词汇表中的id和dropout的概率。主要就是封装了nn.Embedding模块。

    Encoder

    import torch
    import torch.nn as nn
    from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
    
    
    class Encoder(nn.Module):
        def __init__(self, cell_type,  # rnn类型
                     input_size,  # 输入维度
                     output_size,  # 输出维度
                     num_layers,  # rnn层数
                     bidirectional=False,  # 是否双向
                     dropout=0.1):  # dropout
            super(Encoder, self).__init__()
            assert cell_type in ['GRU', 'LSTM']  # 限定rnn类型
    
            if bidirectional:  # 如果双向
                assert output_size % 2 == 0
                cell_size = output_size // 2  # rnn维度
            else:
                cell_size = output_size
    
            self.bidirectional = bidirectional
            self.cell_type = cell_type
            self.rnn_cell = getattr(nn, cell_type)(input_size=input_size,
                                                   hidden_size=cell_size,
                                                   num_layers=num_layers,
                                                   bidirectional=bidirectional,
                                                   dropout=dropout)
    
        def forward(self, x,  # [seq, batch, dim]
                    length):  # [batch]
            x = pack_padded_sequence(x, length, enforce_sorted=False)
    
            # output: [seq, batch, dim*directions] 每个时间步的输出
            # final_state = [layers*directions, batch, dim] 每一层的最终状态
            output, final_state = self.rnn_cell(x)
            output = pad_packed_sequence(output)[0]
    
            if self.bidirectional:  # 如果是双向的,对双向进行拼接作为每层的最终状态
                if self.cell_type == 'GRU':
                    final_state_forward = final_state[0::2, :, :]  # [layers, batch, dim]
                    final_state_back = final_state[1::2, :, :]  # [layers, batch, dim]
                    final_state = torch.cat([final_state_forward, final_state_back], 2)  # [layers, batch, dim*2]
                else:
                    final_state_h, final_state_c = final_state
                    final_state_h = torch.cat([final_state_h[0::2, :, :], final_state_h[1::2, :, :]], 2)
                    final_state_c = torch.cat([final_state_c[0::2, :, :], final_state_c[1::2, :, :]], 2)
                    final_state = (final_state_h, final_state_c)
    
            # output = [seq, batch, dim]
            # final_state = [layers, batch, dim]
            return output, final_state
    

    具体参数都写在了注释中,其他值得注意的就是pack_padded_sequence和pad_packed_sequence的作用。pack_padded_sequence是将句子中的pad压缩,因为数据是按batch封装的,所有的输入都会用pad补齐到这个batch中最长句子的长度,这部分是没必要计算的,只需计算到pad之前的最后一个字符就可以了,通过pack_padded_sequence就会忽略这个的影响。传入的参数length就是这个batch中每句句子的长度,告诉每句句子需要计算多少的长度。另外,如果这个batch的所有句子都按长度进行排序(好像是逆序),enforce_sorted这个参数就可以设置为True来加快计算速度,否则就会报错,如果没排序直接设置False。pad_packed_sequence是一个反向的操作,返回值是一个包含2个值的元组,第一个就是需要的输出,第二个是句子的长度,也就是之前传进去的length又传了回来。通常取第一个值就可以了。
    如果是双向的编码器,需要将正向的状态和反向的状态做一个拼接作为最终的状态输出。

    先验网络

    import torch.nn as nn
    
    
    class PriorNet(nn.Module):
        r""" 计算先验概率p(z|x)的网络,x为解码器最后一步的输出 """
        def __init__(self, x_size,  # post编码维度
                     latent_size,  # 潜变量维度
                     dims):  # 隐藏层维度
            super(PriorNet, self).__init__()
            assert len(dims) >= 1  # 至少两层感知机
    
            dims = [x_size] + dims + [latent_size*2]
            dims_input = dims[:-1]
            dims_output = dims[1:]
    
            self.latent_size = latent_size
            self.mlp = nn.Sequential()
            for idx, (x, y) in enumerate(zip(dims_input[:-1], dims_output[:-1])):
                self.mlp.add_module(f'linear{idx}', nn.Linear(x, y))  # 线性层
                self.mlp.add_module(f'activate{idx}', nn.Tanh())  # 激活层
            self.mlp.add_module('output', nn.Linear(dims_input[-1], dims_output[-1]))
    
        def forward(self, x):  # [batch, x_size]
            predict = self.mlp(x)  # [batch, latent_size*2]
            mu, logvar = predict.split([self.latent_size]*2, 1)
            return mu, logvar
    

    先验网络本质上就是一个多层感知机目的是计算先验概率p(z|x)z的均值和log方差,因为先验分布通常假设为一个高斯分布\mathcal N(\mu,\sigma^2),包含两个参数\mu\sigma^2。那为啥不直接预测方差呢,那是因为从高斯分布中采样的操作是不可微的,需要通过重参数化实现,即采样\mathcal N(\mu,\sigma^2)等于从\mathcal N(0, 1)采样\varepsilon并计算\mu+\varepsilon\sigma,所以通常预测log方差比较方便计算。

    识别网络

    import torch
    import torch.nn as nn
    
    
    class RecognizeNet(nn.Module):
        r""" 计算后验概率p(z|x,y)的网络;x,y为编码器最后一步的输出 """
        def __init__(self, x_size,  # post编码维度
                     y_size,  # response编码维度
                     latent_size,  # 潜变量维度
                     dims):  # 隐藏层维度
            super(RecognizeNet, self).__init__()
            assert len(dims) >= 1  # 至少两层感知机
    
            dims = [x_size+y_size] + dims + [latent_size*2]
            dims_input = dims[:-1]
            dims_output = dims[1:]
    
            self.latent_size = latent_size
            self.mlp = nn.Sequential()
            for idx, (x, y) in enumerate(zip(dims_input[:-1], dims_output[:-1])):
                self.mlp.add_module(f'linear{idx}', nn.Linear(x, y))  # 线性层
                self.mlp.add_module(f'activate{idx}', nn.Tanh())  # 激活层
            self.mlp.add_module('output', nn.Linear(dims_input[-1], dims_output[-1]))
    
        def forward(self, x,  # [batch, x_size]
                    y):  # [batch, y_size]
            x = torch.cat([x, y], 1)  # [batch, x_size+y_size]
            predict = self.mlp(x)  # [batch, latent_size*2]
            mu, logvar = predict.split([self.latent_size]*2, 1)
            return mu, logvar
    

    识别网络本质上也是一个多层感知机,只是多加了回复编码产生的后验信息,和先验网络一起作为一个模块其实都是可以的。

    解码器

    import torch.nn as nn
    
    
    class Decoder(nn.Module):
        def __init__(self, cell_type,  # rnn类型
                     input_size,  # 输入维度
                     output_size,  # 输出维度
                     num_layer,  # rnn层数
                     dropout=0.1):  # dropout
            super(Decoder, self).__init__()
            assert cell_type in ['GRU', 'LSTM']  # 限定rnn类型
    
            self.cell_type = cell_type
            self.rnn_cell = getattr(nn, cell_type)(
                input_size=input_size,
                hidden_size=output_size,
                num_layers=num_layer,
                dropout=dropout)
    
        def forward(self, x,  # 输入 [seq, batch, dim] 或者单步输入 [1, batch, dim]
                    state):  # 初始状态 [layers*directions, batch, dim]
            # output: [seq, batch, dim*directions] 每个时间步的输出
            # final_state: [layers*directions, batch, dim] 每一层的最终状态
            output, final_state = self.rnn_cell(x, state)
            return output, final_state
    

    和编码器没什么区别。

    其余模块

    使用编码器最终状态和潜变量z初始化解码器初始状态

    import torch.nn as nn
    
    
    class PrepareState(nn.Module):
        r""" 准备解码器的初始状态,使用潜变量和编码器输入进行初始化 """
        def __init__(self, input_size,  # 用于初始化状态的向量维度
                     decoder_cell_type,  # 解码器类型
                     decoder_output_size,  # 解码器隐藏层大小
                     decoder_num_layers):  # 解码器层数
            super(PrepareState, self).__init__()
            assert decoder_cell_type in ['GRU', 'LSTM']
    
            self.decoder_cell_type = decoder_cell_type
            self.num_layers = decoder_num_layers
            self.linear = nn.Linear(input_size, decoder_output_size)
    
        def forward(self, x):  # [batch, dim]
            if self.num_layers > 1:
                states = self.linear(x).unsqueeze(0).repeat(self.num_layers, 1, 1)  # [num_layers, batch, output_size]
            else:
                states = self.linear(x).unsqueeze(0)
            if self.decoder_cell_type == 'LSTM':
                return states, states  # (h, c)
            else:
                return states
    

    就是将编码器的最后一层的最终状态和潜变量拼接传入一个线性网络,根据解码器的状态的维度进行多次的复制。

    整个模型

    各个部分的初始化

        def __init__(self, config):
            super(Model, self).__init__()
            self.config = config
    
            # 定义嵌入层
            self.embedding = Embedding(config.num_vocab,  # 词汇表大小
                                       config.embedding_size,  # 嵌入层维度
                                       config.pad_id,  # pad_id
                                       config.dropout)
    
            # post编码器
            self.post_encoder = Encoder(config.post_encoder_cell_type,  # rnn类型
                                        config.embedding_size,  # 输入维度
                                        config.post_encoder_output_size,  # 输出维度
                                        config.post_encoder_num_layers,  # rnn层数
                                        config.post_encoder_bidirectional,  # 是否双向
                                        config.dropout)  # dropout概率
    
            # response编码器
            self.response_encoder = Encoder(config.response_encoder_cell_type,
                                            config.embedding_size,  # 输入维度
                                            config.response_encoder_output_size,  # 输出维度
                                            config.response_encoder_num_layers,  # rnn层数
                                            config.response_encoder_bidirectional,  # 是否双向
                                            config.dropout)  # dropout概率
    
            # 先验网络
            self.prior_net = PriorNet(config.post_encoder_output_size,  # post输入维度
                                      config.latent_size,  # 潜变量维度
                                      config.dims_prior)  # 隐藏层维度
    
            # 识别网络
            self.recognize_net = RecognizeNet(config.post_encoder_output_size,  # post输入维度
                                              config.response_encoder_output_size,  # response输入维度
                                              config.latent_size,  # 潜变量维度
                                              config.dims_recognize)  # 隐藏层维度
    
            # 初始化解码器状态
            self.prepare_state = PrepareState(config.post_encoder_output_size+config.latent_size,
                                              config.decoder_cell_type,
                                              config.decoder_output_size,
                                              config.decoder_num_layers)
    
            # 解码器
            self.decoder = Decoder(config.decoder_cell_type,  # rnn类型
                                   config.embedding_size,  # 输入维度
                                   config.decoder_output_size,  # 输出维度
                                   config.decoder_num_layers,  # rnn层数
                                   config.dropout)  # dropout概率
    
            # 输出层
            self.projector = nn.Sequential(
                nn.Linear(config.decoder_output_size, config.num_vocab),
                nn.Softmax(-1)
            )
    

    需要注意的就是最后有个projector层,将解码器输出映射到词汇表维度,用于预测每个单词概率。

    定义前向传播

        def forward(self, inputs, inference=False, max_len=60, gpu=True):
            if not inference:  # 训练
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                id_responses = inputs['responses']  # [batch, seq]
                len_responses = inputs['len_responses']  # [batch, seq]
                sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
                len_decoder = id_responses.size(1) - 1
    
                embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
                embed_responses = self.embedding(id_responses)  # [batch, seq, embed_size]
                # state: [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
                _, state_responses = self.response_encoder(embed_responses.transpose(0, 1), len_responses)
                if isinstance(state_posts, tuple):
                    state_posts = state_posts[0]
                if isinstance(state_responses, tuple):
                    state_responses = state_responses[0]
                x = state_posts[-1, :, :]  # [batch, dim]
                y = state_responses[-1, :, :]  # [batch, dim]
    
                # p(z|x)
                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                # p(z|x,y)
                mu, logvar = self.recognize_net(x, y)  # [batch, latent]
                # 重参数化
                z = mu + (0.5 * logvar).exp() * sampled_latents  # [batch, latent]
    
                # 解码器的输入为回复去掉end_id
                decoder_inputs = embed_responses[:, :-1, :].transpose(0, 1)  # [seq-1, batch, embed_size]
                decoder_inputs = decoder_inputs.split([1] * len_decoder, 0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]
                first_state = self.prepare_state(torch.cat([z, x], 1))  # [num_layer, batch, dim_out]
    
                outputs = []
                for idx in range(len_decoder):
                    if idx == 0:
                        state = first_state  # 解码器初始状态
                    decoder_input = decoder_inputs[idx]  # 当前时间步输入 [1, batch, embed_size]
                    # output: [1, batch, dim_out]
                    # state: [num_layer, batch, dim_out]
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)
    
                outputs = torch.cat(outputs, 0).transpose(0, 1)  # [batch, seq-1, dim_out]
                output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]
    
                return output_vocab, _mu, _logvar, mu, logvar
            else:  # 测试
                id_posts = inputs['posts']  # [batch, seq]
                len_posts = inputs['len_posts']  # [batch]
                sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
                batch_size = id_posts.size(0)
    
                embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
                # state = [layers, batch, dim]
                _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
                if isinstance(state_posts, tuple):  # 如果是lstm则取h
                    state_posts = state_posts[0]  # [layers, batch, dim]
                x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]
    
                # p(z|x)
                _mu, _logvar = self.prior_net(x)  # [batch, latent]
                # 重参数化
                z = _mu + (0.5 * _logvar).exp() * sampled_latents  # [batch, latent]
    
                first_state = self.prepare_state(torch.cat([z, x], 1))  # [num_layer, batch, dim_out]
                done = torch.tensor([0] * batch_size).bool()
                first_input_id = (torch.ones((1, batch_size)) * self.config.start_id).long()
                if gpu:
                    done = done.cuda()
                    first_input_id = first_input_id.cuda()
    
                outputs = []
                for idx in range(max_len):
                    if idx == 0:  # 第一个时间步
                        state = first_state  # 解码器初始状态
                        decoder_input = self.embedding(first_input_id)  # 解码器初始输入 [1, batch, embed_size]
                    else:
                        decoder_input = self.embedding(next_input_id)  # [1, batch, embed_size]
                    # output: [1, batch, dim_out]
                    # state: [num_layers, batch, dim_out]
                    output, state = self.decoder(decoder_input, state)
                    outputs.append(output)
    
                    vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                    next_input_id = torch.argmax(vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]
    
                    _done = next_input_id.squeeze(0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                    done = done | _done  # 所有完成解码的
                    if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                        break
    
                outputs = torch.cat(outputs, 0).transpose(0, 1)  # [batch, seq, dim_out]
                output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]
    
                return output_vocab, _mu, _logvar, None, None
    

    id_posts是输入的id表示,len_posts是每个输入的长度,sampled_latents 是从标准正态分布中采样的潜变量。
    需要注意的有,输入和回复的编码表示x和y都是采用的lstm的短时记忆h(而不是长时记忆c)或gru的h,并且是最后一层的。这是采用的开头那张图片的论文里源码的做法。

    计算损失

    def compute_loss(outputs, labels, masks, global_step):
        def gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar):  # [batch, latent]
            """ 两个高斯分布之间的kl散度公式 """
            kld = 0.5 * torch.sum(prior_logvar - recog_logvar - 1
                                  + recog_logvar.exp() / prior_logvar.exp()
                                  + (prior_mu - recog_mu).pow(2) / prior_logvar.exp(), 1)
            return kld  # [batch]
    
        # output_vocab: [batch, len_decoder, num_vocab] 对每个单词的softmax概率
        output_vocab, _mu, _logvar, mu, logvar = outputs  # 先验的均值、log方差,后验的均值、log方差
    
        token_per_batch = masks.sum(1)  # 每个样本要计算损失的token数 [batch]
        len_decoder = masks.size(1)  # 解码长度
    
        output_vocab = output_vocab.reshape(-1, config.num_vocab)  # [batch*len_decoder, num_vocab]
        labels = labels.reshape(-1)  # [batch*len_decoder]
        masks = masks.reshape(-1)  # [batch*len_decoder]
    
        # nll_loss需要自己求log,它只是把label指定下标的损失取负并拿出来,reduction='none'代表只是拿出来,而不需要求和或者求均值
        _nll_loss = F.nll_loss(output_vocab.clamp_min(1e-12).log(), labels, reduction='none')  # 每个token的-log似然 [batch*len_decoder]
        _nll_loss = _nll_loss * masks  # 忽略掉不需要计算损失的token [batch*len_decoder]
    
        nll_loss = _nll_loss.reshape(-1, len_decoder).sum(1)  # 每个batch的nll损失 [batch]
    
        ppl = nll_loss / token_per_batch.clamp_min(1e-12)  # ppl的计算需要平均到每个有效的token上 [batch]
    
        # kl散度损失 [batch]
        kld_loss = gaussian_kld(mu, logvar, _mu, _logvar)
    
        # kl退火
        kld_weight = min(1.0 * (global_step % (2*config.kl_step)) / config.kl_step, 1)  # 周期性退火
    
        # 损失
        loss = nll_loss + kld_weight * kld_loss
    
        return loss, nll_loss, kld_loss, ppl, kld_weight
    

    需要注意的是nll损失只要计算一个batch的nll损失,而ppl的计算是要将一个batch的nll损失平均到每个需要计算字符上的。另外不要忘了乘上mask,忽略不要计算损失的字符例如pad的损失。
    github:https://github.com/Kirito0918/cvae-dialog

    相关文章

      网友评论

          本文标题:pytorch实现一个CVAE对话系统

          本文链接:https://www.haomeiwen.com/subject/nsdpactx.html