🚀 使用HuggingFace Transformers进行摘要任务
本项目利用HuggingFace Transformers库实现文本摘要任务,通过预训练模型对输入文本进行处理,生成相应的摘要内容。
🚀 快速开始
安装依赖
确保你已经安装了HuggingFace Transformers库,可以使用以下命令进行安装:
pip install transformers
代码示例
以下是使用预训练模型进行文本摘要的示例代码:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("AlekseyKulnevich/Pegasus-Summarization")
tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
input_text =
input_ = tokenizer.batch_encode_plus([input_text], max_length=1024, pad_to_max_length=True,
truncation=True, padding='longest', return_tensors='pt')
input_ids = input_['input_ids']
input_mask = input_['attention_mask']
summary = model.generate(input_ids=input_ids,
attention_mask=input_mask,
num_beams=32,
min_length=100,
no_repeat_ngram_size=2,
early_stopping=True,
num_return_sequences=10)
questions = tokenizer.batch_decode(summary, skip_special_tokens=True)
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("AlekseyKulnevich/Pegasus-Summarization")
tokenizer = AutoTokenizer.from_pretrained('google/pegasus-large')
input_text =
input_ = tokenizer.batch_encode_plus([input_text], max_length=1024, pad_to_max_length=True,
truncation=True, padding='longest', return_tensors='pt')
input_ids = input_['input_ids']
input_mask = input_['attention_mask']
summary = model.generate(input_ids=input_ids,
attention_mask=input_mask,
num_beams=32,
min_length=100,
no_repeat_ngram_size=2,
early_stopping=True,
num_return_sequences=10)
questions = tokenizer.batch_decode(summary, skip_special_tokens=True)
高级用法
解码器配置示例
你可以参考以下链接中的输入文本进行测试:输入文本示例
summary = model.generate(input_ids=input_ids,
attention_mask=input_mask,
num_beams=32,
min_length=100,
no_repeat_ngram_size=2,
early_stopping=True,
num_return_sequences=1)
tokenizer.batch_decode(summary, skip_special_tokens=True)
输出示例:
- 根据政府间气候变化专门委员会(IPCC)和美国国家海洋和大气管理局(NOAA)发表的一项新研究,全球变暖将扩大热带气旋在世界中纬度地区的活动范围。该研究表明,气候变暖将使这类风暴能够在比过去300万年更广泛的范围内形成。“随着气候变暖,这些风暴可能会变得更加频繁和强烈,”该研究的作者说。
summary = model.generate(input_ids=input_ids,
attention_mask=input_mask,
top_k=30,
no_repeat_ngram_size=2,
early_stopping=True,
min_length=100,
num_return_sequences=1)
tokenizer.batch_decode(summary, skip_special_tokens=True)
输出示例:
- 根据政府间气候变化专门委员会(IPCC)发表的一项关于人类活动引起的气候变化对风暴影响的新研究,世界中纬度地区的热带气旋可能会形成更多这类风暴。该研究表明,气候变暖将增加亚热带气旋在比过去300万年更广泛的纬度范围内形成的可能性,包括赤道地区,并且更有可能在热带地区形成。
你还可以在generate
方法中调整以下参数:
有关生成文本参数的含义,你可以参考:参数含义说明