🚀 DeBERTa (1.5B) 修复版本
本项目是对 deberta-v2-xxlarge 进行更新,使其实现了 AutoModelForCausalLM
类,从而能够生成文本。此实现基于我们的论文 "BERTs are Generative In-Context Learners"。
本仓库还修复了 DeBERTa 在 Hugging Face 上的原始实现 中的三个问题:
- 修复了检查点文件中输出嵌入权重的错误名称;
- 基于 原始 GitHub 仓库 修复了增强掩码解码器(EMD)的实现;
- 对位置嵌入进行了限制,使其能够处理长序列。
🚀 快速开始
本项目是对 deberta-v2-xxlarge 的更新版本,实现了 AutoModelForCausalLM
类,从而具备文本生成能力。其实现基于论文 "BERTs are Generative In-Context Learners"。同时,本仓库修复了原始实现中的三个问题,提升了模型的性能和稳定性。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("ltg/deberta-xxlarge-fixed", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("ltg/deberta-xxlarge-fixed", trust_remote_code=True).cuda().eval()
prompt = """German: Hallo, wie geht es Ihnen heute?
English:"""
prompt = prompt.replace('\n', '\\n ')
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.cuda()
prediction = model.generate(
input_ids,
num_beams=4,
do_sample=False,
use_cache=None,
max_new_tokens=64,
eos_token_id=tokenizer(".\\", add_special_tokens=False).input_ids[1:]
)
prediction = prediction[0, input_ids.size(1):]
prediction = tokenizer.decode(prediction).rstrip('\\')
print(prediction)
📄 许可证
本项目采用 MIT 许可证。
📚 引用
如果您发现 DeBERTa 对您的工作有帮助,请引用以下论文:
@inproceedings{
samuel2024berts,
title={{BERT}s are Generative In-Context Learners},
author={David Samuel},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=BCA9NMZkLS}
}
@inproceedings{he2021deberta,
title={{DeBERTa}: Decoding-enhanced {BERT} with disentangled attention},
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=XPZIaotutsD}
}