🚀 medBERT-base
本仓库包含一个基于BERT的模型 medBERT-base,该模型在 gayanin/pubmed-gastro-maskfilling 数据集上针对**掩码语言模型(Masked Language Modeling,MLM)**任务进行了微调。该模型经过训练,可预测医学和胃肠病学文本中被掩码的标记。本项目的目标是提升模型在自然语言语境中对医学相关信息的理解和生成能力。

🚀 快速开始
本项目的medBERT-base模型是基于BERT架构,在特定医学数据集上微调得到,可用于掩码语言模型任务,帮助预测医学文本中被掩码的标记。
✨ 主要特性
- 基础模型:
bert-base-uncased
- 任务:针对医学文本的掩码语言模型(MLM)
- 分词器:BERT的WordPiece分词器
💻 使用示例
基础用法
你可以使用Hugging Face的 transformers
库加载预训练的 medBERT-base 模型:
from transformers import BertTokenizer, BertForMaskedLM
import torch
tokenizer = BertTokenizer.from_pretrained('suayptalha/medBERT-base')
model = BertForMaskedLM.from_pretrained('suayptalha/medBERT-base').to("cuda")
input_text = "Response to neoadjuvant chemotherapy best predicts survival [MASK] curative resection of gastric cancer."
inputs = tokenizer(input_text, return_tensors='pt').to("cuda")
outputs = model(**inputs)
masked_index = (inputs['input_ids'][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0].item()
top_k = 5
logits = outputs.logits[0, masked_index]
top_k_ids = torch.topk(logits, k=top_k).indices.tolist()
top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_ids)
print("Top 5 prediction:")
for i, token in enumerate(top_k_tokens):
print(f"{i + 1}: {token}")
前5个预测结果:
1: from
2: of
3: after
4: by
5: through
高级用法
要在你自己的医学数据集上微调 medBERT-base 模型,请按照以下步骤操作:
- 准备文本格式的数据集(例如,医学文本或胃肠病学相关信息)。
- 对数据集进行分词并应用掩码。
- 使用提供的训练循环训练模型。
以下是训练代码链接:
https://github.com/suayptalha/medBERT-base/blob/main/medBERT-base.ipynb
🔧 技术细节
超参数
- 批次大小:16
- 学习率:5e-5
- 训练轮数:1
- 最大序列长度:512个标记
数据集
- 数据集名称:gayanin/pubmed-gastro-maskfilling
- 任务:针对医学文本的掩码语言模型(MLM)
📄 许可证
本项目采用 apache-2.0
许可证。
致谢
- gayanin/pubmed-gastro-maskfilling 数据集可在Hugging Face数据集中心获取,它为训练提供了丰富的医学和胃肠病学相关信息。
- 本模型使用了Hugging Face的
transformers
库,这是一个用于NLP模型的先进库。
支持我们:
