🚀 ruRoPEBert经典俄语模型
这是一个来自 Tochka AI 的编码器模型,基于 RoPEBert 架构,采用了我们在Habr上发表的文章中描述的克隆方法。
该模型使用 CulturaX 数据集进行训练。以 ai-forever/ruBert-base 模型为原型,根据 encodechka 基准测试,此模型在质量上超越了原型模型。
模型源代码可在文件 modeling_rope_bert.py 中获取。
该模型在长度 最多为512个标记 的上下文上进行训练,但也可用于更大的上下文。为获得更好的质量,可使用此模型的扩展上下文版本 - Tochka-AI/ruRoPEBert-classic-base-2k。
🚀 快速开始
本模型使用时,建议 transformers
库的版本为4.37.2及以上。要正确加载模型,必须启用从模型仓库下载代码的功能:trust_remote_code=True
,这样会下载 modeling_rope_bert.py 脚本并将权重加载到正确的架构中。
否则,你可以手动下载此脚本并直接使用其中的类来加载模型。
✨ 主要特性
- 基于RoPEBert架构:采用先进的架构,提升模型性能。
- 超越原型模型:在质量上超越了 ai-forever/ruBert-base 模型。
- 支持扩展上下文:可使用扩展上下文版本,处理更大的上下文。
📦 安装指南
确保你已安装 transformers
库,版本为4.37.2及以上。
💻 使用示例
基础用法(无高效注意力机制)
model_name = 'Tochka-AI/ruRoPEBert-classic-base-512'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation='eager')
高级用法(使用SDPA高效注意力机制)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation='sdpa')
获取嵌入向量
正确的池化器 (mean
) 已 内置在模型架构中,它会根据注意力掩码对嵌入向量进行平均。你也可以选择池化器类型 (first_token_transform
),它会对第一个标记执行可学习的线性变换。
要更改内置池化器的实现,请在 AutoModel.from_pretrained
函数中使用 pooler_type
参数。
test_batch = tokenizer.batch_encode_plus(["Привет, чем занят?", "Здравствуйте, чем вы занимаетесь?"], return_tensors='pt', padding=True)
with torch.inference_mode():
pooled_output = model(**test_batch).pooler_output
此外,你可以使用归一化和矩阵乘法计算批次中文本之间的余弦相似度:
import torch.nn.functional as F
F.normalize(pooled_output, dim=1) @ F.normalize(pooled_output, dim=1).T
用作分类器
要加载带有可训练分类头的模型(更改 num_labels
参数):
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, attn_implementation='sdpa', num_labels=4)
使用RoPE缩放
RoPE缩放允许的类型为:linear
和 dynamic
。要扩展模型的上下文窗口,需要更改分词器的最大长度并添加 rope_scaling
参数。
如果你想将模型上下文缩放2倍:
tokenizer.model_max_length = 1024
model = AutoModel.from_pretrained(model_name,
trust_remote_code=True,
attn_implementation='sdpa',
rope_scaling={'type': 'dynamic','factor': 2.0}
)
注意:别忘了指定所需的数据类型和设备,以有效利用资源。
📚 详细文档
模型评估指标
此模型在encodechka基准测试中的评估结果如下:
模型名称 |
STS |
PI |
NLI |
SA |
TI |
IA |
IC |
ICX |
NE1 |
NE2 |
平均S(不含NE) |
平均S+W(含NE) |
ruRoPEBert-classic-base-512 |
0.695 |
0.605 |
0.396 |
0.794 |
0.975 |
0.797 |
0.769 |
0.386 |
0.410 |
0.609 |
0.677 |
0.630 |
ai-forever/ruBert-base |
0.670 |
0.533 |
0.391 |
0.773 |
0.975 |
0.783 |
0.765 |
0.384 |
- |
- |
0.659 |
- |
👨💻 作者