模型概述
模型特點
模型能力
使用案例
🚀 [SwissBERT句子嵌入模型]
SwissBERT 模型經過微調,可用於生成句子嵌入。它使用了約150萬篇截至2023年的瑞士新聞文章,通過自監督的 SimCSE 方法進行微調。該模型支持德語、法語、意大利語和羅曼什語,在句子相似度任務中表現出色。
✨ 主要特性
- 基於 SwissBERT 模型進行微調,用於句子嵌入。
- 使用自監督的 SimCSE 方法進行訓練。
- 支持德語(de_CH)、法語(fr_CH)、意大利語(it_CH)和羅曼什語(rm_CH)。
- 採用最後隱藏狀態的平均值作為句子表示(pooler_type=avg)。
📦 安裝指南
文檔未提供具體安裝步驟,故跳過此章節。
💻 使用示例
基礎用法
import torch
from transformers import AutoModel, AutoTokenizer
# 加載用於句子嵌入的 swissBERT 模型
model_name = "jgrosjean-mathesis/sentence-swissbert"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_sentence_embedding(sentence, language):
# 將適配器設置為指定語言
if "de" in language:
model.set_default_language("de_CH")
if "fr" in language:
model.set_default_language("fr_CH")
if "it" in language:
model.set_default_language("it_CH")
if "rm" in language:
model.set_default_language("rm_CH")
# 對輸入句子進行分詞
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
# 將分詞後的輸入傳入模型
with torch.no_grad():
outputs = model(**inputs)
# 通過平均池化提取句子嵌入
token_embeddings = outputs.last_hidden_state
attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * attention_mask, 1)
sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
embedding = sum_embeddings / sum_mask
return embedding
# 測試一下
sentence_0 = "Wir feiern am 1. August den Schweizer Nationalfeiertag."
sentence_0_embedding = generate_sentence_embedding(sentence_0, language="de")
print(sentence_0_embedding)
輸出:
tensor([[ 5.6306e-02, -2.8375e-01, -4.1495e-02, 7.4393e-02, -3.1552e-01,
1.5213e-01, -1.0258e-01, 2.2790e-01, -3.5968e-02, 3.1769e-01,
1.9354e-01, 1.9748e-02, -1.5236e-01, -2.2657e-01, 1.3345e-02,
...]])
高級用法
from sklearn.metrics.pairwise import cosine_similarity
# 定義兩個句子
sentence_1 = ["Der Zug kommt um 9 Uhr in Zürich an."]
sentence_2 = ["Le train arrive à Lausanne à 9h."]
# 計算兩個句子的嵌入
embedding_1 = generate_sentence_embedding(sentence_1, language="de")
embedding_2 = generate_sentence_embedding(sentence_2, language="fr")
# 計算餘弦相似度
cosine_score = cosine_similarity(embedding_1, embedding_2)
# 輸出相似度得分
print("The cosine score for", sentence_1, "and", sentence_2, "is", cosine_score)
輸出:
The cosine score for ['Der Zug kommt um 9 Uhr in Zürich an.'] and ['Le train arrive à Lausanne à 9h.'] is [[0.85555995]]
📚 詳細文檔
模型詳情
屬性 | 詳情 |
---|---|
開發者 | Juri Grosjean |
模型類型 | XMOD |
支持語言 (NLP) | de_CH, fr_CH, it_CH, rm_CH |
許可證 | Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) |
微調基礎模型 | SwissBERT |
偏差、風險和侷限性
該句子 SwissBERT 模型僅在新聞文章上進行了訓練。因此,它在其他文本類別上的表現可能不佳。此外,該模型特定於與瑞士相關的上下文,這意味著它在不屬於該類別的文本上的表現可能也不理想。另外,該模型未針對機器翻譯任務進行訓練和評估。
訓練詳情
訓練數據
截至2023年,Swissdox@LiRI數據庫 中的德語、法語、意大利語和羅曼什語文檔。
訓練過程
該模型通過自監督的 SimCSE 方法進行微調。正序列對由文章正文與其標題和導語組成,沒有任何難負樣本。
微調腳本可在 此處 訪問。
訓練超參數
- 訓練輪數:1
- 學習率:1e-5
- 批量大小:512
- 溫度:0.05
評估
測試數據
兩項評估任務使用了由 Kew 等人(2023)整理的 20 Minuten 數據集,該數據集包含帶有主題標籤和摘要的瑞士新聞文章。數據集的部分內容使用 Google Cloud API 自動翻譯成法語、意大利語,並通過 Textshuttle API 翻譯成羅曼什語。
通過文檔檢索進行評估
為每個文檔的摘要和內容計算嵌入。隨後,通過最大化每個摘要和內容嵌入對之間的餘弦相似度得分來匹配嵌入。
性能通過準確率來衡量,即正確匹配與總匹配的比例。評估腳本可在 此處 找到。
通過文本分類進行評估
將一些帶有定義主題標籤的文章映射到10個類別,從語料庫中過濾出來,並分為訓練數據(80%)和測試數據(20%)。隨後,為訓練數據和測試數據設置嵌入。然後,使用訓練數據通過 k 近鄰方法對測試數據進行分類。評估腳本可在 此處 找到。
注意:對於法語、意大利語和羅曼什語,訓練數據保持為德語,而測試數據包含翻譯內容。這有助於瞭解模型在跨語言遷移方面的能力。
評估結果
在這些任務中,Sentence SwissBERT 取得了與表現最佳的多語言 Sentence-BERT 模型(distiluse-base-multilingual-cased)相當或更好的結果。除了意大利語的文本分類任務外,它在所有評估任務中都優於該模型。
評估任務 | Swissbert | Sentence Swissbert | Sentence-BERT | |||
---|---|---|---|---|---|---|
準確率 | F1 分數 | 準確率 | F1 分數 | 準確率 | F1 分數 | |
文檔檢索(德語) | 87.20 % | -- | 93.40 % | -- | 91.80 % | -- |
文檔檢索(法語) | 84.97 % | -- | 93.99 % | -- | 93.19 % | -- |
文檔檢索(意大利語) | 84.17 % | -- | 92.18 % | -- | 91.58 % | -- |
文檔檢索(羅曼什語) | 83.17 % | -- | 91.58 % | -- | 73.35 % | -- |
文本分類(德語) | -- | 77.93 % | -- | 78.49 % | -- | 77.23 % |
文本分類(法語) | -- | 69.62 % | -- | 77.18 % | -- | 76.83 % |
文本分類(意大利語) | -- | 67.09 % | -- | 76.65 % | -- | 76.90 % |
文本分類(羅曼什語) | -- | 43.79 % | -- | 77.20 % | -- | 65.35 % |
基線
基線使用原始 swissbert 模型最後隱藏狀態的平均池化嵌入,以及在這些任務中表現最佳的 Sentence-BERT 模型 distiluse-base-multilingual-cased-v1。
🔧 技術細節
該模型基於 SwissBERT 進行微調,使用自監督的 SimCSE 方法。正序列對由文章正文與其標題和導語組成,沒有任何難負樣本。在訓練過程中,使用了特定的超參數,如訓練輪數為1、學習率為1e-5、批量大小為512、溫度為0.05。在評估方面,通過文檔檢索和文本分類兩項任務進行評估,使用準確率和 F1 分數等指標來衡量模型性能。
📄 許可證
該模型使用的許可證為 Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)。







