🚀 意图分类模型
本模型可对客户请求进行意图分类,通过微调T5模型,利用包含合成数据的提示,能动态地将客户请求分类到预定义的类别中。
🚀 快速开始
模型使用示例
class IntentClassifier:
def __init__(self, model_name="serj/intent-classifier", device="cuda"):
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.device = device
def build_prompt(text, prompt="", company_name="", company_specific=""):
if company_name == "Pizza Mia":
company_specific = "This company is a pizzeria place."
if company_name == "Online Banking":
company_specific = "This company is an online banking."
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
def predict(self, text, prompt_options, company_name, company_portion) -> str:
input_text = build_prompt(text, prompt_options, company_name, company_portion)
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
output = self.model.generate(input_ids)
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
return decoded_output
m = IntentClassifier("serj/intent-classifier")
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
"Products and subscriptions"))
提示结构说明
Topic %% Customer: text.
END MESSAGE
OPTIONS:
each class separated by %
Choose one topic that matches customer's issue.
Class name:
你必须在文本末尾加上句号,否则会得到奇怪的结果,这是模型的训练要求。
✨ 主要特性
本模型通过微调Flan - T5 - Base模型,利用包含合成数据的提示对客户请求进行意图分类,可动态地将客户请求分类到预定义的类别中。
📦 安装指南
文档未提供具体安装步骤,暂不展示。
📚 详细文档
模型详情
模型描述
这是一个🤗 transformers模型的模型卡片,已推送到Hub,此模型卡片是自动生成的。
- 开发者:Serj Smorodinsky
- 模型类型:Flan - T5 - Base
- 语言(NLP):[待补充更多信息]
- 许可证:[待补充更多信息]
- 微调基础模型:Flan - T5 - Base
模型来源
- 仓库地址:https://github.com/SerjSmor/intent_classification
训练详情
训练数据
- 训练数据仓库:https://github.com/SerjSmor/intent_classification
- 未来将添加HF数据集。
训练过程
- 训练脚本地址:https://github.com/SerjSmor/intent_classification/blob/main/t5_generator_trainer.py
- 使用HF trainer进行训练:
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
评估
最新版本的模型以少样本方式在2个合成数据集和clinc_oos的前41个类别上进行了微调。所有数据集每个类别有10 - 20个样本,训练数据不包括Atis数据集。
- Atis零样本测试集评估:加权F1值为87%
- 接下来将进行Clinc测试集评估。
硬件
Nvidia RTX3060 12Gb