🚀 推文垃圾郵件檢測模型
本模型用於將來自X(原Twitter)的推文分類為“垃圾郵件”(1)或“優質內容”(0),有效助力用戶識別推文質量,提升信息獲取效率。
🚀 快速開始
本模型可將來自X(原Twitter)的推文分類為“垃圾郵件”(1)或“優質內容”(0)。
✨ 主要特性
- 基於
FacebookAI/xlm - roberta - large
基礎模型,具有強大的特徵提取能力。
- 針對推文垃圾郵件檢測進行了微調,能有效區分垃圾推文和優質推文。
📦 安裝指南
文檔未提及具體安裝步驟,故跳過此章節。
💻 使用示例
基礎用法
def classify_texts(df, text_col, model_path="cja5553/xlm-roberta-Twitter-spam-classification", batch_size=24):
'''
Classifies texts as either "Quality" or "Spam" using a pre-trained sequence classification model.
Parameters:
-----------
df : pandas.DataFrame
DataFrame containing the texts to classify.
text_col : str
Name of the column in that contains the text data to be classified.
model_path : str, default="cja5553/xlm-roberta-Twitter-spam-classification"
Path to the pre-trained model for sequence classification.
batch_size : int, optional, default=24
Batch size for loading and processing data in batches. Adjust based on available GPU memory.
Returns:
--------
pandas.DataFrame
The original DataFrame with an additional column `spam_prediction`, containing the predicted labels ("Quality" or "Spam") for each text.
'''
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path).to("cuda")
model.eval()
df["text"] = df[text_col].astype(str)
text_dataset = Dataset.from_pandas(df)
def tokenize_function(example):
return tokenizer(
example["text"],
padding="max_length",
truncation=True,
max_length=512
)
text_dataset = text_dataset.map(tokenize_function, batched=True)
text_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
text_loader = DataLoader(text_dataset, batch_size=batch_size)
predictions = []
with torch.no_grad():
for batch in tqdm_notebook(text_loader):
input_ids = batch['input_ids'].to("cuda")
attention_mask = batch['attention_mask'].to("cuda")
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
preds = torch.argmax(logits, dim=-1).cpu().numpy()
predictions.extend(preds)
id2label = {0: "Quality", 1: "Spam"}
predicted_labels = [id2label[pred] for pred in predictions]
df["spam_prediction"] = predicted_labels
return df
spam_df_classification = classify_texts(df, "text_col")
print(spam_df_classification)
高級用法
文檔未提及高級用法相關代碼,故不展示此部分。
📚 詳細文檔
訓練數據集
該模型在UtkMl的Twitter垃圾郵件檢測數據集上進行了微調,使用FacebookAI/xlm - roberta - large
作為基礎模型。
指標
基於80 - 10 - 10的訓練 - 驗證 - 測試集劃分,在測試集上取得了以下結果:
屬性 |
詳情 |
準確率 |
0.974555 |
精確率 |
0.97457 |
召回率 |
0.97455 |
F1分數 |
0.97455 |
代碼
用於訓練這些模型的代碼可在GitHub上獲取:github.com/cja5553/Twitter_spam_detection
問題諮詢
如有問題,請聯繫:alba@wustl.edu
📄 許可證
本項目採用MIT許可證。