🚀 ccdv/lsg-bart-base-4096-pubmed
該模型是基於 scientific_papers pubmed 數據集對 ccdv/lsg-bart-base-4096 進行微調後的版本。它利用 Local-Sparse-Global 注意力機制處理長序列,能在文本摘要等任務中取得較好效果。
⚠️ 重要提示
此模型依賴自定義建模文件,需要添加 trust_remote_code=True
。請確保使用的 Transformers >= 4.36.1
,詳見 #13467。
🚀 快速開始
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
text = "Replace by what you want."
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
generated_text = pipe(
text,
truncation=True,
max_length=64,
no_repeat_ngram_size=7,
num_beams=2,
early_stopping=True
)
✨ 主要特性
- 長序列處理:該模型依賴 Local-Sparse-Global 注意力機制來處理長序列,其架構如圖所示:

- 參數規模:模型約有 1.45 億個參數,包含 6 個編碼器層和 6 個解碼器層。
- 微調基礎:模型從 BART-base 進行熱啟動,轉換為處理長序列(僅編碼器)並進行微調。
💻 使用示例
基礎用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-bart-base-4096-pubmed", trust_remote_code=True)
text = "Replace by what you want."
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
generated_text = pipe(
text,
truncation=True,
max_length=64,
no_repeat_ngram_size=7,
num_beams=2,
early_stopping=True
)
📚 詳細文檔
該模型在測試集上取得了以下結果:
較大塊大小
長度 |
稀疏類型 |
塊大小 |
稀疏度 |
連接數 |
R1 |
R2 |
RL |
RLsum |
4096 |
Local |
256 |
0 |
768 |
47.37 |
21.74 |
28.59 |
43.67 |
4096 |
Local |
128 |
0 |
384 |
47.02 |
21.33 |
28.34 |
43.31 |
4096 |
Pooling |
128 |
4 |
644 |
47.11 |
21.42 |
28.43 |
43.40 |
4096 |
Stride |
128 |
4 |
644 |
47.16 |
21.49 |
28.38 |
43.44 |
4096 |
Block Stride |
128 |
4 |
644 |
47.13 |
21.46 |
28.39 |
43.42 |
4096 |
Norm |
128 |
4 |
644 |
47.09 |
21.44 |
28.40 |
43.36 |
4096 |
LSH |
128 |
4 |
644 |
47.11 |
21.41 |
28.41 |
43.42 |
較小塊大小(資源需求較低)
長度 |
稀疏類型 |
塊大小 |
稀疏度 |
連接數 |
R1 |
R2 |
RL |
RLsum |
4096 |
Local |
64 |
0 |
192 |
45.74 |
20.26 |
27.51 |
41.99 |
4096 |
Local |
32 |
0 |
96 |
42.69 |
17.83 |
25.62 |
38.89 |
4096 |
Pooling |
32 |
4 |
160 |
44.60 |
19.35 |
26.83 |
40.85 |
4096 |
Stride |
32 |
4 |
160 |
45.52 |
20.07 |
27.39 |
41.75 |
4096 |
Block Stride |
32 |
4 |
160 |
45.30 |
19.89 |
27.22 |
41.54 |
4096 |
Norm |
32 |
4 |
160 |
44.30 |
19.05 |
26.57 |
40.47 |
4096 |
LSH |
32 |
4 |
160 |
44.53 |
19.27 |
26.84 |
40.74 |
🔧 技術細節
訓練超參數
在訓練過程中使用了以下超參數:
- 學習率:8e-05
- 訓練批次大小:8
- 隨機種子:42
- 梯度累積步數:4
- 總訓練批次大小:32
- 優化器:Adam(β1 = 0.9,β2 = 0.999,ε = 1e-08)
- 學習率調度器類型:線性
- 學習率調度器熱身比例:0.1
- 訓練輪數:8.0
生成超參數
在生成過程中使用了以下超參數:
- 數據集名稱:scientific_papers
- 數據集配置名稱:pubmed
- 評估批次大小:8
- 評估樣本數:6658
- 提前停止:True
- 忽略填充標記的損失:True
- 長度懲罰:2.0
- 最大長度:512
- 最小長度:128
- 束搜索數:5
- 無重複 n-gram 大小:None
- 隨機種子:123
框架版本
- Transformers 4.18.0
- Pytorch 1.10.1+cu102
- Datasets 2.1.0
- Tokenizers 0.11.6
此外,LSG ArXiv 論文 提供了更多理論細節,Github/轉換腳本可在 此鏈接 找到。