đ Zero-shot text classification (model based on albert-xxlarge-v2) trained with self-supervised tuning
This is a zero-shot text classification model trained with self-supervised tuning (SSTuning). It was introduced in the paper Zero-Shot Text Classification via Self-Supervised Tuning by Chaoqun Liu, Wenxuan Zhang, Guizhen Chen, Xiaobao Wu, Anh Tuan Luu, Chip Hong Chang, Lidong Bing and first released in this repository. The model backbone is albert-xxlarge-v2.
⨠Features
- Trained with self-supervised tuning (SSTuning) for zero-shot text classification.
- Utilizes first sentence prediction (FSP) as the learning objective.
- Three model variations are available with different backbones and performance characteristics.
đĻ Installation
No specific installation steps are provided in the original README.
đģ Usage Examples
Basic Usage
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, string, random
tokenizer = AutoTokenizer.from_pretrained("albert-xxlarge-v2")
model = AutoModelForSequenceClassification.from_pretrained("DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT")
text = "I love this place! The food is always so fresh and delicious."
list_label = ["negative", "positive"]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
list_ABC = [x for x in string.ascii_uppercase]
def check_text(model, text, list_label, shuffle=False):
list_label = [x+'.' if x[-1] != '.' else x for x in list_label]
list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))
if shuffle:
random.shuffle(list_label_new)
s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])
text = f'{s_option} {tokenizer.sep_token} {text}'
model.to(device).eval()
encoding = tokenizer([text],truncation=True, max_length=512,return_tensors='pt')
item = {key: val.to(device) for key, val in encoding.items()}
logits = model(**item).logits
logits = logits if shuffle else logits[:,0:len(list_label)]
probs = torch.nn.functional.softmax(logits, dim = -1).tolist()
predictions = torch.argmax(logits, dim=-1).item()
probabilities = [round(x,5) for x in probs[0]]
print(f'prediction: {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
print(f'probability: {round(probabilities[predictions]*100,2)}%')
check_text(model, text, list_label)
Advanced Usage
You can try the model with the Colab Notebook.
đ Documentation
Model description
The model is tuned with unlabeled data using a learning objective called first sentence prediction (FSP). The FSP task is designed by considering both the nature of the unlabeled corpus and the input/output format of classification tasks. The training and validation sets are constructed from the unlabeled corpus using FSP.
During tuning, BERT-like pre-trained masked language models such as RoBERTa and ALBERT are employed as the backbone, and an output layer for classification is added. The learning objective for FSP is to predict the index of the correct label. A cross-entropy loss is used for tuning the model.
Model variations
There are three versions of models released. The details are:
Property |
Details |
Model Type |
zero-shot-classify-SSTuning-base, zero-shot-classify-SSTuning-large, zero-shot-classify-SSTuning-ALBERT |
Backbone |
roberta-base, roberta-large, albert-xxlarge-v2 |
#params |
125M, 355M, 235M |
accuracy |
Low, Medium, High |
Speed |
High, Medium, Low |
Training Data |
20.48M, 5.12M, 5.12M |
Please note that zero-shot-classify-SSTuning-base is trained with more data (20.48M) than the paper, as this will increase the accuracy.
Intended uses & limitations
The model can be used for zero-shot text classification such as sentiment analysis and topic classification. No further finetuning is needed.
The number of labels should be 2 ~ 20.
đ§ Technical Details
The model is based on the self-supervised tuning (SSTuning) method, which uses first sentence prediction (FSP) as the learning objective. The FSP task is designed to leverage unlabeled data for training and validation. BERT-like pre-trained masked language models are used as the backbone, and an output layer for classification is added. A cross-entropy loss is used for tuning the model.
đ License
The model is released under the MIT license.
BibTeX entry and citation info
@inproceedings{acl23/SSTuning,
author = {Chaoqun Liu and
Wenxuan Zhang and
Guizhen Chen and
Xiaobao Wu and
Anh Tuan Luu and
Chip Hong Chang and
Lidong Bing},
title = {Zero-Shot Text Classification via Self-Supervised Tuning},
booktitle = {Findings of the Association for Computational Linguistics: ACL 2023},
year = {2023},
url = {https://arxiv.org/abs/2305.11442},
}