模型简介
模型特点
模型能力
使用案例
🚀 [SwissBERT句子嵌入模型]
SwissBERT 模型经过微调,可用于生成句子嵌入。它使用了约150万篇截至2023年的瑞士新闻文章,通过自监督的 SimCSE 方法进行微调。该模型支持德语、法语、意大利语和罗曼什语,在句子相似度任务中表现出色。
✨ 主要特性
- 基于 SwissBERT 模型进行微调,用于句子嵌入。
- 使用自监督的 SimCSE 方法进行训练。
- 支持德语(de_CH)、法语(fr_CH)、意大利语(it_CH)和罗曼什语(rm_CH)。
- 采用最后隐藏状态的平均值作为句子表示(pooler_type=avg)。
📦 安装指南
文档未提供具体安装步骤,故跳过此章节。
💻 使用示例
基础用法
import torch
from transformers import AutoModel, AutoTokenizer
# 加载用于句子嵌入的 swissBERT 模型
model_name = "jgrosjean-mathesis/sentence-swissbert"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def generate_sentence_embedding(sentence, language):
# 将适配器设置为指定语言
if "de" in language:
model.set_default_language("de_CH")
if "fr" in language:
model.set_default_language("fr_CH")
if "it" in language:
model.set_default_language("it_CH")
if "rm" in language:
model.set_default_language("rm_CH")
# 对输入句子进行分词
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
# 将分词后的输入传入模型
with torch.no_grad():
outputs = model(**inputs)
# 通过平均池化提取句子嵌入
token_embeddings = outputs.last_hidden_state
attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * attention_mask, 1)
sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
embedding = sum_embeddings / sum_mask
return embedding
# 测试一下
sentence_0 = "Wir feiern am 1. August den Schweizer Nationalfeiertag."
sentence_0_embedding = generate_sentence_embedding(sentence_0, language="de")
print(sentence_0_embedding)
输出:
tensor([[ 5.6306e-02, -2.8375e-01, -4.1495e-02, 7.4393e-02, -3.1552e-01,
1.5213e-01, -1.0258e-01, 2.2790e-01, -3.5968e-02, 3.1769e-01,
1.9354e-01, 1.9748e-02, -1.5236e-01, -2.2657e-01, 1.3345e-02,
...]])
高级用法
from sklearn.metrics.pairwise import cosine_similarity
# 定义两个句子
sentence_1 = ["Der Zug kommt um 9 Uhr in Zürich an."]
sentence_2 = ["Le train arrive à Lausanne à 9h."]
# 计算两个句子的嵌入
embedding_1 = generate_sentence_embedding(sentence_1, language="de")
embedding_2 = generate_sentence_embedding(sentence_2, language="fr")
# 计算余弦相似度
cosine_score = cosine_similarity(embedding_1, embedding_2)
# 输出相似度得分
print("The cosine score for", sentence_1, "and", sentence_2, "is", cosine_score)
输出:
The cosine score for ['Der Zug kommt um 9 Uhr in Zürich an.'] and ['Le train arrive à Lausanne à 9h.'] is [[0.85555995]]
📚 详细文档
模型详情
属性 | 详情 |
---|---|
开发者 | Juri Grosjean |
模型类型 | XMOD |
支持语言 (NLP) | de_CH, fr_CH, it_CH, rm_CH |
许可证 | Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) |
微调基础模型 | SwissBERT |
偏差、风险和局限性
该句子 SwissBERT 模型仅在新闻文章上进行了训练。因此,它在其他文本类别上的表现可能不佳。此外,该模型特定于与瑞士相关的上下文,这意味着它在不属于该类别的文本上的表现可能也不理想。另外,该模型未针对机器翻译任务进行训练和评估。
训练详情
训练数据
截至2023年,Swissdox@LiRI数据库 中的德语、法语、意大利语和罗曼什语文档。
训练过程
该模型通过自监督的 SimCSE 方法进行微调。正序列对由文章正文与其标题和导语组成,没有任何难负样本。
微调脚本可在 此处 访问。
训练超参数
- 训练轮数:1
- 学习率:1e-5
- 批量大小:512
- 温度:0.05
评估
测试数据
两项评估任务使用了由 Kew 等人(2023)整理的 20 Minuten 数据集,该数据集包含带有主题标签和摘要的瑞士新闻文章。数据集的部分内容使用 Google Cloud API 自动翻译成法语、意大利语,并通过 Textshuttle API 翻译成罗曼什语。
通过文档检索进行评估
为每个文档的摘要和内容计算嵌入。随后,通过最大化每个摘要和内容嵌入对之间的余弦相似度得分来匹配嵌入。
性能通过准确率来衡量,即正确匹配与总匹配的比例。评估脚本可在 此处 找到。
通过文本分类进行评估
将一些带有定义主题标签的文章映射到10个类别,从语料库中过滤出来,并分为训练数据(80%)和测试数据(20%)。随后,为训练数据和测试数据设置嵌入。然后,使用训练数据通过 k 近邻方法对测试数据进行分类。评估脚本可在 此处 找到。
注意:对于法语、意大利语和罗曼什语,训练数据保持为德语,而测试数据包含翻译内容。这有助于了解模型在跨语言迁移方面的能力。
评估结果
在这些任务中,Sentence SwissBERT 取得了与表现最佳的多语言 Sentence-BERT 模型(distiluse-base-multilingual-cased)相当或更好的结果。除了意大利语的文本分类任务外,它在所有评估任务中都优于该模型。
评估任务 | Swissbert | Sentence Swissbert | Sentence-BERT | |||
---|---|---|---|---|---|---|
准确率 | F1 分数 | 准确率 | F1 分数 | 准确率 | F1 分数 | |
文档检索(德语) | 87.20 % | -- | 93.40 % | -- | 91.80 % | -- |
文档检索(法语) | 84.97 % | -- | 93.99 % | -- | 93.19 % | -- |
文档检索(意大利语) | 84.17 % | -- | 92.18 % | -- | 91.58 % | -- |
文档检索(罗曼什语) | 83.17 % | -- | 91.58 % | -- | 73.35 % | -- |
文本分类(德语) | -- | 77.93 % | -- | 78.49 % | -- | 77.23 % |
文本分类(法语) | -- | 69.62 % | -- | 77.18 % | -- | 76.83 % |
文本分类(意大利语) | -- | 67.09 % | -- | 76.65 % | -- | 76.90 % |
文本分类(罗曼什语) | -- | 43.79 % | -- | 77.20 % | -- | 65.35 % |
基线
基线使用原始 swissbert 模型最后隐藏状态的平均池化嵌入,以及在这些任务中表现最佳的 Sentence-BERT 模型 distiluse-base-multilingual-cased-v1。
🔧 技术细节
该模型基于 SwissBERT 进行微调,使用自监督的 SimCSE 方法。正序列对由文章正文与其标题和导语组成,没有任何难负样本。在训练过程中,使用了特定的超参数,如训练轮数为1、学习率为1e-5、批量大小为512、温度为0.05。在评估方面,通过文档检索和文本分类两项任务进行评估,使用准确率和 F1 分数等指标来衡量模型性能。
📄 许可证
该模型使用的许可证为 Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)。







