🚀 BMRetriever-410Mモデル
このモデルは、医療および生物学の分野におけるテキスト検索に特化した大規模言語モデルです。論文「BMRetriever: Tuning Large Language Models as Better Biomedical Text Retrievers」の手法に基づいてファインチューニングされています。
🔍 モデル情報
属性 |
详情 |
モデルタイプ |
ファインチューニング済み大規模言語モデル |
学習データ |
MedRAG/textbooks、MedRAG/pubmed、MedRAG/statpearls、mteb/raw_biorxiv、mteb/raw_medrxiv、ms_marco、BMRetriever/biomed_retrieval_dataset |
🚀 クイックスタート
このモデルは、論文「BMRetriever: Tuning Large Language Models as Better Biomedical Text Retrievers」(EMNLP 2024で発表)に記載されたアプローチに従ってファインチューニングされています。関連するGitHubリポジトリはこちらです。
このモデルは4.1億個のパラメータを持っています。詳細は論文を参照してください。
📦 インストール
事前学習済みモデルは、HuggingFaceのtransformersライブラリを通じてロードすることができます。
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("BMRetriever/BMRetriever-410M")
tokenizer = AutoTokenizer.from_pretrained("BMRetriever/BMRetriever-410M")
💻 使用例
基本的な使用法
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
embedding = last_hidden[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
embedding = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
return embedding
def get_detailed_instruct_query(task_description: str, query: str) -> str:
return f'{task_description}\nQuery: {query}'
def get_detailed_instruct_passage(passage: str) -> str:
return f'Represent this passage\npassage: {passage}'
task = 'Given a scientific claim, retrieve documents that support or refute the claim'
queries = [
get_detailed_instruct_query(task, 'Cis-acting lncRNAs control the expression of genes that are positioned in the vicinity of their transcription sites.'),
get_detailed_instruct_query(task, 'Forkhead 0 (fox0) transcription factors are involved in apoptosis.')
]
documents = [
get_detailed_instruct_passage("Gene regulation by the act of long non-coding RNA transcription Long non-protein-coding RNAs (lncRNAs) are proposed to be the largest transcript class in the mouse and human transcriptomes. Two important questions are whether all lncRNAs are functional and how they could exert a function. Several lncRNAs have been shown to function through their product, but this is not the only possible mode of action. In this review we focus on a role for the process of lncRNA transcription, independent of the lncRNA product, in regulating protein-coding-gene activity in cis. We discuss examples where lncRNA transcription leads to gene silencing or activation, and describe strategies to determine if the lncRNA product or its transcription causes the regulatory effect."),
get_detailed_instruct_passage("Noncoding transcription at enhancers: general principles and functional models. Mammalian genomes are extensively transcribed outside the borders of protein-coding genes. Genome-wide studies recently demonstrated that cis-regulatory genomic elements implicated in transcriptional control, such as enhancers and locus-control regions, represent major sites of extragenic noncoding transcription. Enhancer-templated transcripts provide a quantitatively small contribution to the total amount of cellular nonribosomal RNA; nevertheless, the possibility that enhancer transcription and the resulting enhancer RNAs may, in some cases, have functional roles, rather than represent mere transcriptional noise at accessible genomic regions, is supported by an increasing amount of experimental data. In this article we review the current knowledge on enhancer transcription and its functional implications.")
]
input_texts = queries + documents
max_length = 512
batch_dict = tokenizer(input_texts, max_length=max_length-1, padding=True, truncation=True, return_tensors='pt')
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt').to("cuda")
model.eval()
with torch.no_grad():
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
高度な使用法
異なる文間の類似度スコアは、埋め込み間のドット積によって得られます。
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
📄 ライセンス
このリポジトリはMITライセンスの下で公開されています。
📚 引用
このリポジトリが役立った場合、対応する論文を引用していただけると幸いです。ありがとうございます!
@inproceedings{xu2024bmretriever,
title={BMRetriever: Tuning Large Language Models as Better Biomedical Text Retrievers},
author={Ran Xu and Wenqi Shi and Yue Yu and Yuchen Zhuang and Yanqiao Zhu and May D. Wang and Joyce C. Ho and Chao Zhang and Carl Yang},
year={2024},
booktitle={Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing},
}