🚀 CodeT5-base代码摘要生成模型
CodeT5-base是一个用于代码摘要生成的模型。它基于CodeT5-base,在CodeSearchNet数据集上进行了多语言(Ruby、JavaScript、Go、Python、Java、PHP)的微调训练。该模型由Yue Wang、Weishi Wang、Shafiq Joty和Steven C.H. Hoi等人在2021年的EMNLP会议论文CodeT5: Identifier-aware Unified Pre-trained Encoder-Decoder Models for Code Understanding and Generation中提出。更多详情请查看此仓库。
🚀 快速开始
以下是使用该模型的示例代码:
from transformers import RobertaTokenizer, T5ForConditionalGeneration
if __name__ == '__main__':
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum')
text = """def svg_to_image(string, size=None):
if isinstance(string, unicode):
string = string.encode('utf-8')
renderer = QtSvg.QSvgRenderer(QtCore.QByteArray(string))
if not renderer.isValid():
raise ValueError('Invalid SVG data.')
if size is None:
size = renderer.defaultSize()
image = QtGui.QImage(size, QtGui.QImage.Format_ARGB32)
painter = QtGui.QPainter(image)
renderer.render(painter)
return image"""
input_ids = tokenizer(text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=20)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
📦 安装指南
文档未提及安装步骤,如需使用该模型,请参考transformers库的安装说明。
💻 使用示例
基础用法
from transformers import RobertaTokenizer, T5ForConditionalGeneration
if __name__ == '__main__':
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base-multi-sum')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base-multi-sum')
text = """def svg_to_image(string, size=None):
if isinstance(string, unicode):
string = string.encode('utf-8')
renderer = QtSvg.QSvgRenderer(QtCore.QByteArray(string))
if not renderer.isValid():
raise ValueError('Invalid SVG data.')
if size is None:
size = renderer.defaultSize()
image = QtGui.QImage(size, QtGui.QImage.Format_ARGB32)
painter = QtGui.QPainter(image)
renderer.render(painter)
return image"""
input_ids = tokenizer(text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=20)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
📚 详细文档
微调数据
我们使用了来自CodeXGLUE基准测试的过滤版本的CodeSearchNet数据[Husain et al., 2019]进行代码摘要生成的微调。数据使用我们预训练的特定于代码的BPE(字节对编码)分词器进行分词。可以使用codet5-base的词汇文件,通过RobertaTokenizer
为模型准备文本(或代码)。
数据统计
编程语言 |
训练集 |
验证集 |
测试集 |
Python |
251,820 |
13,914 |
14,918 |
PHP |
241,241 |
12,982 |
14,014 |
Go |
167,288 |
7,325 |
8,122 |
Java |
164,923 |
5,183 |
10,955 |
JavaScript |
58,025 |
3,885 |
3,291 |
Ruby |
24,927 |
1,400 |
1,261 |
训练过程
我们在多任务学习设置下,对这六种编程语言(Ruby、JavaScript、Go、Python、Java、PHP)进行了codet5-base
的微调。我们采用了平衡采样的方法,以避免偏向高资源任务。更多详细信息请参考论文。
评估结果
与论文中允许为不同编程语言(PL)选择不同的最佳检查点不同,这里我们对所有PL使用一个检查点。此外,我们在训练和推理中移除了指定PL的任务控制前缀。测试集上的结果如下:
模型 |
Ruby |
JavaScript |
Go |
Python |
Java |
PHP |
总体 |
Seq2Seq |
9.64 |
10.21 |
13.98 |
15.93 |
15.09 |
21.08 |
14.32 |
Transformer |
11.18 |
11.59 |
16.38 |
15.81 |
16.26 |
22.12 |
15.56 |
RoBERTa |
11.17 |
11.90 |
17.72 |
18.14 |
16.47 |
24.02 |
16.57 |
CodeBERT |
12.16 |
14.90 |
18.07 |
19.06 |
17.65 |
25.16 |
17.83 |
PLBART |
14.11 |
15.56 |
18.91 |
19.30 |
18.45 |
23.58 |
18.32 |
CodeT5-small |
14.87 |
15.32 |
19.25 |
20.04 |
19.92 |
25.46 |
19.14 |
CodeT5-base |
15.24 |
16.16 |
19.56 |
20.01 |
20.31 |
26.03 |
19.55 |
CodeT5-base-multi-sum |
15.24 |
16.18 |
19.95 |
20.42 |
20.26 |
26.10 |
19.69 |
📄 许可证
本项目采用BSD 3-Clause许可证。
🔗 引用
@inproceedings{
wang2021codet5,
title={CodeT5: Identifier-aware Unified Pre-trained Encoder-Decoder Models for Code Understanding and Generation},
author={Yue Wang, Weishi Wang, Shafiq Joty, Steven C.H. Hoi},
booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, EMNLP 2021},
year={2021},
}