🚀 bert-chunker-2
bert-chunker-2 是一个基于 BERT 的文本分块器,它带有一个分类器头,用于预测文本块的起始标记(可用于检索增强生成(RAG)等场景)。通过使用滑动窗口,它可以将任意大小的文档切割成文本块。我们认为它是 语义分块器 的替代方案,特别之处在于,它不仅适用于结构化文本,还适用于非结构化和杂乱的文本。作为 bert-chunker 的新实验版本,它针对文章结构进行了增强,旨在实现语义分块和结构分块之间的平衡。它是一个经过训练的语义分块器和一个经过训练的结构分块器以 0.1:0.9 的线性权重合并而成。
一个更新、更成熟的版本是 bert-chunker-3。
GitHub
🚀 快速开始
运行以下代码:
import torch
from transformers import AutoConfig,AutoTokenizer,BertForTokenClassification
import math
model_path="tim1900/bert-chunker-2"
tokenizer = AutoTokenizer.from_pretrained(
model_path,
padding_side="right",
model_max_length=255,
trust_remote_code=True,
)
config = AutoConfig.from_pretrained(
model_path,
trust_remote_code=True,
)
device = 'cpu'
model = BertForTokenClassification.from_pretrained(model_path, ).to(device)
def chunk_text(model,text:str, tokenizer, prob_threshold=0.5)->list[str]:
MAX_TOKENS=255
tokens=tokenizer(text, return_tensors="pt",truncation=False)
input_ids=tokens['input_ids']
attention_mask=tokens['attention_mask'][:,0:MAX_TOKENS]
attention_mask=attention_mask.to(model.device)
CLS=input_ids[:,0].unsqueeze(0)
SEP=input_ids[:,-1].unsqueeze(0)
input_ids=input_ids[:,1:-1]
model.eval()
split_str_poses=[]
token_pos = []
windows_start =0
windows_end= 0
logits_threshold = math.log(1/prob_threshold-1)
print(f'Processing {input_ids.shape[1]} tokens...')
while windows_end <= input_ids.shape[1]:
windows_end= windows_start + MAX_TOKENS-2
ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
ids=ids.to(model.device)
output=model(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1],device=model.device))
logits = output['logits'][:, 1:-1,:]
chunk_decision = (logits[:,:,1]>(logits[:,:,0]-logits_threshold))
greater_rows_indices = torch.where(chunk_decision)[1].tolist()
if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
token_pos +=[sp + windows_start + 1 for sp in greater_rows_indices]
split_str_poses += split_str_pos
windows_start = greater_rows_indices[-1] + windows_start
else:
windows_start = windows_end
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
token_pos = [0] + token_pos
return substrings,token_pos
text='''In the heart of the bustling city, where towering skyscrapers touch the clouds and the symphony
of honking cars never ceases, Sarah, an aspiring novelist, found solace in the quiet corners of the ancient library
Surrounded by shelves that whispered stories of centuries past, she crafted her own world with words, oblivious to the rush outside Dr.Alexander Thompson, aboard the spaceship 'Pandora's Venture', was en route to the newly discovered exoplanet Zephyr-7.
As the lead astrobiologist of the expedition, his mission was to uncover signs of microbial life within the planet's subterranean ice caves.
With each passing light year, the anticipation of unraveling secrets that could alter humanity's
understanding of life in the universe grew ever stronger.'''
chunks, token_pos=chunk_text(model,text, tokenizer, prob_threshold=0.5)
for i, (c,t) in enumerate(zip(chunks,token_pos)):
print(f'-----chunk: {i}----token_idx: {t}--------')
print(c)
📄 许可证
本项目采用 Apache-2.0 许可证。