🚀 INTERPRESS新聞分類
本項目聚焦於INTERPRESS新聞分類,藉助特定數據集訓練模型,實現對新聞的精準分類,為新聞信息的高效處理提供了有力支持。
🚀 快速開始
本項目提供了使用Torch和Tensorflow進行新聞分類預測的方法,你可以根據自己的需求選擇合適的方式。
✨ 主要特性
- 真實數據集:使用從INTERPRESS下載的真實世界數據,經過篩選後使用了108K條數據進行模型訓練。
- 高準確率:模型在訓練數據和驗證數據上的準確率達到了97%。
- 多框架支持:支持Torch和Tensorflow兩種深度學習框架進行使用。
📦 安裝指南
Torch
pip install transformers or pip install transformers==4.3.3
Tensorflow
pip install transformers or pip install transformers==4.3.3
💻 使用示例
Torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("serdarakyol/interpress-turkish-news-classification")
model = AutoModelForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
import torch
if torch.cuda.is_available():
device = torch.device("cuda")
model = model.cuda()
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('GPU name is:', torch.cuda.get_device_name(0))
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
import numpy as np
def prediction(news):
news=[news]
indices=tokenizer.batch_encode_plus(
news,
max_length=512,
add_special_tokens=True,
return_attention_mask=True,
padding='max_length',
truncation=True,
return_tensors='pt')
inputs = indices["input_ids"].clone().detach().to(device)
masks = indices["attention_mask"].clone().detach().to(device)
with torch.no_grad():
output = model(inputs, token_type_ids=None,attention_mask=masks)
logits = output[0]
logits = logits.detach().cpu().numpy()
pred = np.argmax(logits,axis=1)[0]
return pred
news = r"ABD'den Prens Selman'a yaptırım yok Beyaz Saray Sözcüsü Psaki, Muhammed bin Selman'a yaptırım uygulamamanın \"doğru karar\" olduğunu savundu. Psaki, \"Tarihimizde, Demokrat ve Cumhuriyetçi başkanların yönetimlerinde diplomatik ilişki içinde olduğumuz ülkelerin liderlerine yönelik yaptırım getirilmemiştir\" dedi."
labels = {
0 : "Culture-Art",
1 : "Economy",
2 : "Politics",
3 : "Education",
4 : "World",
5 : "Sport",
6 : "Technology",
7 : "Magazine",
8 : "Health",
9 : "Agenda"
}
pred = prediction(news)
print(labels[pred])
Tensorflow
import tensorflow as tf
from transformers import BertTokenizer, TFBertForSequenceClassification
import numpy as np
tokenizer = BertTokenizer.from_pretrained('serdarakyol/interpress-turkish-news-classification')
model = TFBertForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
news = r"ABD'den Prens Selman'a yaptırım yok Beyaz Saray Sözcüsü Psaki, Muhammed bin Selman'a yaptırım uygulamamanın \"doğru karar\" olduğunu savundu. Psaki, \"Tarihimizde, Demokrat ve Cumhuriyetçi başkanların yönetimlerinde diplomatik ilişki içinde olduğumuz ülkelerin liderlerine yönelik yaptırım getirilmemiştir\" dedi."
inputs = tokenizer(news, return_tensors="tf")
inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(inputs)
loss = outputs.loss
logits = outputs.logits
pred = np.argmax(logits,axis=1)[0]
print(labels[pred])
📚 詳細文檔
數據集
數據集從INTERPRESS下載,屬於真實世界數據。實際上有273K條數據,但為了本模型的訓練,篩選後使用了108K條數據。有關數據集的更多信息,請訪問此鏈接。
模型
模型在訓練數據和驗證數據上的準確率為97%。數據按80%訓練和20%驗證的比例進行劃分。結果如下:
分類報告

混淆矩陣

📄 許可證
文檔中未提及許可證相關信息。
感謝 @yavuzkomecoglu 的貢獻。
如果您有任何問題,請隨時與我聯繫:
