CTCloss 详解
CTC(Connectionist Temporal Classification)避免了输入与输出手动对齐,适合OCR或语音这样的序列应用;
给定输入序列,以及对应的标签数据.不一定相等 我们的工作是找到一个X到Y的映射.这种对时序数据进行分类算法叫做Temporal Classification。
- X和Y的长度都是变化且不相等的.
- 对于一个端到端的模型,我们并不想手动设计X和Y之间的对齐.
CTC提供了解决方案,对于一个给定的输入序列 ,CTC给出所有可能的 Y 的输出分布。根据这个分布,我们可以输出最可能的结果或者给出某个输出的概率。
loss:给定输入序列X,我们希望最大化Y的后验概率 ,应该是可导的,这样我们就能执行梯度下降算法进行优化.
1.1 对齐
- 几乎不可能将 X 的每个时间片都和输出Y对应上,例如OCR中字符的间隔,语音识别中的停顿;
- 不能处理有连续重复字符出现的情况,例如单词“APPLE”,按照上面的算法,输出的是“APLE”而非“APPLE”。
CTC的对齐涉及去除重复字母和去除 空白字符 两部分.
- 连续相同的字符做去重
- 去重空白字符
比如,对于长度为10的输入序列,以下RNN输出序列都可以映射为: apple
- _aappp_ple
- ap_p_|_ e
- _ _ app_ple_
则由2|l| + 1 = 2|l’|,比如:l=apple => l’=_a_p_p_l_e_
1.2 路径搜索与动态规划
image image image上图的路径搜索中:
- 为第t个时刻,gt字符串的第s个字符的路径前向概率.
- 为预测矩阵中第t时刻是第s个字符的概率.
- 为输入序列x,输出为l的概率,我们要最大化其概率
(1) 如果为空白符.则只能由前一个空白符或者其GT中该字符为上一时刻得到(因为我们是隔一个字符插入空白符,当前字符是空白的话,如果前一个也不是s的字符,就会错过GT中s字符,导致最后的path没法解析到GT,所以要么最多连续两个空白格,要么是前一个已经出现s字符,当前可以为空白)
- :s为当前gt的第s个字符,然后前一个为空白字符的概率
- :当前字符s连续出现的概率
- :前一个字符GT的上一个字符,当前步是GT的S个字符,代表s-1,s之间没有空白符,没有连续的字符的概率(因为每个字符都隔了一个空白符)
利用前向概率计算ctc 的loss
即 等于最小化对数域.
CTC 概率图前向概率:
在torch的ctc 前馈过程中,计算的log前向概率值的矩阵,(用以进行loss back),我们看到其核心:
- 每一个baych ,通过两层循环(T,S)动态规划计算前向概率Log值. 在计算的同时将同样需要计算的作为la1 ,la2,然后判断当前s的字符来决定第三个加项是否为
- 同时,因为转换到了对数域,也避免了数变小与溢出的问题,其项变成了
log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];
使用了pytorch中ctcloss 的源码
static inline int64_t get_target_prime(target_t *target, int64_t offset, int64_t stride, int64_t idx, int64_t BLANK)
if (idx % 2 == 0)
return BLANK;
return target[offset + stride * (idx / 2)];
Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2 * max_target_length + 1}, log_probs.options());
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
auto lpp = log_probs.permute({1, 0, 2});
auto log_probs_a_global = lpp.accessor<scalar_t, 3>();
auto log_alpha_a_global = log_alpha.accessor<scalar_t, 3>();
auto targets_data = targets.data_ptr<target_t>();
auto neg_log_likelihood_a = neg_log_likelihood.accessor<scalar_t, 1>();
// alpha calculation for the first row, the three equations for alpha_1 above eq (6)
// first the default
log_alpha.narrow(1, 0, 1).fill_(neginf);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (int64_t b = start; b < end; b++)
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
auto log_probs_a = log_probs_a_global[b];
auto log_alpha_a = log_alpha_a_global[b];
int64_t tg_batch_offset = tg_batch_offsets[b];
// the first two items of alpha_t above eq (6)
log_alpha_a[0][0] = log_probs_a[0][BLANK];
if (target_length > 0)
log_alpha_a[0][1] = log_probs_a[0][get_target_prime(targets_data, tg_batch_offset, tg_target_stride, 1, BLANK)];
// now the loop over the inputs
for (int64_t t = 1; t < input_length; t++)
for (int64_t s = 0; s < 2 * target_length + 1; s++)
auto current_target_prime = get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s, BLANK);
// this loop over s could be parallel/vectorized, too, but the required items are one index apart
// alternatively, one might consider moving s to the outer loop to cache current_target_prime more (but then it needs to be descending)
// for the cuda implementation, that gave a speed boost.
// This is eq (6) and (7), la1,2,3 are the three summands. We keep track of the maximum for the logsumexp calculation.
scalar_t la1 = log_alpha_a[t - 1][s];
scalar_t lamax = la1;
scalar_t la2, la3;
if (s > 0)
la2 = log_alpha_a[t - 1][s - 1];
if (la2 > lamax)
lamax = la2;
la2 = neginf;
if ((s > 1) && (get_target_prime(targets_data, tg_batch_offset, tg_target_stride, s - 2, BLANK) !=
la3 = log_alpha_a[t - 1][s - 2];
if (la3 > lamax)
lamax = la3;
la3 = neginf;
if (lamax == neginf) // cannot do neginf-neginf
lamax = 0;
// this is the assignment of eq (6)
log_alpha_a[t][s] = std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + log_probs_a[t][current_target_prime];
// the likelihood is the the sum of the last two alphas, eq (8), the loss is the negative log likelihood
if (target_length == 0)
// if the target is empty then there is no preceding BLANK state and hence there is no path to merge
neg_log_likelihood_a[b] = -log_alpha_a[input_length - 1][0];
scalar_t l1 = log_alpha_a[input_length - 1][target_length * 2];
scalar_t l2 = log_alpha_a[input_length - 1][target_length * 2 - 1];
scalar_t m = std::max(l1, l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m;
neg_log_likelihood_a[b] = -log_likelihood;
提供一个python 版本的numpy ctc的代码方便理解
import numpy as np
ninf = -np.float('inf')
def _logsumexp(a, b):
np.log(np.exp(a) + np.exp(b))
if a < b:
a, b = b, a
if b == ninf:
return a
return a + np.log(1 + np.exp(b - a))
def logsumexp(*args):
from scipy.special import logsumexp
res = args[0]
for e in args[1:]:
res = _logsumexp(res, e)
return res
class CTC:
def __init__(self):
def forward(self):
def alpha(self, log_y, labels):
##alpha 为前向概率
T, V = log_y.shape
L = len(labels)
log_alpha = np.ones([T, L]) * ninf
# init
## 初始化动态规划
log_alpha[0, 0] = log_y[0, labels[0]]
log_alpha[0, 1] = log_y[0, labels[1]]
for t in range(1, T):
for i in range(L):
s = labels[i]
a = log_alpha[t - 1, i]
if i - 1 >= 0:
a = logsumexp(a, log_alpha[t - 1, i - 1])
if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
a = logsumexp(a, log_alpha[t - 1, i - 2])
log_alpha[t, i] = a + log_y[t, s]
return log_alpha
def beta(self, log_y, labels):
T, V = log_y.shape
L = len(labels)
log_beta = np.ones([T, L]) * ninf
# init
log_beta[-1, -1] = log_y[-1, labels[-1]]
log_beta[-1, -2] = log_y[-1, labels[-2]]
for t in range(T - 2, -1, -1):
for i in range(L):
s = labels[i]
a = log_beta[t + 1, i]
if i + 1 < L:
a = logsumexp(a, log_beta[t + 1, i + 1])
if i + 2 < L and s != 0 and s != labels[i + 2]:
a = logsumexp(a, log_beta[t + 1, i + 2])
log_beta[t, i] = a + log_y[t, s]
return log_beta
def backward(selflog_y, labels):
T, V = log_y.shape
L = len(labels)
log_alpha = self.alpha(log_y, labels)
log_beta = self.beta(log_y, labels)
log_p = logsumexp(log_alpha[-1, -1], log_alpha[-1, -2])
log_grad = np.ones([T, V]) * ninf
for t in range(T):
for s in range(V):
lab = [i for i, c in enumerate(labels) if c == s]
for i in lab:
log_grad[t, s] = logsumexp(log_grad[t, s],
log_alpha[t, i] + log_beta[t, i])
log_grad[t, s] -= 2 * log_y[t, s]
log_grad -= log_p
return log_grad
def predict(self):
def ctc_prefix(self):
def ctc_beamsearch(self):
def alpha_vanilla(self, y, labels):
T, V = y.shape # T,time step, V: probs
L = len(labels) # label length
alpha = np.zeros([T, L])
# init
alpha[0, 0] = y[0, labels[0]]
alpha[0, 1] = y[0, labels[1]]
for t in range(1, T):
for i in range(L):
s = labels[i]
a = alpha[t - 1, i]
if i - 1 >= 0:
a += alpha[t - 1, i - 1]
if i - 2 >= 0 and s != 0 and s != labels[i - 2]:
a += alpha[t - 1, i - 2]
alpha[t, i] = a * y[t, s]
return alpha
def beta_vanilla(self, y, labels):
T, V = y.shape
L = len(labels)
beta = np.zeros([T, L])
# init
beta[-1, -1] = y[-1, labels[-1]]
beta[-1, -2] = y[-1, labels[-2]]
for t in range(T - 2, -1, -1):
for i in range(L):
s = labels[i]
a = beta[t + 1, i]
if i + 1 < L:
a += beta[t + 1, i + 1]
if i + 2 < L and s != 0 and s != labels[i + 2]:
a += beta[t + 1, i + 2]
beta[t, i] = a * y[t, s]
return beta
def gradient(self, y, labels):
T, V = y.shape
L = len(labels)
alpha = self.alpha_vanilla(y, labels)
beta = self.beta(y, labels)
p = alpha[-1, -1] + alpha[-1, -2]
grad = np.zeros([T, V])
for t in range(T):
for s in range(V):
lab = [i for i, c in enumerate(labels) if c == s]
for i in lab:
grad[t, s] += alpha[t, i] * beta[t, i]
grad[t, s] /= y[t, s] ** 2
grad /= p
return grad
def check_grad(y, labels, w=-1, v=-1, toleration=1e-3):
grad_1 = gradient(y, labels)[w, v]
delta = 1e-10
original = y[w, v]
y[w, v] = original + delta
alpha = forward(y, labels)
log_p1 = np.log(alpha[-1, -1] + alpha[-1, -2])
y[w, v] = original - delta
alpha = forward(y, labels)
log_p2 = np.log(alpha[-1, -1] + alpha[-1, -2])
y[w, v] = original
grad_2 = (log_p1 - log_p2) / (2 * delta)
if np.abs(grad_1 - grad_2) > toleration:
print('[%d, %d]:%.2e' % (w, v, np.abs(grad_1 - grad_2)))
def remove_blank(labels, blank=0):
new_labels = []
# combine duplicate
previous = None
for l in labels:
if l != previous:
previous = l
# remove blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def greedy_decode(y, blank=0):
raw_rs = np.argmax(y, axis=1)
rs = remove_blank(raw_rs, blank)
return raw_rs, rs
def beam_decode(y, beam_size=10):
T, V = y.shape
log_y = np.log(y)
beam = [([], 0)]
for t in range(T): # for every timestep
new_beam = []
for prefix, score in beam:
for i in range(V): # for every state
new_prefix = prefix + [i]
new_score = score + log_y[t, i]
new_beam.append((new_prefix, new_score))
# top beam_size
new_beam.sort(key=lambda x: x[1], reverse=True)
beam = new_beam[:beam_size]
return beam
def prefix_beam_decode(y, beam_size=10, blank=0):
T, V = y.shape
log_y = np.log(y)
beam = [(tuple(), (0, ninf))] # blank, non-blank
for t in range(T): # for every timestep
new_beam = defaultdict(lambda : (ninf, ninf))
for prefix, (p_b, p_nb) in beam:
for i in range(V): # for every state
p = log_y[t, i]
if i == blank: # propose a blank
new_p_b, new_p_nb = new_beam[prefix]
new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
else: # extend with non-blank
end_t = prefix[-1] if prefix else None
# exntend current prefix
new_prefix = prefix + (i,)
new_p_b, new_p_nb = new_beam[new_prefix]
if i != end_t:
new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
new_p_nb = logsumexp(new_p_nb, p_b + p)
new_beam[new_prefix] = (new_p_b, new_p_nb)
# keep current prefix
if i == end_t:
new_p_b, new_p_nb = new_beam[prefix]
new_p_nb = logsumexp(new_p_nb, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
# top beam_size
beam = sorted(new_beam.items(), key=lambda x : logsumexp(*x[1]), reverse=True)
beam = beam[:beam_size]
return beam