🚀 基於 RoBERTa 微調的醫學問診意圖識別模型
本項目是中科圍繞心理健康大模型研發的對話導診系統中的意圖識別任務,能夠對用戶輸入的文本進行意圖識別,判斷是【問診】還是【閒聊】,為醫學對話系統提供了精準的意圖判別能力。
🚀 快速開始
本項目是中科(安徽)G60智慧健康創新研究院圍繞心理健康大模型研發的對話導診系統中的意圖識別任務,下面為你詳細介紹該項目的相關信息。
✨ 主要特性
- 精準意圖識別:能夠準確判別用戶輸入文本是【問診】還是【閒聊】。
- 數據融合構建:融合開源對話數據集與中科內部垂域醫學對話數據集。
- 良好微調效果:在測試集上取得了較高的準確率和 F1 分數。
📦 安裝指南
在 Featurize 在線平臺實例 上進行實驗,需手動安裝以下庫:
pip install transformers datasets evaluate accelerate
📚 詳細文檔
項目來源
本項目來源於中科(安徽)G60智慧健康創新研究院圍繞心理健康大模型研發的對話導診系統,本項目為其中的意圖識別任務。
模型用途
將用戶輸入對話系統中的 query
文本進行意圖識別,判別其意向是【問診】or【閒聊】。
數據描述
- 數據來源:由 Hugging Face 的開源對話數據集,以及中科內部的垂域醫學對話數據集經過清洗和預處理融合構建而成。
- 數據劃分:共計 6000 條樣本,其中,訓練集 4800 條,測試集 1200 條,並在數據構建過程中確保了正負樣例的平衡。
- 數據樣例:
[
{
"query": "最近熱門的5部電影叫什麼名字",
"label": "nonmed"
},
{
"query": "關節疼痛,足痛可能是什麼原因",
"label": "med"
},
{
"query": "最近出冷汗,肚子疼,噁心與嘔吐,嚴重影響學習工作",
"label": "med"
}
]
實驗環境
Featurize 在線平臺實例:
- CPU:6核 E5-2680 V4
- GPU:RTX3060,12.6GB顯存
- 預裝鏡像:Ubuntu 20.04,Python 3.9/3.10,PyTorch 2.0.1,TensorFlow 2.13.0,Docker 20.10.10, CUDA 儘量維持在最新版本
訓練方式
基於 Hugging Face 的 transformers
庫對哈工大訊飛聯合實驗室 (HFL) 發佈的 chinese-roberta-wwm-ext 中文預訓練模型進行微調。
訓練參數、效果與侷限性
{
output_dir: "output",
num_train_epochs: 2,
learning_rate: 3e-5,
lr_scheduler_type: "cosine",
per_device_train_batch_size: 16,
per_device_eval_batch_size: 16,
weight_decay: 0.01,
warmup_ratio: 0.02,
logging_steps: 0.01,
logging_strategy: "steps",
fp16: True,
eval_strategy: "steps",
eval_steps: 0.1,
save_strategy: 'epoch'
}
- 微調效果:
| 數據集 | 準確率 | F1分數 |
| ------ | ------ | ------ |
| 測試集 | 0.99 | 0.98 |
- 侷限性:整體而言,微調後模型對於醫學問診的意圖識別效果不錯;但礙於本次用於模型訓練的數據量終究有限且樣本多樣性欠佳,故在某些情況下的效果可能存在偏差。
💻 使用示例
基礎用法
單樣本推理示例:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "閒聊", 1: "問診"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = '這孩子目前28歲,情緒不好時經常無徵兆吐血,呼吸系統和消化系統做過多次檢查,沒有檢查出結果,最近三天連續早晨出現吐血現象'
tokenized_query = tokenizer(query, return_tensors='pt')
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_id = outputs.logits.argmax(-1).item()
intent = ID2LABEL[pred_id]
print(intent)
終端結果:
問診
高級用法
批次數據推理示例:
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "閒聊", 1: "問診"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = [
'胃痛,連續拉肚子好幾天了,有時候半夜還嘔吐',
'腿上的毛怎樣去掉,不用任何藥學和醫學器械',
'你好,感冒咳嗽用什麼藥?',
'你覺得今天天氣如何?我感覺咱可以去露營了!'
]
tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_ids = outputs.logits.argmax(-1).tolist()
intent = [ID2LABEL[pred_id] for pred_id in pred_ids]
print(intent)
終端結果:
["問診", "閒聊", "問診", "閒聊"]
🔧 技術細節
- 模型基礎:基於哈工大訊飛聯合實驗室 (HFL) 發佈的 chinese-roberta-wwm-ext 中文預訓練模型。
- 訓練框架:使用 Hugging Face 的
transformers
庫進行微調。
- 評估指標:使用混淆矩陣、準確率、F1 分數等指標進行評估。
📄 許可證
本項目採用 Apache-2.0 許可證。
屬性 |
詳情 |
模型類型 |
基於 RoBERTa 微調的文本分類模型 |
訓練數據 |
由 Hugging Face 的開源對話數據集與中科內部垂域醫學對話數據集融合構建 |
基礎模型 |
hfl/chinese-roberta-wwm-ext |
任務類型 |
文本分類 |
標籤 |
醫學 |
評估指標 |
混淆矩陣、準確率、F1 分數 |