🚀 RuBERTExtSumGazeta
本项目基于 rubert-base-cased 模型,开发了用于抽取式文本摘要的模型,可有效处理相关文本的摘要提取任务。
🚀 快速开始
如何使用
你可以通过以下 Colab 链接体验:link
import razdel
from transformers import AutoTokenizer, BertForTokenClassification
model_name = "IlyaGusev/rubert_ext_sum_gazeta"
tokenizer = AutoTokenizer.from_pretrained(model_name)
sep_token = tokenizer.sep_token
sep_token_id = tokenizer.sep_token_id
model = BertForTokenClassification.from_pretrained(model_name)
article_text = "..."
sentences = [s.text for s in razdel.sentenize(article_text)]
article_text = sep_token.join(sentences)
inputs = tokenizer(
[article_text],
max_length=500,
padding="max_length",
truncation=True,
return_tensors="pt",
)
sep_mask = inputs["input_ids"][0] == sep_token_id
current_token_type_id = 0
for pos, input_id in enumerate(inputs["input_ids"][0]):
inputs["token_type_ids"][0][pos] = current_token_type_id
if input_id == sep_token_id:
current_token_type_id = 1 - current_token_type_id
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, :, 1]
logits = logits[sep_mask]
logits, indices = logits.sort(descending=True)
logits, indices = logits.cpu().tolist(), indices.cpu().tolist()
pairs = list(zip(logits, indices))
pairs = pairs[:3]
indices = list(sorted([idx for _, idx in pairs]))
summary = " ".join([sentences[idx] for idx in indices])
print(summary)
✨ 主要特性
本模型基于 rubert-base-cased 构建,可实现抽取式文本摘要功能。
📦 安装指南
文档未提供具体安装步骤,暂不展示。
💻 使用示例
基础用法
import razdel
from transformers import AutoTokenizer, BertForTokenClassification
model_name = "IlyaGusev/rubert_ext_sum_gazeta"
tokenizer = AutoTokenizer.from_pretrained(model_name)
sep_token = tokenizer.sep_token
sep_token_id = tokenizer.sep_token_id
model = BertForTokenClassification.from_pretrained(model_name)
article_text = "..."
sentences = [s.text for s in razdel.sentenize(article_text)]
article_text = sep_token.join(sentences)
inputs = tokenizer(
[article_text],
max_length=500,
padding="max_length",
truncation=True,
return_tensors="pt",
)
sep_mask = inputs["input_ids"][0] == sep_token_id
current_token_type_id = 0
for pos, input_id in enumerate(inputs["input_ids"][0]):
inputs["token_type_ids"][0][pos] = current_token_type_id
if input_id == sep_token_id:
current_token_type_id = 1 - current_token_type_id
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits[0, :, 1]
logits = logits[sep_mask]
logits, indices = logits.sort(descending=True)
logits, indices = logits.cpu().tolist(), indices.cpu().tolist()
pairs = list(zip(logits, indices))
pairs = pairs[:3]
indices = list(sorted([idx for _, idx in pairs]))
summary = " ".join([sentences[idx] for idx in indices])
print(summary)
高级用法
文档未提供高级用法代码示例,暂不展示。
📚 详细文档
预期用途与限制
限制与偏差
- 该模型在处理 Gazeta.ru 的文章时效果较好,但对于其他机构的文章,可能会出现领域偏移的问题。
训练数据
训练过程
文档未提供具体训练过程信息,暂不展示。
评估结果
评估脚本:https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
评估参数:--language ru --tokenize-after --lower
文档未提供具体评估结果信息,暂不展示。
📄 许可证
本项目采用 apache-2.0
许可证。