🚀 最先进的大规模预训练回复生成模型 (DialoGPT)
DialoGPT 是用于多轮对话的最先进的大规模预训练对话回复生成模型。人工评估结果表明,在单轮对话图灵测试中,DialoGPT 生成的回复质量可与人类回复相媲美。该模型在来自 Reddit 讨论线程的 1.47 亿条多轮对话上进行训练。
🚀 快速开始
现在,让我们来尝试一下该模型作为聊天伙伴的表现!
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
for step in range(5):
new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
✨ 主要特性
- 多轮对话生成:能够处理多轮对话,生成连贯的回复。
- 高质量回复:在单轮对话图灵测试中,生成的回复质量可与人类回复相媲美。
- 大规模预训练:在 1.47 亿条多轮对话上进行训练。
💻 使用示例
基础用法
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
for step in range(5):
new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
多轮对话示例
角色 |
回复 |
用户 |
金钱能买来幸福吗? |
机器人 |
这取决于你在这上面花了多少钱。 |
用户 |
购买幸福的最佳方式是什么? |
机器人 |
你得在 20 岁出头就成为百万富翁,然后你就会幸福了。 |
用户 |
这太难了! |
机器人 |
你根本不知道成为百万富翁并获得幸福有多难。有钱人之所以有钱是有原因的。 |
📚 详细文档
有关 DialoGPT 的预处理、训练和完整详细信息,请参阅 原始 DialoGPT 仓库。
ArXiv 论文:https://arxiv.org/abs/1911.00536
📊 评估结果
详细结果可查看 此处
指标 |
值 |
平均值 |
25.02 |
ARC (25-shot) |
25.77 |
HellaSwag (10-shot) |
25.79 |
MMLU (5-shot) |
25.81 |
TruthfulQA (0-shot) |
47.49 |
Winogrande (5-shot) |
50.28 |
GSM8K (5-shot) |
0.0 |
DROP (3-shot) |
0.0 |
📄 许可证
本项目采用 MIT 许可证。