🚀 基于 RoBERTa 微调的医学问诊意图识别模型
本项目是中科围绕心理健康大模型研发的对话导诊系统中的意图识别任务,能够对用户输入的文本进行意图识别,判断是【问诊】还是【闲聊】,为医学对话系统提供了精准的意图判别能力。
🚀 快速开始
本项目是中科(安徽)G60智慧健康创新研究院围绕心理健康大模型研发的对话导诊系统中的意图识别任务,下面为你详细介绍该项目的相关信息。
✨ 主要特性
- 精准意图识别:能够准确判别用户输入文本是【问诊】还是【闲聊】。
- 数据融合构建:融合开源对话数据集与中科内部垂域医学对话数据集。
- 良好微调效果:在测试集上取得了较高的准确率和 F1 分数。
📦 安装指南
在 Featurize 在线平台实例 上进行实验,需手动安装以下库:
pip install transformers datasets evaluate accelerate
📚 详细文档
项目来源
本项目来源于中科(安徽)G60智慧健康创新研究院围绕心理健康大模型研发的对话导诊系统,本项目为其中的意图识别任务。
模型用途
将用户输入对话系统中的 query
文本进行意图识别,判别其意向是【问诊】or【闲聊】。
数据描述
- 数据来源:由 Hugging Face 的开源对话数据集,以及中科内部的垂域医学对话数据集经过清洗和预处理融合构建而成。
- 数据划分:共计 6000 条样本,其中,训练集 4800 条,测试集 1200 条,并在数据构建过程中确保了正负样例的平衡。
- 数据样例:
[
{
"query": "最近热门的5部电影叫什么名字",
"label": "nonmed"
},
{
"query": "关节疼痛,足痛可能是什么原因",
"label": "med"
},
{
"query": "最近出冷汗,肚子疼,恶心与呕吐,严重影响学习工作",
"label": "med"
}
]
实验环境
Featurize 在线平台实例:
- CPU:6核 E5-2680 V4
- GPU:RTX3060,12.6GB显存
- 预装镜像:Ubuntu 20.04,Python 3.9/3.10,PyTorch 2.0.1,TensorFlow 2.13.0,Docker 20.10.10, CUDA 尽量维持在最新版本
训练方式
基于 Hugging Face 的 transformers
库对哈工大讯飞联合实验室 (HFL) 发布的 chinese-roberta-wwm-ext 中文预训练模型进行微调。
训练参数、效果与局限性
{
output_dir: "output",
num_train_epochs: 2,
learning_rate: 3e-5,
lr_scheduler_type: "cosine",
per_device_train_batch_size: 16,
per_device_eval_batch_size: 16,
weight_decay: 0.01,
warmup_ratio: 0.02,
logging_steps: 0.01,
logging_strategy: "steps",
fp16: True,
eval_strategy: "steps",
eval_steps: 0.1,
save_strategy: 'epoch'
}
- 微调效果:
| 数据集 | 准确率 | F1分数 |
| ------ | ------ | ------ |
| 测试集 | 0.99 | 0.98 |
- 局限性:整体而言,微调后模型对于医学问诊的意图识别效果不错;但碍于本次用于模型训练的数据量终究有限且样本多样性欠佳,故在某些情况下的效果可能存在偏差。
💻 使用示例
基础用法
单样本推理示例:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "闲聊", 1: "问诊"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = '这孩子目前28岁,情绪不好时经常无征兆吐血,呼吸系统和消化系统做过多次检查,没有检查出结果,最近三天连续早晨出现吐血现象'
tokenized_query = tokenizer(query, return_tensors='pt')
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_id = outputs.logits.argmax(-1).item()
intent = ID2LABEL[pred_id]
print(intent)
终端结果:
问诊
高级用法
批次数据推理示例:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "闲聊", 1: "问诊"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = [
'胃痛,连续拉肚子好几天了,有时候半夜还呕吐',
'腿上的毛怎样去掉,不用任何药学和医学器械',
'你好,感冒咳嗽用什么药?',
'你觉得今天天气如何?我感觉咱可以去露营了!'
]
tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_ids = outputs.logits.argmax(-1).tolist()
intent = [ID2LABEL[pred_id] for pred_id in pred_ids]
print(intent)
终端结果:
["问诊", "闲聊", "问诊", "闲聊"]
🔧 技术细节
- 模型基础:基于哈工大讯飞联合实验室 (HFL) 发布的 chinese-roberta-wwm-ext 中文预训练模型。
- 训练框架:使用 Hugging Face 的
transformers
库进行微调。
- 评估指标:使用混淆矩阵、准确率、F1 分数等指标进行评估。
📄 许可证
本项目采用 Apache-2.0 许可证。
属性 |
详情 |
模型类型 |
基于 RoBERTa 微调的文本分类模型 |
训练数据 |
由 Hugging Face 的开源对话数据集与中科内部垂域医学对话数据集融合构建 |
基础模型 |
hfl/chinese-roberta-wwm-ext |
任务类型 |
文本分类 |
标签 |
医学 |
评估指标 |
混淆矩阵、准确率、F1 分数 |