🚀 Intent Classification Model
This project focuses on intent classification, which can classify customers' requests into predefined categories. By fine - tuning a T5 model with prompts containing synthetic data, the model can classify intents dynamically.
🚀 Quick Start
Basic Usage
class IntentClassifier:
def __init__(self, model_name="serj/intent-classifier", device="cuda"):
self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.device = device
def build_prompt(text, prompt="", company_name="", company_specific=""):
if company_name == "Pizza Mia":
company_specific = "This company is a pizzeria place."
if company_name == "Online Banking":
company_specific = "This company is an online banking."
return f"Company name: {company_name} is doing: {company_specific}\nCustomer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n{prompt}\nClass name: "
def predict(self, text, prompt_options, company_name, company_portion) -> str:
input_text = build_prompt(text, prompt_options, company_name, company_portion)
input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
output = self.model.generate(input_ids)
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
return decoded_output
m = IntentClassifier("serj/intent-classifier")
print(m.predict("Hey, after recent changes, I want to cancel subscription, please help.",
"OPTIONS:\n refund\n cancel subscription\n damaged item\n return item\n", "Company",
"Products and subscriptions"))
✨ Features
- Dynamic Intent Classification: By adding all categories to the prompt, the model can classify intents in a dynamic way.
- Fine - Tuned Model: Based on the Flan - T5 - Base model, fine - tuned with synthetic data.
📦 Installation
No specific installation steps are provided in the original document, so this section is skipped.
📚 Documentation
Prompt Structure
Topic %% Customer: text.
END MESSAGE
OPTIONS:
each class separated by %
Choose one topic that matches customer's issue.
Class name:
You have to have a period after the end of the text, otherwise you'll get funky results. That's how the model was trained.
Model Details
Property |
Details |
Model Type |
Flan - T5 - Base |
Developed by |
Serj Smorodinsky |
Language(s) (NLP) |
[More Information Needed] |
License |
[More Information Needed] |
Finetuned from model |
Flan - T5 - Base |
Repository |
https://github.com/SerjSmor/intent_classification |
Training Details
Training Data
The training data is related to the project on GitHub: https://github.com/SerjSmor/intent_classification. HF dataset will be added in the future.
Training Procedure
The training uses the HF trainer. The relevant code is as follows:
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
Evaluation
The newest version of the model is finetuned on 2 synthetic datasets and 41 first classes of clinc_oos in a few shot manner. All datasets have 10 - 20 samples per class. Training data did not include Atis dataset.
- Atis zero shot test set evaluation: weighted F1 87%
- Clinc test set is next.
Hardware
Nvidia RTX3060 12Gb
⚠️ Important Note
You have to have a period after the end of the text in the prompt, otherwise you'll get funky results. That's how the model was trained.