大模型的生成策略有常见的以下几种:
- 贪婪算法
- 采样
- topp采样
- topk采样
- topk – topp采样
- beam_search
1 贪婪算法
贪婪算法的本质就是从输出概率中找到概率最大的token作为生成token。
简易代码实现如下:
- 1 取输出token的概率
- 2 softmax,argmax
def greedy(logits):
"""
:param logits: (batch_size, seq_len, hidden_size)
:return: torch.long
"""
logit = logits[:, -1, :]
logit_probs = F.softmax(logit, dim=-1)
next_idx = torch.argmax(logit_probs, dim=-1)
return next_idx
2 topk算法
topk算法就是保留输出token概率topk个最大的token,在k中进行采样,采样结果作为生成token。
简易代码实现如下:
- 1 取输出token概率
- 2 softmax,topk
- 3 topk中进行采样
def topk(logits, k):
logit = logits[:, -1, :]
value, index = torch.topk(logit, k)
logit[logit < value[:, -1]] = -float('Inf')
logit_probs = F.softmax(logit, dim=-1)
next_idx = torch.multinomial(logit_probs, 1)
return next_idx
3 topp算法
topp算法和topk算法很像,但是topp算法是从累计概率大于p的token中进行采样作为生成token。
简易代码实现如下:
- 1 取输出token概率
- 2 softmax,sort,cumsum
- 3 累计概率小于p的置无穷大
- 4 采样
def topp(logits, p):
logit = logits[:, -1, :]
logit = F.softmax(logit, dim=-1)
sorted_logit, sorted_indice = torch.sort(logit, dim=-1, descending=True)
logit = torch.cumsum(sorted_logit, dim=-1)
logit[logit < p] = -float('Inf')
logit_probs = F.softmax(logit, dim=-1)
next_idx = torch.multinomial(logit_probs, 1)
return next_idx
5 beam_search算法
beam_search是贪婪算法的升级版,beam_search每次都会保留k个最大概率token计算序列概率,每次生成都保留k个最大概率序列。
即每一个序列都需要做k次的生成(总共k*k次生成),然后保留k个最大概率序列,最终遇到eos token,只输出最大概率的token序列。
由于beam_search最终的输出是最终时刻概率最大的token序列,遇到流式输出的时候,可能会出现时刻1概率最大的序列和时刻2概率最大的序列不同的情况,所以使用前瞻算法,即如果连续k次都是该序列概率最大,则最终就输出该条序列。
具体实现代码如下:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def beam_search_streaming_lookahead(model, tokenizer, input_text, beam_width=3, max_length=50, lookahead_steps=5):
input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids = input_ids.to(model.device)
# Initialize the beams
beams = [(input_ids, 0)]
for step in range(max_length):
new_beams = []
for beam_input_ids, beam_score in beams:
outputs = model(beam_input_ids)
next_token_logits = outputs.logits[:, -1, :]
probs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
top_k_probs, top_k_ids = torch.topk(probs, beam_width, dim=-1)
for i in range(beam_width):
next_token_id = top_k_ids[0, i].unsqueeze(0)
next_token_prob = top_k_probs[0, i].item()
new_input_ids = torch.cat([beam_input_ids, next_token_id.unsqueeze(0)], dim=-1)
new_score = beam_score + next_token_prob
new_beams.append((new_input_ids, new_score))
# Select the top `beam_width` beams
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
beams = new_beams
# Lookahead to determine the best beam for streaming output
if step >= lookahead_steps:
best_beam = beams[0][0]
output_text = tokenizer.decode(best_beam[0], skip_special_tokens=True)
print("Streaming output:", output_text)
# Check for end-of-sequence token
if tokenizer.eos_token_id in best_beam:
break
# Return the final best sequence
final_output_text = tokenizer.decode(best_beam[0], skip_special_tokens=True)
return final_output_text
if __name__ == "__main__":
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')
input_text = "Once upon a time"
beam_width = 5
max_length = 100
lookahead_steps = 5
final_output_text = beam_search_streaming_lookahead(model, tokenizer, input_text, beam_width, max_length, lookahead_steps)
print("Final generated text:", final_output_text)
6 model.generate
这里详细探讨一下model.generate中有关token生成的一些参数:
- 1 do_sample
- 2 topk
- 3 topp
- 4 repeate_penny
存在以下这些情况:
- 1 do_sample=False,未设置topk和topp,则使用贪婪算法
- 2 do_sample=True, 设置topp,则使用topp采样算法
- 3 do_sample=True, 设置topk,则使用topk采样算法
- 4 do_sample=True,设置topk和topp,则先使用topk,后使用topp算法
repeate_penny一般为1.0,如果需要设置重复惩罚,则将重复惩罚设置为大于1.0即可。
原理:如果该token前面已经出现过,则下一次该tokend的概率会除以重复惩罚。
7 常见推理模版
推理是有一套公式的:
- 1 加载model和tokenizer
- 2 tokenizer.encode
- 3 model.generate
- 4 tokenizer.decode_batch
简易代码如下:
input_ids = tokenizer.encode(message, return_tensors="pt").cuda()
response = model.generate(
input_ids, max_new_tokens=max_new_tokens, do_sample=True,
top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty
)
response = tokenizer.batch_decode(response)[0]
return response