🚀 long-t5-tglobal-xl + BookSum
本项目基于long-t5-tglobal-xl
模型在kmfoda/booksum
数据集上微调而来,能够对长文本进行总结,为你提供类似SparkNotes的各主题摘要。它在学术和叙述性文本上泛化能力较好,XL版本在人工评估中能生成更优质的摘要。
🚀 快速开始
安装依赖
安装或更新transformers
库:
pip install -U transformers
文本总结示例
使用pipeline
进行文本总结:
import torch
from transformers import pipeline
summarizer = pipeline(
"summarization",
"pszemraj/long-t5-tglobal-xl-16384-book-summary",
device=0 if torch.cuda.is_available() else -1,
)
long_text = "Here is a lot of text I don't want to read. Replace me"
result = summarizer(long_text)
print(result[0]["summary_text"])
⚠️ 重要提示
根据此讨论,我们发现long-t5
模型版本 >= 4.23.0 存在问题。请使用pip install transformers==4.22.0
以确保该模型性能良好。
简单概念验证
以下是对著名的海豹突击队复制粘贴文本的总结:
In this chapter, the monster explains how he intends to exact revenge on "the little b****" who insulted him. He tells the kiddo that he is a highly trained and experienced killer who will use his arsenal of weapons--including his access to the internet--to exact justice on the little brat.
虽然这是一个粗糙的例子,但你可以将这段复制粘贴文本输入其他总结模型,看看理解能力的差异(即使它甚至不是“长”文本!)。
✨ 主要特性
- 能够对长文本进行总结,生成类似SparkNotes的各主题摘要。
- 在学术和叙述性文本上有较好的泛化能力。
- XL版本在人工评估中能生成更优质的摘要。
📦 安装指南
安装transformers
库
pip install -U transformers
安装bitsandbytes
和accelerate
库(用于LLM.int8量化)
pip install -U transformers bitsandbytes accelerate
安装textsum
包(可选)
pip install textsum
💻 使用示例
基础用法
import torch
from transformers import pipeline
summarizer = pipeline(
"summarization",
"pszemraj/long-t5-tglobal-xl-16384-book-summary",
device=0 if torch.cuda.is_available() else -1,
)
long_text = "Here is a lot of text I don't want to read. Replace me"
result = summarizer(long_text)
print(result[0]["summary_text"])
高级用法
调整参数
在调用summarizer
时传递其他与波束搜索文本生成相关的参数,以获得更高质量的结果。
LLM.int8量化
通过此PR,long-t5
模型现在支持LLM.int8量化。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained(
"pszemraj/long-t5-tglobal-xl-16384-book-summary"
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"pszemraj/long-t5-tglobal-xl-16384-book-summary",
load_in_8bit=True,
device_map="auto",
)
使用textsum
包
from textsum.summarize import Summarizer
summarizer = Summarizer(
model_name_or_path="pszemraj/long-t5-tglobal-xl-16384-book-summary"
)
long_string = "This is a long string of text that will be summarized."
out_str = summarizer.summarize_string(long_string)
print(f"summary: {out_str}")
📚 详细文档
预期用途和限制
虽然该模型似乎提高了事实一致性,但不要将总结视为万无一失的,对于看起来奇怪的内容要进行检查。特别是否定陈述(即模型说:“这个东西没有[属性]”,而实际上应该说“这个东西有很多[属性]”)。你通常可以通过将特定陈述与周围句子的含义进行比较来检查。
训练和评估数据
使用HuggingFace上的kmfoda/booksum
数据集,阅读原始论文。
- 初始微调:出于内存原因,仅使用输入标记为12288或更少且输出标记为1024或更少的输入文本(即,在训练前丢弃更长的行)。经过快速分析,该数据集中12288 - 16384范围内的总结占少数。此外,初始训练将训练集和验证集合并,并对它们进行整体训练,以增加功能数据集的大小。因此,对验证集结果要持保留态度;主要指标应该(始终)是测试集。
- 最终微调阶段:使用标准的16384输入/1024输出约定,保留标准的输入/输出长度(并截断较长的序列)。这似乎对损失/性能影响不大。
评估结果
将使用模型评估器计算并公布官方结果。由于训练方法的原因,验证集上的性能看起来比测试集上的结果要好。该模型在评估集上取得以下结果:
-
eval_loss: 1.2756
-
eval_rouge1: 41.8013
-
eval_rouge2: 12.0895
-
eval_rougeL: 21.6007
-
eval_rougeLsum: 39.5382
-
eval_gen_len: 387.2945
-
eval_runtime: 13908.4995
-
eval_samples_per_second: 0.107
-
eval_steps_per_second: 0.027
***** predict/test metrics (initial) *****
predict_gen_len = 506.4368
predict_loss = 2.028
predict_rouge1 = 36.8815
predict_rouge2 = 8.0625
predict_rougeL = 17.6161
predict_rougeLsum = 34.9068
predict_runtime = 2:04:14.37
predict_samples = 1431
predict_samples_per_second = 0.192
predict_steps_per_second = 0.048
常见问题解答
如何在CPU上运行推理?
暂未提供相关详细说明。
如何对非常长(30k + 标记)的文档进行批量推理?
请参阅我的Hugging Face空间文档总结代码中的summarize.py
。你也可以使用相同的代码将文档拆分为4096等批次,并使用模型对它们进行迭代。这在CUDA内存有限的情况下很有用。
如何进一步微调模型?
请参阅使用脚本进行训练和总结脚本。
是否有更简单的运行方法?
可以使用textsum包,它提供了易于使用的接口,可将总结模型应用于任意长度的文本文档。目前实现的接口包括Python API、CLI和可共享的演示应用程序。
训练过程
更新
相关时,将在此处发布此模型/模型卡的更新。该模型似乎已基本收敛;如果使用BookSum
数据集可以进行更新/改进,此仓库将进行更新。
训练超参数
训练期间使用以下超参数:
- learning_rate: 0.0006
- train_batch_size: 1
- eval_batch_size: 1
- seed: 10350
- distributed_type: multi-GPU
- num_devices: 4
- gradient_accumulation_steps: 32
- total_train_batch_size: 128
- total_eval_batch_size: 4
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: constant
- num_epochs: 1.0
框架版本
- Transformers 4.25.0.dev0
- Pytorch 1.13.0+cu117
- Datasets 2.6.1
- Tokenizers 0.13.1
📄 许可证
本项目使用以下许可证: