Huggingface model.generate如何既生成Greedy Search又生成Sampling的输出.md

太久没更新博客文章了,今天遇到了一个问题,我需要同时生成Greedy Search和Sampling的输出,但是Huggingface的generate方法只能生成其中一种。

如果调用两次generate,就多花了时间。网上查不到相关的资料,只好自己写一个,顺便更新一下博客。

首先看huggingface文档

并没有提供我们要的功能。

然后查看了一下Generation Config,也并不能传入一个List

但联想到 Greedy Search 其实是一个特殊的 Sampling,只是 Logits 在非最大值处被设置为负无穷。所以答案就呼之欲出了,使用LogitsProcessor来调整第一个序列的 Logits。

1
2
3
4
5
6
7
8
9
10
11
12
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
class MaxScoreLogitsProcessor(LogitsProcessor):
def __init__(self):
super().__init__()

def __call__(self, input_ids, scores) -> torch.Tensor:
# 找出scores[0]中最大的位置
max_idx = torch.argmax(scores[0])
# 将其他位置的logits全部置为负无穷
scores[0][0:max_idx] = -float("inf")
scores[0][max_idx+1:] = -float("inf")
return scores

以下是实验,使用Meta-Llama-3-8B-Instruct模型。

正常的贪婪搜索:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def get_response(prompt):
generation_config = dict(
do_sample=False,
max_new_tokens=100,
return_dict_in_generate=True,
output_hidden_states=True,
output_scores = True,
num_beams=1)

input_ids = llmtokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
)
input_ids = torch.tensor(input_ids).unsqueeze(0).to(model.device)

outputs = model.generate(
input_ids,
pad_token_id=llmtokenizer.eos_token_id,
attention_mask=input_ids.ne(llmtokenizer.eos_token_id),
**generation_config
)
response = llmtokenizer.decode(outputs.sequences[0][input_ids.shape[-1]:], skip_special_tokens=True)
return response

贪婪搜索+采样(返回的第一个元素是贪婪搜索的结果):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
logits_processor = LogitsProcessorList([MaxScoreLogitsProcessor()])
def get_response2(prompt):
generation_config = dict(
do_sample=True,
max_new_tokens=100,
return_dict_in_generate=True,
output_hidden_states=True,
output_scores = True,
top_p=0.95,
temperature=1.2,
num_return_sequences=3)

input_ids = llmtokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
)
input_ids = torch.tensor(input_ids).unsqueeze(0).to(model.device)

outputs = model.generate(
input_ids,
pad_token_id=llmtokenizer.eos_token_id,
attention_mask=input_ids.ne(llmtokenizer.eos_token_id),
logits_processor=logits_processor,
**generation_config
)
responses = []
for i in range(len(outputs.sequences)):
response = llmtokenizer.decode(outputs.sequences[i][input_ids.shape[-1]:], skip_special_tokens=True)
responses.append(response)
return responses

测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
questions = [
"How to make a chatbot?",
"How to make a cake?",
"Why is the sky blue?"]

for q in questions:
print(q)
print("====================================")
print("Greedy:")
print(get_response(q))
print("====================================")
print("Sampling with max score:")
rs = get_response2(q)
print(rs[0])
print("====================================")
print("Sampling:")
print(rs[1])
print("====================================")
print("Sampling:")
print(rs[2])
print("====================================")

结果非常完美(只放了第一个问题的输出):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
How to make a chatbot?
====================================
Greedy:
Making a chatbot involves several steps, including planning, designing, building, and testing. Here's a step-by-step guide to help you get started:

**Step 1: Plan and Define Your Chatbot's Purpose**

1. Identify the purpose of your chatbot: What do you want your chatbot to do? (e.g., answer customer queries, provide product information, book appointments)
2. Define your target audience: Who will be interacting with your chatbot? (e.g.,
====================================
Sampling with max score:
Making a chatbot involves several steps, including planning, designing, building, and testing. Here's a step-by-step guide to help you get started:

**Step 1: Plan and Define Your Chatbot's Purpose**

1. Identify the purpose of your chatbot: What do you want your chatbot to do? (e.g., answer customer queries, provide product information, book appointments)
2. Define your target audience: Who will be interacting with your chatbot? (e.g.,
====================================
Sampling:
Creating a chatbot involves several steps, from defining the chatbot's functionality to deploying it in a platform. Here's a step-by-step guide to help you get started:

1. **Define the chatbot's purpose and functionality**:
* Determine the chatbot's objective (e.g., provide customer support, answer questions, or entertain).
* Identify the user interactions and goals you want to support (e.g., booking a flight, making a payment, or asking a question).
2
====================================
Sampling:
Creating a chatbot involves several steps, from planning and designing to developing and testing. Here's a general guide to help you get started:

**Planning and Designing (1-3 days)**

1. **Define your chatbot's purpose**: Determine what your chatbot will do, such as answering customer questions, providing product information, or entertaining users.
2. **Identify your target audience**: Who will be using your chatbot? Understanding your audience will help you design a chatbot that
====================================

当然,这个方法存在的问题是,如果随机种子相同,采样的结果可能会和分别调用两次generate得到的结果不同,但也不一定,取决于具体的实现,这个问题我不关注,所以没有继续实验。

另外,Jupyter Notebook初步估了一下上述测试,24.2s,去掉get_response的时间则是12.5s,说明没有引进什么开销。

睡觉!


Huggingface model.generate如何既生成Greedy Search又生成Sampling的输出.md
https://bebr2.com/2024/11/26/Huggingface model.generate如何既生成Greedy Search又生成Sampling的输出.md/
作者
BeBr2
发布于
2024年11月26日
许可协议