CVAE模型结构
CVAE模型结构如上图所示,CVAE模型在seq2seq的基础上多了一个先验网络,一个识别网络。在训练时,从先验网络中采样潜变量用于解码,而测试时从识别网络采样潜变量。这里不考虑图中的dialog act和bow预测,即图中蓝色和黄色部分。
所以基本模块主要包括Embedding,Encoder,PriorNet,RecognizeNet,Decoder。
Embedding
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, num_vocab,
embedding_size,
pad_id=0,
dropout=0.1):
super(Embedding, self).__init__()
self.embedding = nn.Embedding(num_vocab, embedding_size, padding_idx=pad_id)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x): # [batch, seq]
return self.dropout(self.embedding(x)) # [batch, seq, embedding_size]
参数分别是词汇表大小,词嵌入维度,用于pad句子的符号在词汇表中的id和dropout的概率。主要就是封装了nn.Embedding模块。
Encoder
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class Encoder(nn.Module):
def __init__(self, cell_type, # rnn类型
input_size, # 输入维度
output_size, # 输出维度
num_layers, # rnn层数
bidirectional=False, # 是否双向
dropout=0.1): # dropout
super(Encoder, self).__init__()
assert cell_type in ['GRU', 'LSTM'] # 限定rnn类型
if bidirectional: # 如果双向
assert output_size % 2 == 0
cell_size = output_size // 2 # rnn维度
else:
cell_size = output_size
self.bidirectional = bidirectional
self.cell_type = cell_type
self.rnn_cell = getattr(nn, cell_type)(input_size=input_size,
hidden_size=cell_size,
num_layers=num_layers,
bidirectional=bidirectional,
dropout=dropout)
def forward(self, x, # [seq, batch, dim]
length): # [batch]
x = pack_padded_sequence(x, length, enforce_sorted=False)
# output: [seq, batch, dim*directions] 每个时间步的输出
# final_state = [layers*directions, batch, dim] 每一层的最终状态
output, final_state = self.rnn_cell(x)
output = pad_packed_sequence(output)[0]
if self.bidirectional: # 如果是双向的,对双向进行拼接作为每层的最终状态
if self.cell_type == 'GRU':
final_state_forward = final_state[0::2, :, :] # [layers, batch, dim]
final_state_back = final_state[1::2, :, :] # [layers, batch, dim]
final_state = torch.cat([final_state_forward, final_state_back], 2) # [layers, batch, dim*2]
else:
final_state_h, final_state_c = final_state
final_state_h = torch.cat([final_state_h[0::2, :, :], final_state_h[1::2, :, :]], 2)
final_state_c = torch.cat([final_state_c[0::2, :, :], final_state_c[1::2, :, :]], 2)
final_state = (final_state_h, final_state_c)
# output = [seq, batch, dim]
# final_state = [layers, batch, dim]
return output, final_state
具体参数都写在了注释中,其他值得注意的就是pack_padded_sequence和pad_packed_sequence的作用。pack_padded_sequence是将句子中的pad压缩,因为数据是按batch封装的,所有的输入都会用pad补齐到这个batch中最长句子的长度,这部分是没必要计算的,只需计算到pad之前的最后一个字符就可以了,通过pack_padded_sequence就会忽略这个的影响。传入的参数length就是这个batch中每句句子的长度,告诉每句句子需要计算多少的长度。另外,如果这个batch的所有句子都按长度进行排序(好像是逆序),enforce_sorted这个参数就可以设置为True来加快计算速度,否则就会报错,如果没排序直接设置False。pad_packed_sequence是一个反向的操作,返回值是一个包含2个值的元组,第一个就是需要的输出,第二个是句子的长度,也就是之前传进去的length又传了回来。通常取第一个值就可以了。
如果是双向的编码器,需要将正向的状态和反向的状态做一个拼接作为最终的状态输出。
先验网络
import torch.nn as nn
class PriorNet(nn.Module):
r""" 计算先验概率p(z|x)的网络,x为解码器最后一步的输出 """
def __init__(self, x_size, # post编码维度
latent_size, # 潜变量维度
dims): # 隐藏层维度
super(PriorNet, self).__init__()
assert len(dims) >= 1 # 至少两层感知机
dims = [x_size] + dims + [latent_size*2]
dims_input = dims[:-1]
dims_output = dims[1:]
self.latent_size = latent_size
self.mlp = nn.Sequential()
for idx, (x, y) in enumerate(zip(dims_input[:-1], dims_output[:-1])):
self.mlp.add_module(f'linear{idx}', nn.Linear(x, y)) # 线性层
self.mlp.add_module(f'activate{idx}', nn.Tanh()) # 激活层
self.mlp.add_module('output', nn.Linear(dims_input[-1], dims_output[-1]))
def forward(self, x): # [batch, x_size]
predict = self.mlp(x) # [batch, latent_size*2]
mu, logvar = predict.split([self.latent_size]*2, 1)
return mu, logvar
先验网络本质上就是一个多层感知机目的是计算先验概率中的均值和log方差,因为先验分布通常假设为一个高斯分布,包含两个参数和。那为啥不直接预测方差呢,那是因为从高斯分布中采样的操作是不可微的,需要通过重参数化实现,即采样等于从采样并计算,所以通常预测log方差比较方便计算。
识别网络
import torch
import torch.nn as nn
class RecognizeNet(nn.Module):
r""" 计算后验概率p(z|x,y)的网络;x,y为编码器最后一步的输出 """
def __init__(self, x_size, # post编码维度
y_size, # response编码维度
latent_size, # 潜变量维度
dims): # 隐藏层维度
super(RecognizeNet, self).__init__()
assert len(dims) >= 1 # 至少两层感知机
dims = [x_size+y_size] + dims + [latent_size*2]
dims_input = dims[:-1]
dims_output = dims[1:]
self.latent_size = latent_size
self.mlp = nn.Sequential()
for idx, (x, y) in enumerate(zip(dims_input[:-1], dims_output[:-1])):
self.mlp.add_module(f'linear{idx}', nn.Linear(x, y)) # 线性层
self.mlp.add_module(f'activate{idx}', nn.Tanh()) # 激活层
self.mlp.add_module('output', nn.Linear(dims_input[-1], dims_output[-1]))
def forward(self, x, # [batch, x_size]
y): # [batch, y_size]
x = torch.cat([x, y], 1) # [batch, x_size+y_size]
predict = self.mlp(x) # [batch, latent_size*2]
mu, logvar = predict.split([self.latent_size]*2, 1)
return mu, logvar
识别网络本质上也是一个多层感知机,只是多加了回复编码产生的后验信息,和先验网络一起作为一个模块其实都是可以的。
解码器
import torch.nn as nn
class Decoder(nn.Module):
def __init__(self, cell_type, # rnn类型
input_size, # 输入维度
output_size, # 输出维度
num_layer, # rnn层数
dropout=0.1): # dropout
super(Decoder, self).__init__()
assert cell_type in ['GRU', 'LSTM'] # 限定rnn类型
self.cell_type = cell_type
self.rnn_cell = getattr(nn, cell_type)(
input_size=input_size,
hidden_size=output_size,
num_layers=num_layer,
dropout=dropout)
def forward(self, x, # 输入 [seq, batch, dim] 或者单步输入 [1, batch, dim]
state): # 初始状态 [layers*directions, batch, dim]
# output: [seq, batch, dim*directions] 每个时间步的输出
# final_state: [layers*directions, batch, dim] 每一层的最终状态
output, final_state = self.rnn_cell(x, state)
return output, final_state
和编码器没什么区别。
其余模块
使用编码器最终状态和潜变量z初始化解码器初始状态
import torch.nn as nn
class PrepareState(nn.Module):
r""" 准备解码器的初始状态,使用潜变量和编码器输入进行初始化 """
def __init__(self, input_size, # 用于初始化状态的向量维度
decoder_cell_type, # 解码器类型
decoder_output_size, # 解码器隐藏层大小
decoder_num_layers): # 解码器层数
super(PrepareState, self).__init__()
assert decoder_cell_type in ['GRU', 'LSTM']
self.decoder_cell_type = decoder_cell_type
self.num_layers = decoder_num_layers
self.linear = nn.Linear(input_size, decoder_output_size)
def forward(self, x): # [batch, dim]
if self.num_layers > 1:
states = self.linear(x).unsqueeze(0).repeat(self.num_layers, 1, 1) # [num_layers, batch, output_size]
else:
states = self.linear(x).unsqueeze(0)
if self.decoder_cell_type == 'LSTM':
return states, states # (h, c)
else:
return states
就是将编码器的最后一层的最终状态和潜变量拼接传入一个线性网络,根据解码器的状态的维度进行多次的复制。
整个模型
各个部分的初始化
def __init__(self, config):
super(Model, self).__init__()
self.config = config
# 定义嵌入层
self.embedding = Embedding(config.num_vocab, # 词汇表大小
config.embedding_size, # 嵌入层维度
config.pad_id, # pad_id
config.dropout)
# post编码器
self.post_encoder = Encoder(config.post_encoder_cell_type, # rnn类型
config.embedding_size, # 输入维度
config.post_encoder_output_size, # 输出维度
config.post_encoder_num_layers, # rnn层数
config.post_encoder_bidirectional, # 是否双向
config.dropout) # dropout概率
# response编码器
self.response_encoder = Encoder(config.response_encoder_cell_type,
config.embedding_size, # 输入维度
config.response_encoder_output_size, # 输出维度
config.response_encoder_num_layers, # rnn层数
config.response_encoder_bidirectional, # 是否双向
config.dropout) # dropout概率
# 先验网络
self.prior_net = PriorNet(config.post_encoder_output_size, # post输入维度
config.latent_size, # 潜变量维度
config.dims_prior) # 隐藏层维度
# 识别网络
self.recognize_net = RecognizeNet(config.post_encoder_output_size, # post输入维度
config.response_encoder_output_size, # response输入维度
config.latent_size, # 潜变量维度
config.dims_recognize) # 隐藏层维度
# 初始化解码器状态
self.prepare_state = PrepareState(config.post_encoder_output_size+config.latent_size,
config.decoder_cell_type,
config.decoder_output_size,
config.decoder_num_layers)
# 解码器
self.decoder = Decoder(config.decoder_cell_type, # rnn类型
config.embedding_size, # 输入维度
config.decoder_output_size, # 输出维度
config.decoder_num_layers, # rnn层数
config.dropout) # dropout概率
# 输出层
self.projector = nn.Sequential(
nn.Linear(config.decoder_output_size, config.num_vocab),
nn.Softmax(-1)
)
需要注意的就是最后有个projector层,将解码器输出映射到词汇表维度,用于预测每个单词概率。
定义前向传播
def forward(self, inputs, inference=False, max_len=60, gpu=True):
if not inference: # 训练
id_posts = inputs['posts'] # [batch, seq]
len_posts = inputs['len_posts'] # [batch]
id_responses = inputs['responses'] # [batch, seq]
len_responses = inputs['len_responses'] # [batch, seq]
sampled_latents = inputs['sampled_latents'] # [batch, latent_size]
len_decoder = id_responses.size(1) - 1
embed_posts = self.embedding(id_posts) # [batch, seq, embed_size]
embed_responses = self.embedding(id_responses) # [batch, seq, embed_size]
# state: [layers, batch, dim]
_, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
_, state_responses = self.response_encoder(embed_responses.transpose(0, 1), len_responses)
if isinstance(state_posts, tuple):
state_posts = state_posts[0]
if isinstance(state_responses, tuple):
state_responses = state_responses[0]
x = state_posts[-1, :, :] # [batch, dim]
y = state_responses[-1, :, :] # [batch, dim]
# p(z|x)
_mu, _logvar = self.prior_net(x) # [batch, latent]
# p(z|x,y)
mu, logvar = self.recognize_net(x, y) # [batch, latent]
# 重参数化
z = mu + (0.5 * logvar).exp() * sampled_latents # [batch, latent]
# 解码器的输入为回复去掉end_id
decoder_inputs = embed_responses[:, :-1, :].transpose(0, 1) # [seq-1, batch, embed_size]
decoder_inputs = decoder_inputs.split([1] * len_decoder, 0) # 解码器每一步的输入 seq-1个[1, batch, embed_size]
first_state = self.prepare_state(torch.cat([z, x], 1)) # [num_layer, batch, dim_out]
outputs = []
for idx in range(len_decoder):
if idx == 0:
state = first_state # 解码器初始状态
decoder_input = decoder_inputs[idx] # 当前时间步输入 [1, batch, embed_size]
# output: [1, batch, dim_out]
# state: [num_layer, batch, dim_out]
output, state = self.decoder(decoder_input, state)
outputs.append(output)
outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq-1, dim_out]
output_vocab = self.projector(outputs) # [batch, seq-1, num_vocab]
return output_vocab, _mu, _logvar, mu, logvar
else: # 测试
id_posts = inputs['posts'] # [batch, seq]
len_posts = inputs['len_posts'] # [batch]
sampled_latents = inputs['sampled_latents'] # [batch, latent_size]
batch_size = id_posts.size(0)
embed_posts = self.embedding(id_posts) # [batch, seq, embed_size]
# state = [layers, batch, dim]
_, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
if isinstance(state_posts, tuple): # 如果是lstm则取h
state_posts = state_posts[0] # [layers, batch, dim]
x = state_posts[-1, :, :] # 取最后一层 [batch, dim]
# p(z|x)
_mu, _logvar = self.prior_net(x) # [batch, latent]
# 重参数化
z = _mu + (0.5 * _logvar).exp() * sampled_latents # [batch, latent]
first_state = self.prepare_state(torch.cat([z, x], 1)) # [num_layer, batch, dim_out]
done = torch.tensor([0] * batch_size).bool()
first_input_id = (torch.ones((1, batch_size)) * self.config.start_id).long()
if gpu:
done = done.cuda()
first_input_id = first_input_id.cuda()
outputs = []
for idx in range(max_len):
if idx == 0: # 第一个时间步
state = first_state # 解码器初始状态
decoder_input = self.embedding(first_input_id) # 解码器初始输入 [1, batch, embed_size]
else:
decoder_input = self.embedding(next_input_id) # [1, batch, embed_size]
# output: [1, batch, dim_out]
# state: [num_layers, batch, dim_out]
output, state = self.decoder(decoder_input, state)
outputs.append(output)
vocab_prob = self.projector(output) # [1, batch, num_vocab]
next_input_id = torch.argmax(vocab_prob, 2) # 选择概率最大的词作为下个时间步的输入 [1, batch]
_done = next_input_id.squeeze(0) == self.config.end_id # 当前时间步完成解码的 [batch]
done = done | _done # 所有完成解码的
if done.sum() == batch_size: # 如果全部解码完成则提前停止
break
outputs = torch.cat(outputs, 0).transpose(0, 1) # [batch, seq, dim_out]
output_vocab = self.projector(outputs) # [batch, seq, num_vocab]
return output_vocab, _mu, _logvar, None, None
id_posts是输入的id表示,len_posts是每个输入的长度,sampled_latents 是从标准正态分布中采样的潜变量。
需要注意的有,输入和回复的编码表示x和y都是采用的lstm的短时记忆h(而不是长时记忆c)或gru的h,并且是最后一层的。这是采用的开头那张图片的论文里源码的做法。
计算损失
def compute_loss(outputs, labels, masks, global_step):
def gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar): # [batch, latent]
""" 两个高斯分布之间的kl散度公式 """
kld = 0.5 * torch.sum(prior_logvar - recog_logvar - 1
+ recog_logvar.exp() / prior_logvar.exp()
+ (prior_mu - recog_mu).pow(2) / prior_logvar.exp(), 1)
return kld # [batch]
# output_vocab: [batch, len_decoder, num_vocab] 对每个单词的softmax概率
output_vocab, _mu, _logvar, mu, logvar = outputs # 先验的均值、log方差,后验的均值、log方差
token_per_batch = masks.sum(1) # 每个样本要计算损失的token数 [batch]
len_decoder = masks.size(1) # 解码长度
output_vocab = output_vocab.reshape(-1, config.num_vocab) # [batch*len_decoder, num_vocab]
labels = labels.reshape(-1) # [batch*len_decoder]
masks = masks.reshape(-1) # [batch*len_decoder]
# nll_loss需要自己求log,它只是把label指定下标的损失取负并拿出来,reduction='none'代表只是拿出来,而不需要求和或者求均值
_nll_loss = F.nll_loss(output_vocab.clamp_min(1e-12).log(), labels, reduction='none') # 每个token的-log似然 [batch*len_decoder]
_nll_loss = _nll_loss * masks # 忽略掉不需要计算损失的token [batch*len_decoder]
nll_loss = _nll_loss.reshape(-1, len_decoder).sum(1) # 每个batch的nll损失 [batch]
ppl = nll_loss / token_per_batch.clamp_min(1e-12) # ppl的计算需要平均到每个有效的token上 [batch]
# kl散度损失 [batch]
kld_loss = gaussian_kld(mu, logvar, _mu, _logvar)
# kl退火
kld_weight = min(1.0 * (global_step % (2*config.kl_step)) / config.kl_step, 1) # 周期性退火
# 损失
loss = nll_loss + kld_weight * kld_loss
return loss, nll_loss, kld_loss, ppl, kld_weight
需要注意的是nll损失只要计算一个batch的nll损失,而ppl的计算是要将一个batch的nll损失平均到每个需要计算字符上的。另外不要忘了乘上mask,忽略不要计算损失的字符例如pad的损失。
github:https://github.com/Kirito0918/cvae-dialog
网友评论