🚀 西班牙RoBERTa2RoBERTa (roberta-base-bne) 在MLSUM ES数据集上微调用于摘要生成
本项目是将西班牙RoBERTa2RoBERTa (roberta-base-bne) 模型在MLSUM ES数据集上进行微调,以实现文本摘要生成功能。该模型能够对文本进行有效概括,为用户提供简洁的文本摘要。
🚀 快速开始
模型使用示例
以下是使用该模型进行文本摘要生成的Python代码示例:
import torch
from transformers import RobertaTokenizerFast, EncoderDecoderModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = 'Narrativa/bsc_roberta2roberta_shared-spanish-finetuned-mlsum-summarization'
tokenizer = RobertaTokenizerFast.from_pretrained(ckpt)
model = EncoderDecoderModel.from_pretrained(ckpt).to(device)
def generate_summary(text):
inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output = model.generate(input_ids, attention_mask=attention_mask)
return tokenizer.decode(output[0], skip_special_tokens=True)
text = "Your text here..."
generate_summary(text)
✨ 主要特性
📦 安装指南
文档中未提及具体安装步骤,若需使用该模型,可参考Hugging Face相关库的安装方式,确保安装 torch
和 transformers
库。
💻 使用示例
基础用法
import torch
from transformers import RobertaTokenizerFast, EncoderDecoderModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = 'Narrativa/bsc_roberta2roberta_shared-spanish-finetuned-mlsum-summarization'
tokenizer = RobertaTokenizerFast.from_pretrained(ckpt)
model = EncoderDecoderModel.from_pretrained(ckpt).to(device)
def generate_summary(text):
inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output = model.generate(input_ids, attention_mask=attention_mask)
return tokenizer.decode(output[0], skip_special_tokens=True)
text = "Your text here..."
generate_summary(text)
高级用法
文档中未提及高级用法相关代码,可根据模型的API文档进一步探索模型的参数调整等高级功能。
📚 详细文档
模型
使用的基础模型为 BSC-TeMU/roberta-base-bne ,这是一个RoBERTa检查点模型。
数据集
MLSUM 是第一个大规模多语言摘要数据集。它从在线报纸获取,包含超过150万篇文章及其摘要对,涵盖五种不同语言,即法语、德语、西班牙语、俄语和土耳其语。与流行的CNN/Daily mail数据集中的英文报纸文章一起,收集的数据形成了一个大规模多语言数据集,为文本摘要社区开辟了新的研究方向。
MLSUM es
结果
属性 |
详情 |
测试集Rouge2 - mid - precision |
11.42 |
测试集Rouge2 - mid - recall |
10.58 |
测试集Rouge2 - mid - fmeasure |
10.69 |
测试集Rouge1 - fmeasure |
28.83 |
测试集RougeL - fmeasure |
23.15 |
原始指标计算代码如下:
rouge = datasets.load_metric("rouge")
rouge.compute(predictions=results["pred_summary"], references=results["summary"])
{'rouge1': AggregateScore(low=Score(precision=0.30393366820245, recall=0.27905239591639935, fmeasure=0.283148902808752), mid=Score(precision=0.3068521142101569, recall=0.2817252494122592, fmeasure=0.28560373425206464), high=Score(precision=0.30972608774202665, recall=0.28458152325781716, fmeasure=0.2883786700591887)),
'rougeL': AggregateScore(low=Score(precision=0.24184668819794716, recall=0.22401171380621518, fmeasure=0.22624104698839514), mid=Score(precision=0.24470388406868163, recall=0.22665793214539162, fmeasure=0.2289118878817394), high=Score(precision=0.2476594458951327, recall=0.22932683203591905, fmeasure=0.23153001570662513))}
rouge.compute(predictions=results["pred_summary"], references=results["summary"], rouge_types=["rouge2"])["rouge2"].mid
Score(precision=0.11423200347113865, recall=0.10588038944902506, fmeasure=0.1069921217219595)
🔧 技术细节
文档中未提及具体技术细节相关内容,可参考模型的官方文档或相关研究论文获取更多信息。
📄 许可证
文档中未提及许可证相关信息。
本项目由 Narrativa 创建。关于Narrativa:专注于自然语言生成 (NLG) ,其基于机器学习的平台Gabriele可构建和部署自然语言解决方案。 #NLG #AI