模型概述
模型特點
模型能力
使用案例
🚀 MrT5 Large模型卡片
MrT5(MergeT5)是ByT5 (Xue et al., 2022)的高效變體,它在編碼器中集成了令牌刪除機制,能夠動態縮短輸入序列長度。該模型解決了現有字節級模型的實際限制問題,通過將關鍵信息“合併”到更緊湊的序列中,提供了一種有效的解決方案。
🚀 快速開始
和ByT5一樣,MrT5可以處理原始的UTF - 8字節數據,並且無需分詞器即可使用。請確保設置trust_remote_code=True
來加載MrT5代碼:
from transformers import AutoModelForSeq2SeqLM
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 # add 3 for special tokens
# Forward pass with hard deletion
loss = model(input_ids, labels=labels, hard_delete=True).loss
對於批量推理和訓練,你可以使用ByT5的分詞器類:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-large')
model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
labels = tokenizer(["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt").input_ids
# Forward pass with hard deletion
loss = model(**model_inputs, labels=labels, hard_delete=True).loss
✨ 主要特性
- 動態令牌合併:MrT5在編碼器中集成了令牌刪除機制,通過學習的刪除門動態決定哪些令牌要刪除,哪些要保留,有效縮短輸入序列長度,解決了現有字節級模型的實際限制問題。
- 高效性能:該模型能夠將編碼器序列長度平均減少約50%,同時僅引入了額外的3000個參數。
- 多語言支持:支持15種類型多樣的語言,包括英語、法語、西班牙語、德語、希臘語、保加利亞語、俄語、土耳其語、阿拉伯語、越南語、泰語、中文、印地語、斯瓦希里語和烏爾都語。
📦 安裝指南
文檔未提及安裝步驟,故跳過該章節。
💻 使用示例
基礎用法
from transformers import AutoModelForSeq2SeqLM
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 # add 3 for special tokens
# Forward pass with hard deletion
loss = model(input_ids, labels=labels, hard_delete=True).loss
高級用法
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model = AutoModelForSeq2SeqLM.from_pretrained('stanfordnlp/mrt5-large', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('google/byt5-large')
model_inputs = tokenizer(["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt")
labels = tokenizer(["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt").input_ids
# Forward pass with hard deletion
loss = model(**model_inputs, labels=labels, hard_delete=True).loss
📚 詳細文檔
模型詳情
這是參數為12.3億的MrT5 Large (mrt5-large
)的模型卡片,它是ByT5 Large (google/byt5-large
)的更高效變體。該模型經過訓練,平均可將序列長度減少約50%。
- 開發者:Julie Kallini、Shikhar Murty、Christopher D. Manning、Christopher Potts、Róbert Csordás
- 模型類型:MrT5
- 支持語言:英語、法語、西班牙語、德語、希臘語、保加利亞語、俄語、土耳其語、阿拉伯語、越南語、泰語、中文、印地語、斯瓦希里語和烏爾都語
- 微調基礎模型:google/byt5-large
- 更多信息來源:
模型架構
MrT5 Large採用了標準ByT5 Large的模型配置,其前饋維度為3840,模型維度為1536,有36個編碼器層、12個解碼器層,每層有16個注意力頭,總參數為12.3億。
MrT5有一個額外的刪除門,可動態減少編碼器序列長度。在該模型中,刪除門位於第三個編碼器層之後,後續所有層都在縮減後的序列上運行。該模型的訓練刪除率δ = 0.5,這意味著模型在第三層之後將其編碼器序列長度減少約50%。MrT5的門控機制僅引入了額外的3000個參數。
MrT5 Large從ByT5 Large初始化,並在相同的訓練目標上進行微調。只有MrT5的刪除門在訓練前是隨機初始化的。MrT5的另一個顯著特點是在其注意力機制中使用了softmax1。
使用場景
該模型是一種編碼器 - 解碼器架構,主要設計用於序列到序列任務。雖然它可以直接用於探索性或學術目的,但建議進行微調以在特定下游任務上實現最佳性能。
要利用模型的刪除功能,請使用配套倉庫中提供的自定義MrT5Trainer。這個專門的訓練器可確保在微調過程中正確維護和集成刪除機制。
由於這是一個為學術和研究探索而構建的基礎模型,不適合用於生產級部署。用戶應仔細評估模型的輸出,特別是在對可靠性和魯棒性要求較高的任何場景中。
偏差、風險和侷限性
語言模型已知會表現出各種形式的社會偏差,並且可能產生有害或冒犯性內容(Bender et al., 2021; Bommasani et al., 2022; Liang et al., 2022)。與其他語言模型一樣,該模型可能會產生有偏差或有害的輸出。它尚未針對安全性進行微調,因此應謹慎使用,尤其是在敏感環境中。
訓練詳情
訓練數據
在持續預訓練中,我們使用了多語言C4 (mC4)語料庫(Raffel et al., 2020; Xue et al., 2021)。MrT5在15種類型多樣的語言上進行訓練:英語、法語、西班牙語、德語、希臘語、保加利亞語、俄語、土耳其語、阿拉伯語、越南語、泰語、中文、印地語、斯瓦希里語和烏爾都語。
為避免對模型進行多個週期的訓練,我們確保從mC4語料庫中抽取的樣本足夠大。此外,我們從mC4訓練分割中為每種語言抽取相同大小(按字節計算)的樣本。
訓練過程
MrT5在ByT5的跨度損壞預訓練目標上進行訓練。在這個任務中,未標記文本數據中的令牌跨度被每個跨度一個單獨的哨兵令牌 ID 替換,模型必須填充缺失的令牌。對於ByT5和MrT5,這些是字節跨度,並且掩碼可能會干擾單詞邊界。
預處理
在針對跨度損壞目標進行訓練時,我們計算損壞的跨度,使得平均掩碼跨度長度為20個令牌,噪聲密度為15%,即序列中15%的令牌被掩碼,遵循ByT5論文中概述的規範。
優化
MrT5在2^20個令牌的批次上進行5000次梯度步驟的訓練(即編碼器序列長度為1024,有效批次大小為1024)。我們使用AdamW優化器,初始學習率為1e - 4,採用線性衰減且無預熱。
為實現特定的序列長度減少率,我們使用PI控制器,目標刪除率δ = 0.5,如論文第3.2節所述。我們還使用了注意力分數正則化,如論文附錄D所述。
環境影響
- 硬件類型:NVIDIA A100 - SXM4 - 80GB
- GPU數量:4
- 使用時長:約73小時
- 雲服務提供商:斯坦福NLP集群
模型卡片作者
Julie Kallini kallini@stanford.edu
🔧 技術細節
文檔未提供符合要求的技術細節內容,故跳過該章節。
📄 許可證
文檔未提及許可證信息,故跳過該章節。
📚 引用信息
如果你使用此模型,請引用MrT5論文:
@inproceedings{
kallini2025mrt,
title={MrT5: Dynamic Token Merging for Efficient Byte-level Language Models},
author={Julie Kallini and Shikhar Murty and Christopher D Manning and Christopher Potts and R{\'o}bert Csord{\'a}s},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=VYWBMq1L7H}
}
同時引用ByT5論文:
@article{xue-etal-2022-byt5,
title = "{B}y{T}5: Towards a Token-Free Future with Pre-trained Byte-to-Byte Models",
author = "Xue, Linting and
Barua, Aditya and
Constant, Noah and
Al-Rfou, Rami and
Narang, Sharan and
Kale, Mihir and
Roberts, Adam and
Raffel, Colin",
editor = "Roark, Brian and
Nenkova, Ani",
journal = "Transactions of the Association for Computational Linguistics",
volume = "10",
year = "2022",
address = "Cambridge, MA",
publisher = "MIT Press",
url = "https://aclanthology.org/2022.tacl-1.17",
doi = "10.1162/tacl_a_00461",
pages = "291--306",
}



