BeamSearch算法原理及代码解析

1.算法原理

beam search有一个超参数beam_size,设为 k 。第一个时间步长,选取当前条件概率最大的 k 个词,当做候选输出序列的第一个词。之后的每个时间步长,基于上个步长的输出序列,挑选出所有组合中条件概率最大的 k 个,作为该时间步长下的候选输出序列。始终保持 k 个候选。最后从k 个候选中挑出最优的。

2.中心思想

假设有n句话,每句话的长度为T。encoder的输出shape为(n, T, hidden_dim),扩展成(n*beam_size, T, hidden_dim)。decoder第一次输入shape为(n, 1),扩展到(n*beam_size, 1)。经过一次解码,输出得分的shape为(n*beam_size, vocab_size),路径得分log_prob的shape为(n*beam_size, 1),两者相加得到当前帧的路径得分。reshape到(n, beam_size*vocab_size),取topk(beam_size),得到排序后的索引(n, beam_size),索引除以vocab_size,得到的是每句话的beam_id,用来获取当前路径前一个字;对vocab_size取余,得到的是每句话的token_id,用来获取当前路径下一次字。

3.代码解析

def beam_search():k_prev_words = torch.full((k, 1), SOS_TOKEN, dtype=torch.long) # (k, 1)# 此时输出序列中只有sos tokenseqs = k_prev_words #当前路径(k, 1)# 初始化scores向量为0top_k_scores = torch.zeros(k, 1)complete_seqs = [] #已完成序列complete_seqs_scores = [] #已完成序列的得分step = 1hidden = torch.zeros(1, k, hidden_size) # encoder的输出: (1, k, hidden_size)while True:outputs, hidden = decoder(k_prev_words, hidden) # outputs: (k, seq_len, vocab_size)next_token_logits = outputs[:,-1,:] # (k, vocab_size)if step == 1:# 因为最开始解码的时候只有一个结点,所以只需要取其中一个结点计算topktop_k_scores, top_k_words = next_token_logits[0].topk(k, dim=0, largest=True, sorted=True)else:# 此时要先展开再计算topk,如上图所示。# top_k_scores: (k) top_k_words: (k)top_k_scores, top_k_words = next_token_logits.view(-1).topk(k, 0, True, True)prev_word_inds = top_k_words / vocab_size  # (k)  实际是beam_id,哪个beam就是哪条最优路径next_word_inds = top_k_words % vocab_size  # (k)  实际是token_id,在单词表中的index# seqs: (k, step) ==> (k, step+1)seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)# 当前输出的单词不是eos的有哪些(输出其在next_wod_inds中的位置, 实际是beam_id)incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) ifnext_word != vocab['']]# 输出已经遇到eos的句子的beam id(即seqs中的句子索引)complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))if len(complete_inds) > 0:complete_seqs.extend(seqs[complete_inds].tolist()) # 加入句子complete_seqs_scores.extend(top_k_scores[complete_inds]) # 加入句子对应的累加log_prob# 减掉已经完成的句子的数量,更新k, 下次就不用执行那么多topk了,因为若干句子已经被解码出来了k -= len(complete_inds) if k == 0: # 完成break# 更新下一次迭代数据, 仅专注于那些还没完成的句子 seqs = seqs[incomplete_inds]hidden = hidden[prev_word_inds[incomplete_inds]]top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)   #(s, 1) s < kk_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) #(s, 1) s < kif step > max_length: # decode太长后,直接break掉breakstep += 1i = complete_seqs_scores.index(max(complete_seqs_scores)) # 寻找score最大的序列# 有些许问题,在训练初期一直碰不到eos时,此时complete_seqs为空seq = complete_seqs[i] return seq


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部