WeNet中注意力重打分(attention rescoring decoding)()-其他
WeNet中注意力重打分(attention rescoring decoding)()
我们知道CTC是非自回归,而像transformer中解码是自回归的,所以transformer很大的一个缺陷就是解码速度慢。
在最近几年CTC和注意力机制联合训练得到的性能效果得到极大的提升,在训练过程中主要的操作就是将encoder的输出分别作为decoder的输入和CTC的输入,
通过两种不同的解码方式得到两个不同的损失函数,然后根据不同的权重大小将两个损失相加得到模型整体的损失函数,从而进行反向传播梯度更新参数。比如:
Loss = 0.3*lossctc + 0.6*lossattn
另外在这种联合训练的思想进一步发展,decoder在解码时标签输入可能语言模型这种强化作用,但是在打分过程中不存在这样的输入,所以WeNet中引用的这种重打分机制感觉就特别牛,
他的操作过程是encoder端的输出经过CTC prefix beam search,得到n个最好的结果,然后将这n个最好的结果作为decoder的标签输入,这样就完美解决了在打分阶段decoder端没有先验知识的问题。
当然在这里需要CTC得到的结果越好对最终的打分结果更好,这个_ctc_prefix_beam_search也是非常有意思的,他的搜索方法很有意思。
代码如下:
def _ctc_prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
torch.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
batch_size = speech.shape[0]
# For CTC prefix beam search, we only support batch_size=1
assert batch_size == 1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps, encoder_out
def ctc_prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> List[int]:
""" Apply CTC prefix beam search
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths,
beam_size, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming)
return hyps[0]
def attention_rescoring(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
ctc_weight: float = 0.0,
simulate_streaming: bool = False,
reverse_weight: float = 0.0,
) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
reverse_weight (float): right to left decoder weight
ctc_weight (float): ctc score weight
Returns:
List[int]: Attention rescoring result
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.device
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size
hyps_pad = pad_sequence([
torch.tensor(hyp[0], device=device, dtype=torch.long)
for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
ori_hyps_pad = hyps_pad
hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = torch.ones(beam_size,
1,
encoder_out.size(1),
dtype=torch.bool,
device=device)
# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out.cpu().numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out.cpu().numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
score += decoder_out[i][len(hyp[0])][self.eos]
# add right to left decoder score
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score
score += hyp[1] * ctc_weight
if score > best_score:
best_score = score
best_index = i
return hyps[best_index][0], best_score
我们知道CTC是非自回归,而像transformer中解码是自回归的,所以transformer很大的一个缺陷就是解码速度慢。
在最近几年CTC和注意力机制联合训练得到的性能效果得到极大的提升,在训练过程中主要的操作就是将encoder的输出分别作为decoder的输入和CTC的输入,
通过两种不同的解码方式得到两个不同的损失函数,然后根据不同的权重大小将两个损失相加得到模型整体的损失函数,从而进行反向传播梯度更新参数。比如:
Loss = 0.3*lossctc + 0.6*lossattn
另外在这种联合训练的思想进一步发展,decoder在解码时标签输入可能语言模型这种强化作用,但是在打分过程中不存在这样的输入,所以WeNet中引用的这种重打分机制感觉就特别牛,
他的操作过程是encoder端的输出经过CTC prefix beam search,得到n个最好的结果,然后将这n个最好的结果作为decoder的标签输入,这样就完美解决了在打分阶段decoder端没有先验知识的问题。
当然在这里需要CTC得到的结果越好对最终的打分结果更好,这个_ctc_prefix_beam_search也是非常有意思的,他的搜索方法很有意思。
代码如下:
def _ctc_prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
torch.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
batch_size = speech.shape[0]
# For CTC prefix beam search, we only support batch_size=1
assert batch_size == 1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps, encoder_out
def ctc_prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
) -> List[int]:
""" Apply CTC prefix beam search
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths,
beam_size, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming)
return hyps[0]
def attention_rescoring(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
beam_size: int,
decoding_chunk_size: int = -1,
num_decoding_left_chunks: int = -1,
ctc_weight: float = 0.0,
simulate_streaming: bool = False,
reverse_weight: float = 0.0,
) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
reverse_weight (float): right to left decoder weight
ctc_weight (float): ctc score weight
Returns:
List[int]: Attention rescoring result
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.device
batch_size = speech.shape[0]
# For attention rescoring we only support batch_size=1
assert batch_size == 1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
hyps, encoder_out = self._ctc_prefix_beam_search(
speech, speech_lengths, beam_size, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
assert len(hyps) == beam_size
hyps_pad = pad_sequence([
torch.tensor(hyp[0], device=device, dtype=torch.long)
for hyp in hyps
], True, self.ignore_id) # (beam_size, max_hyps_len)
ori_hyps_pad = hyps_pad
hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps],
device=device,
dtype=torch.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = torch.ones(beam_size,
1,
encoder_out.size(1),
dtype=torch.bool,
device=device)
# used for right to left decoder
r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id)
r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos,
self.ignore_id)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
decoder_out = decoder_out.cpu().numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out.cpu().numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
score += decoder_out[i][len(hyp[0])][self.eos]
# add right to left decoder score
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score
score += hyp[1] * ctc_weight
if score > best_score:
best_score = score
best_index = i
return hyps[best_index][0], best_score