1. 举个🌰
def plot_attention(data, X_label=None, Y_label=None):
'''
Plot the attention model heatmap
Args:
data: attn_matrix with shape [ty, tx], cutted before 'PAD'
X_label: list of size tx, encoder tags
Y_label: list of size ty, decoder tags
'''
fig, ax = plt.subplots(figsize=(20, 8)) # set figure size
heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9)
# Set axis labels
if X_label != None and Y_label != None:
X_label = [x_label.decode('utf-8') for x_label in X_label]
Y_label = [y_label.decode('utf-8') for y_label in Y_label]
xticks = range(0,len(X_label))
ax.set_xticks(xticks, minor=False) # major ticks
ax.set_xticklabels(X_label, minor = False, rotation=45) # labels should be 'unicode'
yticks = range(0,len(Y_label))
ax.set_yticks(yticks, minor=False)
ax.set_yticklabels(Y_label, minor = False) # labels should be 'unicode'
ax.grid(True)
2. 参数
X_label
: 是encoder的句子一个一个word组成的list;
Y_label
: 是decoder的句子一个一个word组成的list
网友评论