🚀 基於 RoBERTa 微調的醫學問診意圖識別模型
本項目是中科圍繞心理健康大模型研發的對話導診系統中的意圖識別任務,能對用戶輸入的 query
文本進行意圖識別,判斷是【問診】還是【閒聊】,為醫學對話交互提供精準支持。
🚀 快速開始
單樣本推理示例
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "閒聊", 1: "問診"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = '這孩子目前28歲,情緒不好時經常無徵兆吐血,呼吸系統和消化系統做過多次檢查,沒有檢查出結果,最近三天連續早晨出現吐血現象'
tokenized_query = tokenizer(query, return_tensors='pt')
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_id = outputs.logits.argmax(-1).item()
intent = ID2LABEL[pred_id]
print(intent)
終端結果
問診
批次數據推理示例
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "閒聊", 1: "問診"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = [
'胃痛,連續拉肚子好幾天了,有時候半夜還嘔吐',
'腿上的毛怎樣去掉,不用任何藥學和醫學器械',
'你好,感冒咳嗽用什麼藥?',
'你覺得今天天氣如何?我感覺咱可以去露營了!'
]
tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_ids = outputs.logits.argmax(-1).tolist()
intent = [ID2LABEL[pred_id] for pred_id in pred_ids]
print(intent)
終端結果
["問診", "閒聊", "問診", "閒聊"]
✨ 主要特性
- 精準意圖識別:能準確判別用戶輸入文本是【問診】還是【閒聊】意圖。
- 數據融合構建:融合開源與內部垂域醫學對話數據集,確保數據多樣性。
- 微調預訓練模型:基於
transformers
庫微調 chinese - roberta - wwm - ext
模型,提升性能。
📦 安裝指南
在 Featurize 在線平臺實例 上,需手動安裝以下庫:
pip install transformers datasets evaluate accelerate
平臺環境信息:
- CPU:6核 E5 - 2680 V4
- GPU:RTX3060,12.6GB顯存
- 預裝鏡像:Ubuntu 20.04,Python 3.9/3.10,PyTorch 2.0.1,TensorFlow 2.13.0,Docker 20.10.10, CUDA 儘量維持在最新版本
📚 詳細文檔
項目簡介
- 項目來源:中科(安徽)G60智慧健康創新研究院(以下簡稱 “中科”)圍繞心理健康大模型研發的對話導診系統,本項目為其中的意圖識別任務。
- 模型用途:將用戶輸入對話系統中的
query
文本進行意圖識別,判別其意向是【問診】or【閒聊】。
數據描述
- 數據來源:由 Hugging Face 的開源對話數據集,以及中科內部的垂域醫學對話數據集經過清洗和預處理融合構建而成。
- 數據劃分:共計 6000 條樣本,其中,訓練集 4800 條,測試集1200 條,並在數據構建過程中確保了正負樣例的平衡。
- 數據樣例:
[
{
"query": "最近熱門的5部電影叫什麼名字",
"label": "nonmed"
},
{
"query": "關節疼痛,足痛可能是什麼原因",
"label": "med"
},
{
"query": "最近出冷汗,肚子疼,噁心與嘔吐,嚴重影響學習工作",
"label": "med"
}
]
訓練方式
基於 Hugging Face 的 transformers
庫對哈工大訊飛聯合實驗室 (HFL) 發佈的 [chinese - roberta - wwm - ext](https://github.com/ymcui/Chinese - BERT - wwm) 中文預訓練模型進行微調。
訓練參數、效果與侷限性
訓練參數
{
output_dir: "output",
num_train_epochs: 2,
learning_rate: 3e-5,
lr_scheduler_type: "cosine",
per_device_train_batch_size: 16,
per_device_eval_batch_size: 16,
weight_decay: 0.01,
warmup_ratio: 0.02,
logging_steps: 0.01,
logging_strategy: "steps",
fp16: True,
eval_strategy: "steps",
eval_steps: 0.1,
save_strategy: 'epoch'
}
微調效果
數據集 |
準確率 |
F1分數 |
測試集 |
0.99 |
0.98 |
侷限性
整體而言,微調後模型對於醫學問診的意圖識別效果不錯;但礙於本次用於模型訓練的數據量終究有限且樣本多樣性欠佳,故在某些情況下的效果可能存在偏差。
🔧 技術細節
- 模型選擇:選擇
chinese - roberta - wwm - ext
預訓練模型,因其在中文任務上有較好的表現,通過微調可適配醫學問診意圖識別任務。
- 數據處理:融合開源與內部數據集,清洗和預處理確保數據質量,劃分訓練集和測試集保證模型泛化能力。
- 訓練優化:使用
transformers
庫進行微調,設置合適的訓練參數,如學習率、批次大小等,提升模型性能。
📄 許可證
本項目採用 apache - 2.0
許可證。
📋 其他信息
屬性 |
詳情 |
模型類型 |
基於 RoBERTa 微調的醫學問診意圖識別模型 |
訓練數據 |
由 Hugging Face 的開源對話數據集和中科內部垂域醫學對話數據集融合構建,共 6000 條樣本 |
評估指標 |
混淆矩陣、準確率、F1分數 |
應用領域 |
醫學 |