🚀 クイックスタート
このモデルは、顧客の要求を異なる事前定義されたカテゴリに分類する意図分類を行います。時には意図分類はトピック分類とも呼ばれます。このモデルは、顧客の要求に似た合成データを含むプロンプトでFlan - T5 - Baseモデルをファインチューニングすることで、動的に意図を分類することができます。
✨ 主な機能
- 顧客の要求を事前定義されたカテゴリに分類する意図分類を行うことができます。
- 合成データを含むプロンプトでファインチューニングすることで、動的に意図を分類できます。
📦 インストール
このREADMEには具体的なインストール手順が記載されていないため、このセクションをスキップします。
💻 使用例
基本的な使用法
class IntentClassifier:
def __init__(self, model_name="serj/intent-classifier", device="cuda"):
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.device = device
def build_prompt(text, prompt="", company_name="", company_specific=""):
if company_name == "Pizza Mia":
company_specific = "This company is a pizzeria place."
if company_name == "Online Banking":
company_specific = "This company is an online banking."
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
def predict(self, text, prompt_options, company_name, company_portion) -> str:
input_text = build_prompt(text, prompt_options, company_name, company_portion)
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
output = self.model.generate(input_ids)
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
return decoded_output
m = IntentClassifier("serj/intent-classifier")
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
"Products and subscriptions"))
📚 ドキュメント
プロンプト構造
Topic %% Customer: text.
END MESSAGE
OPTIONS:
each class separated by %
Choose one topic that matches customer's issue.
Class name:
テキストの末尾にピリオドを付ける必要があります。そうしないと、予期しない結果になります。これはモデルの学習方法によるものです。
モデル詳細
- 開発者: Serj Smorodinsky
- モデルタイプ: Flan - T5 - Base
- 言語 (NLP): [詳細情報が必要]
- ライセンス: [詳細情報が必要]
- ファインチューニング元のモデル [オプション]: Flan - T5 - Base
- リポジトリ: https://github.com/SerjSmor/intent_classification
学習詳細
学習データ
https://github.com/SerjSmor/intent_classification
将来的にHFデータセットが追加されます。
学習手順
https://github.com/SerjSmor/intent_classification/blob/main/t5_generator_trainer.py
HFトレーナーを使用しています。
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
評価
このモデルの最新バージョンは、2つの合成データセットとclinc_oosの最初の41クラスをfew - shot方式でファインチューニングしています。すべてのデータセットはクラスごとに10 - 20サンプルを持っています。学習データにはAtisデータセットは含まれていません。
Atisゼロショットテストセットの評価: 重み付きF1 87%
次はClincテストセットです。
ハードウェア
Nvidia RTX3060 12Gb
情報テーブル
属性 |
詳情 |
モデルタイプ |
Flan - T5 - Base |
学習データ |
https://github.com/SerjSmor/intent_classification 。将来的にHFデータセットが追加されます。 |
重要提示
⚠️ 重要提示
テキストの末尾にはピリオドを付ける必要があります。そうしないと、予期しない結果になります。これはモデルの学習方法によるものです。