🚀 Cadet-Tinyとは?
Allen AIのCosmo-XLにインスパイアされたCadet-Tinyは、SODAデータセットを使って学習させた 非常に小さな 対話モデルです。Cadet-Tinyは、エッジデバイス(2GB RAMのRaspberry Piのような小型デバイス)での推論を想定しています。
Cadet-Tinyは、Googleのt5-small事前学習モデルをベースに学習させており、その結果、Cosmo-3Bモデルの約2%のサイズになっています。
これは私が初めて作ったSEQ2SEQのNLPモデルです!HuggingFaceでこれを共有できてとても嬉しいです!:)
何か質問や改善点に関するコメントがあれば、tcgoldfarb@gmail.comまでご連絡ください。
📦 インストール
このモデルを使用するには、以下の依存関係をインストールする必要があります。
pip install torch transformers colorful
💻 使用例
基本的な使用法
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import colorful as cf
cf.use_true_colors()
cf.use_style('monokai')
class CadetTinyAgent:
def __init__(self):
print(cf.bold | cf.purple("Waking up Cadet-Tiny..."))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained("t5-small", model_max_length=512)
self.model = AutoModelForSeq2SeqLM.from_pretrained("ToddGoldfarb/Cadet-Tiny", low_cpu_mem_usage=True).to(self.device)
self.conversation_history = ""
def observe(self, observation):
self.conversation_history = self.conversation_history + observation
if len(self.conversation_history) > 400:
self.conversation_history = self.conversation_history[112:]
def set_input(self, situation_narrative="", role_instruction=""):
input_text = "dialogue: "
if situation_narrative != "":
input_text = input_text + situation_narrative
if role_instruction != "":
input_text = input_text + " <SEP> " + role_instruction
input_text = input_text + " <TURN> " + self.conversation_history
return input_text
def generate(self, situation_narrative, role_instruction, user_response):
user_response = user_response + " <TURN> "
self.observe(user_response)
input_text = self.set_input(situation_narrative, role_instruction)
inputs = self.tokenizer([input_text], return_tensors="pt").to(self.device)
outputs = self.model.generate(inputs["input_ids"], max_new_tokens=512, temperature=0.75, top_p=.95,
do_sample=True)
cadet_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
added_turn = cadet_response + " <TURN> "
self.observe(added_turn)
return cadet_response
def reset_history(self):
self.conversation_history = []
def run(self):
def get_valid_input(prompt, default):
while True:
user_input = input(prompt)
if user_input in ["Y", "N", "y", "n"]:
return user_input
if user_input == "":
return default
while True:
continue_chat = ""
situation_narrative = "Imagine you are Cadet-Tiny talking to ???."
role_instruction = "You are Cadet-Tiny, and you are talking to ???."
self.chat(situation_narrative, role_instruction)
continue_chat = get_valid_input(cf.purple("Start a new conversation with new setup? [Y/N]:"), "Y")
if continue_chat in ["N", "n"]:
break
print(cf.blue("CT: See you!"))
def chat(self, situation_narrative, role_instruction):
print(cf.green(
"Cadet-Tiny is running! Input [RESET] to reset the conversation history and [END] to end the conversation."))
while True:
user_input = input("You: ")
if user_input == "[RESET]":
self.reset_history()
print(cf.green("[Conversation history cleared. Chat with Cadet-Tiny!]"))
continue
if user_input == "[END]":
break
response = self.generate(situation_narrative, role_instruction, user_input)
print(cf.blue("CT: " + response))
def main():
print(cf.bold | cf.blue("LOADING MODEL"))
CadetTiny = CadetTinyAgent()
CadetTiny.run()
if __name__ == '__main__':
main()
📚 ドキュメント
Google Colabリンク
こちらがGoogle Colabファイルのリンクです。ここでは、モデルの学習プロセスと、AI2のSODA公開データセットの使用方法を解説しています。
https://colab.research.google.com/drive/1cx3Yujr_jGQkseqzXZW-2L0vEyEjds_s?usp=sharing
📄 ライセンス
このモデルは、OpenRAILライセンスの下で公開されています。
📖 引用と特別な感謝
Hyunwoo Kim氏には、SODAデータセットの最適な使用方法について議論してくれたことに特別な感謝を申し上げます。彼らのSODA、Prosocial-Dialog、またはCOSMOに関する研究を見ていない場合は、是非見てみることをおすすめします!また、SODAに関する論文も読んでみてください。論文は以下に示します。
@article{kim2022soda,
title={SODA: Million-scale Dialogue Distillation with Social Commonsense Contextualization},
author={Hyunwoo Kim and Jack Hessel and Liwei Jiang and Peter West and Ximing Lu and Youngjae Yu and Pei Zhou and Ronan Le Bras and Malihe Alikhani and Gunhee Kim and Maarten Sap and Yejin Choi},
journal={ArXiv},
year={2022},
volume={abs/2212.10465}
}