🚀 最先進的大規模預訓練回覆生成模型 (DialoGPT)
DialoGPT 是用於多輪對話的最先進的大規模預訓練對話回覆生成模型。人工評估結果表明,在單輪對話圖靈測試中,DialoGPT 生成的回覆質量可與人類回覆相媲美。該模型在來自 Reddit 討論線程的 1.47 億條多輪對話上進行訓練。
🚀 快速開始
現在,讓我們來嘗試一下該模型作為聊天夥伴的表現!
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
for step in range(5):
new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
✨ 主要特性
- 多輪對話生成:能夠處理多輪對話,生成連貫的回覆。
- 高質量回復:在單輪對話圖靈測試中,生成的回覆質量可與人類回覆相媲美。
- 大規模預訓練:在 1.47 億條多輪對話上進行訓練。
💻 使用示例
基礎用法
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
for step in range(5):
new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
多輪對話示例
角色 |
回覆 |
用戶 |
金錢能買來幸福嗎? |
機器人 |
這取決於你在這上面花了多少錢。 |
用戶 |
購買幸福的最佳方式是什麼? |
機器人 |
你得在 20 歲出頭就成為百萬富翁,然後你就會幸福了。 |
用戶 |
這太難了! |
機器人 |
你根本不知道成為百萬富翁並獲得幸福有多難。有錢人之所以有錢是有原因的。 |
📚 詳細文檔
有關 DialoGPT 的預處理、訓練和完整詳細信息,請參閱 原始 DialoGPT 倉庫。
ArXiv 論文:https://arxiv.org/abs/1911.00536
📊 評估結果
詳細結果可查看 此處
指標 |
值 |
平均值 |
25.02 |
ARC (25-shot) |
25.77 |
HellaSwag (10-shot) |
25.79 |
MMLU (5-shot) |
25.81 |
TruthfulQA (0-shot) |
47.49 |
Winogrande (5-shot) |
50.28 |
GSM8K (5-shot) |
0.0 |
DROP (3-shot) |
0.0 |
📄 許可證
本項目採用 MIT 許可證。