🚀 ccdv/lsg-bart-base-4096-pubmed
该模型是基于 scientific_papers pubmed 数据集对 ccdv/lsg-bart-base-4096 进行微调后的版本。它利用 Local-Sparse-Global 注意力机制处理长序列,能在文本摘要等任务中取得较好效果。
⚠️ 重要提示
此模型依赖自定义建模文件,需要添加 trust_remote_code=True
。请确保使用的 Transformers >= 4.36.1
,详见 #13467。
🚀 快速开始
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
text = "Replace by what you want."
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
generated_text = pipe(
text,
truncation=True,
max_length=64,
no_repeat_ngram_size=7,
num_beams=2,
early_stopping=True
)
✨ 主要特性
- 长序列处理:该模型依赖 Local-Sparse-Global 注意力机制来处理长序列,其架构如图所示:

- 参数规模:模型约有 1.45 亿个参数,包含 6 个编码器层和 6 个解码器层。
- 微调基础:模型从 BART-base 进行热启动,转换为处理长序列(仅编码器)并进行微调。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
text = "Replace by what you want."
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
generated_text = pipe(
text,
truncation=True,
max_length=64,
no_repeat_ngram_size=7,
num_beams=2,
early_stopping=True
)
📚 详细文档
该模型在测试集上取得了以下结果:
较大块大小
长度 |
稀疏类型 |
块大小 |
稀疏度 |
连接数 |
R1 |
R2 |
RL |
RLsum |
4096 |
Local |
256 |
0 |
768 |
47.37 |
21.74 |
28.59 |
43.67 |
4096 |
Local |
128 |
0 |
384 |
47.02 |
21.33 |
28.34 |
43.31 |
4096 |
Pooling |
128 |
4 |
644 |
47.11 |
21.42 |
28.43 |
43.40 |
4096 |
Stride |
128 |
4 |
644 |
47.16 |
21.49 |
28.38 |
43.44 |
4096 |
Block Stride |
128 |
4 |
644 |
47.13 |
21.46 |
28.39 |
43.42 |
4096 |
Norm |
128 |
4 |
644 |
47.09 |
21.44 |
28.40 |
43.36 |
4096 |
LSH |
128 |
4 |
644 |
47.11 |
21.41 |
28.41 |
43.42 |
较小块大小(资源需求较低)
长度 |
稀疏类型 |
块大小 |
稀疏度 |
连接数 |
R1 |
R2 |
RL |
RLsum |
4096 |
Local |
64 |
0 |
192 |
45.74 |
20.26 |
27.51 |
41.99 |
4096 |
Local |
32 |
0 |
96 |
42.69 |
17.83 |
25.62 |
38.89 |
4096 |
Pooling |
32 |
4 |
160 |
44.60 |
19.35 |
26.83 |
40.85 |
4096 |
Stride |
32 |
4 |
160 |
45.52 |
20.07 |
27.39 |
41.75 |
4096 |
Block Stride |
32 |
4 |
160 |
45.30 |
19.89 |
27.22 |
41.54 |
4096 |
Norm |
32 |
4 |
160 |
44.30 |
19.05 |
26.57 |
40.47 |
4096 |
LSH |
32 |
4 |
160 |
44.53 |
19.27 |
26.84 |
40.74 |
🔧 技术细节
训练超参数
在训练过程中使用了以下超参数:
- 学习率:8e-05
- 训练批次大小:8
- 随机种子:42
- 梯度累积步数:4
- 总训练批次大小:32
- 优化器:Adam(β1 = 0.9,β2 = 0.999,ε = 1e-08)
- 学习率调度器类型:线性
- 学习率调度器热身比例:0.1
- 训练轮数:8.0
生成超参数
在生成过程中使用了以下超参数:
- 数据集名称:scientific_papers
- 数据集配置名称:pubmed
- 评估批次大小:8
- 评估样本数:6658
- 提前停止:True
- 忽略填充标记的损失:True
- 长度惩罚:2.0
- 最大长度:512
- 最小长度:128
- 束搜索数:5
- 无重复 n-gram 大小:None
- 随机种子:123
框架版本
- Transformers 4.18.0
- Pytorch 1.10.1+cu102
- Datasets 2.1.0
- Tokenizers 0.11.6
此外,LSG ArXiv 论文 提供了更多理论细节,Github/转换脚本可在 此链接 找到。