模型简介
模型特点
模型能力
使用案例
🚀 MrT5 Large模型卡片
MrT5(MergeT5)是ByT5 (Xue et al., 2022)的高效变体,它在编码器中集成了令牌删除机制,能够动态缩短输入序列长度。该模型解决了现有字节级模型的实际限制问题,通过将关键信息“合并”到更紧凑的序列中,提供了一种有效的解决方案。
🚀 快速开始
和ByT5一样,MrT5可以处理原始的UTF - 8字节数据,并且无需分词器即可使用。请确保设置trust_remote_code=True
来加载MrT5代码:
from transformers import AutoModelForSeq2SeqLM
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 # add 3 for special tokens
# Forward pass with hard deletion
loss = model(input_ids, labels=labels, hard_delete=True).loss
对于批量推理和训练,你可以使用ByT5的分词器类:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-large')
model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
labels = tokenizer(["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt").input_ids
# Forward pass with hard deletion
loss = model(**model_inputs, labels=labels, hard_delete=True).loss
✨ 主要特性
- 动态令牌合并:MrT5在编码器中集成了令牌删除机制,通过学习的删除门动态决定哪些令牌要删除,哪些要保留,有效缩短输入序列长度,解决了现有字节级模型的实际限制问题。
- 高效性能:该模型能够将编码器序列长度平均减少约50%,同时仅引入了额外的3000个参数。
- 多语言支持:支持15种类型多样的语言,包括英语、法语、西班牙语、德语、希腊语、保加利亚语、俄语、土耳其语、阿拉伯语、越南语、泰语、中文、印地语、斯瓦希里语和乌尔都语。
📦 安装指南
文档未提及安装步骤,故跳过该章节。
💻 使用示例
基础用法
from transformers import AutoModelForSeq2SeqLM
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 # add 3 for special tokens
# Forward pass with hard deletion
loss = model(input_ids, labels=labels, hard_delete=True).loss
高级用法
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-large')
model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
labels = tokenizer(["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt").input_ids
# Forward pass with hard deletion
loss = model(**model_inputs, labels=labels, hard_delete=True).loss
📚 详细文档
模型详情
这是参数为12.3亿的MrT5 Large (mrt5-large
)的模型卡片,它是ByT5 Large (google/byt5-large
)的更高效变体。该模型经过训练,平均可将序列长度减少约50%。
- 开发者:Julie Kallini、Shikhar Murty、Christopher D. Manning、Christopher Potts、Róbert Csordás
- 模型类型:MrT5
- 支持语言:英语、法语、西班牙语、德语、希腊语、保加利亚语、俄语、土耳其语、阿拉伯语、越南语、泰语、中文、印地语、斯瓦希里语和乌尔都语
- 微调基础模型:google/byt5-large
- 更多信息来源:
模型架构
MrT5 Large采用了标准ByT5 Large的模型配置,其前馈维度为3840,模型维度为1536,有36个编码器层、12个解码器层,每层有16个注意力头,总参数为12.3亿。
MrT5有一个额外的删除门,可动态减少编码器序列长度。在该模型中,删除门位于第三个编码器层之后,后续所有层都在缩减后的序列上运行。该模型的训练删除率δ = 0.5,这意味着模型在第三层之后将其编码器序列长度减少约50%。MrT5的门控机制仅引入了额外的3000个参数。
MrT5 Large从ByT5 Large初始化,并在相同的训练目标上进行微调。只有MrT5的删除门在训练前是随机初始化的。MrT5的另一个显著特点是在其注意力机制中使用了softmax1。
使用场景
该模型是一种编码器 - 解码器架构,主要设计用于序列到序列任务。虽然它可以直接用于探索性或学术目的,但建议进行微调以在特定下游任务上实现最佳性能。
要利用模型的删除功能,请使用配套仓库中提供的自定义MrT5Trainer。这个专门的训练器可确保在微调过程中正确维护和集成删除机制。
由于这是一个为学术和研究探索而构建的基础模型,不适合用于生产级部署。用户应仔细评估模型的输出,特别是在对可靠性和鲁棒性要求较高的任何场景中。
偏差、风险和局限性
语言模型已知会表现出各种形式的社会偏差,并且可能产生有害或冒犯性内容(Bender et al., 2021; Bommasani et al., 2022; Liang et al., 2022)。与其他语言模型一样,该模型可能会产生有偏差或有害的输出。它尚未针对安全性进行微调,因此应谨慎使用,尤其是在敏感环境中。
训练详情
训练数据
在持续预训练中,我们使用了多语言C4 (mC4)语料库(Raffel et al., 2020; Xue et al., 2021)。MrT5在15种类型多样的语言上进行训练:英语、法语、西班牙语、德语、希腊语、保加利亚语、俄语、土耳其语、阿拉伯语、越南语、泰语、中文、印地语、斯瓦希里语和乌尔都语。
为避免对模型进行多个周期的训练,我们确保从mC4语料库中抽取的样本足够大。此外,我们从mC4训练分割中为每种语言抽取相同大小(按字节计算)的样本。
训练过程
MrT5在ByT5的跨度损坏预训练目标上进行训练。在这个任务中,未标记文本数据中的令牌跨度被每个跨度一个单独的哨兵令牌 ID 替换,模型必须填充缺失的令牌。对于ByT5和MrT5,这些是字节跨度,并且掩码可能会干扰单词边界。
预处理
在针对跨度损坏目标进行训练时,我们计算损坏的跨度,使得平均掩码跨度长度为20个令牌,噪声密度为15%,即序列中15%的令牌被掩码,遵循ByT5论文中概述的规范。
优化
MrT5在2^20个令牌的批次上进行5000次梯度步骤的训练(即编码器序列长度为1024,有效批次大小为1024)。我们使用AdamW优化器,初始学习率为1e - 4,采用线性衰减且无预热。
为实现特定的序列长度减少率,我们使用PI控制器,目标删除率δ = 0.5,如论文第3.2节所述。我们还使用了注意力分数正则化,如论文附录D所述。
环境影响
- 硬件类型:NVIDIA A100 - SXM4 - 80GB
- GPU数量:4
- 使用时长:约73小时
- 云服务提供商:斯坦福NLP集群
模型卡片作者
Julie Kallini kallini@stanford.edu
🔧 技术细节
文档未提供符合要求的技术细节内容,故跳过该章节。
📄 许可证
文档未提及许可证信息,故跳过该章节。
📚 引用信息
如果你使用此模型,请引用MrT5论文:
@inproceedings{
kallini2025mrt,
title={MrT5: Dynamic Token Merging for Efficient Byte-level Language Models},
author={Julie Kallini and Shikhar Murty and Christopher D Manning and Christopher Potts and R{\'o}bert Csord{\'a}s},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=VYWBMq1L7H}
}
同时引用ByT5论文:
@article{xue-etal-2022-byt5,
title = "{B}y{T}5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models",
author = "Xue, Linting and
Barua, Aditya and
Constant, Noah and
Al-Rfou, Rami and
Narang, Sharan and
Kale, Mihir and
Roberts, Adam and
Raffel, Colin",
editor = "Roark, Brian and
Nenkova, Ani",
journal = "Transactions of the Association for Computational Linguistics",
volume = "10",
year = "2022",
address = "Cambridge, MA",
publisher = "MIT Press",
url = "https://aclanthology.org/2022.tacl-1.17",
doi = "10.1162/tacl_a_00461",
pages = "291--306",
}



