🚀 DeBERTa (1.5B) 修復版本
本項目是對 deberta-v2-xxlarge 進行更新,使其實現了 AutoModelForCausalLM
類,從而能夠生成文本。此實現基於我們的論文 "BERTs are Generative In-Context Learners"。
本倉庫還修復了 DeBERTa 在 Hugging Face 上的原始實現 中的三個問題:
- 修復了檢查點文件中輸出嵌入權重的錯誤名稱;
- 基於 原始 GitHub 倉庫 修復了增強掩碼解碼器(EMD)的實現;
- 對位置嵌入進行了限制,使其能夠處理長序列。
🚀 快速開始
本項目是對 deberta-v2-xxlarge 的更新版本,實現了 AutoModelForCausalLM
類,從而具備文本生成能力。其實現基於論文 "BERTs are Generative In-Context Learners"。同時,本倉庫修復了原始實現中的三個問題,提升了模型的性能和穩定性。
💻 使用示例
基礎用法
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("ltg/deberta-xxlarge-fixed", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("ltg/deberta-xxlarge-fixed", trust_remote_code=True).cuda().eval()
prompt = """German: Hallo, wie geht es Ihnen heute?
English:"""
prompt = prompt.replace('\n', '\\n ')
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.cuda()
prediction = model.generate(
input_ids,
num_beams=4,
do_sample=False,
use_cache=None,
max_new_tokens=64,
eos_token_id=tokenizer(".\\", add_special_tokens=False).input_ids[1:]
)
prediction = prediction[0, input_ids.size(1):]
prediction = tokenizer.decode(prediction).rstrip('\\')
print(prediction)
📄 許可證
本項目採用 MIT 許可證。
📚 引用
如果您發現 DeBERTa 對您的工作有幫助,請引用以下論文:
@inproceedings{
samuel2024berts,
title={{BERT}s are Generative In-Context Learners},
author={David Samuel},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=BCA9NMZkLS}
}
@inproceedings{he2021deberta,
title={{DeBERTa}: Decoding-enhanced {BERT} with disentangled attention},
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=XPZIaotutsD}
}