🚀 基于 RoBERTa 微调の医学问诊意図識別モデル
このプロジェクトは、中科(安徽)G60 智慧健康創新研究院が開発する対話導診システムの一部で、ユーザーが入力した query
テキストの意図を識別し、【问诊】または【閑聊】を判別することができます。
🚀 クイックスタート
プロジェクト概要
- プロジェクトの起源:中科(安徽)G60 智慧健康創新研究院(以下「中科」と略称)が精神健康大モデルの開発に関連して構築した対話導診システムの意図識別タスクです。
- モデルの用途:ユーザーが対話システムに入力した
query
テキストの意図を識別し、【问诊】または【閑聊】を判別します。
データ説明
- データの出所:Hugging Face のオープンソース対話データセットと中科の内部の医療分野の対話データセットをクリーニングと前処理して融合して構築されました。
- データの分割:合計 6000 件のサンプルで、トレーニングセットに 4800 件、テストセットに 1200 件が含まれ、データ構築の過程で正負のサンプルのバランスが確保されています。
- データの例:
[
{
"query": "最近热门的5部电影叫什么名字",
"label": "nonmed"
},
{
"query": "关节疼痛,足痛可能是什么原因",
"label": "med"
},
{
"query": "最近出冷汗,肚子疼,恶心与呕吐,严重影响学习工作",
"label": "med"
}
]
実験環境
Featurize オンラインプラットフォームインスタンス:
- CPU:6 コア E5 - 2680 V4
- GPU:RTX3060、12.6GB ビデオメモリ
- プリインストールされたイメージ:Ubuntu 20.04、Python 3.9/3.10、PyTorch 2.0.1、TensorFlow 2.13.0、Docker 20.10.10、CUDA は最新バージョンを維持
- 手動でインストールする必要があるライブラリ:
pip install transformers datasets evaluate accelerate
学習方法
Hugging Face の transformers
ライブラリを使用して、哈工大讯飞联合实验室 (HFL) が公開した chinese - roberta - wwm - ext 中国語事前学習モデルをファインチューニングします。
学習パラメータ、結果と制限
{
output_dir: "output",
num_train_epochs: 2,
learning_rate: 3e-5,
lr_scheduler_type: "cosine",
per_device_train_batch_size: 16,
per_device_eval_batch_size: 16,
weight_decay: 0.01,
warmup_ratio: 0.02,
logging_steps: 0.01,
logging_strategy: "steps",
fp16: True,
eval_strategy: "steps",
eval_steps: 0.1,
save_strategy: 'epoch'
}
- ファインチューニングの結果
| データセット | 正解率 | F1 スコア |
| ------ | ------ | ------ |
| テストセット | 0.99 | 0.98 |
- 制限
全体的に、ファインチューニング後のモデルは医学问诊の意図識別において良好な結果を示しています。しかし、今回のモデル学習に使用されたデータ量が限られており、サンプルの多様性が不十分であるため、一部のケースでは結果にずれが生じる可能性があります。
使い方
基本的な使用法
単サンプル推論の例
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "闲聊", 1: "问诊"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = '这孩子目前28岁,情绪不好时经常无征兆吐血,呼吸系统和消化系统做过多次检查,没有检查出结果,最近三天连续早晨出现吐血现象'
tokenized_query = tokenizer(query, return_tensors='pt')
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_id = outputs.logits.argmax(-1).item()
intent = ID2LABEL[pred_id]
print(intent)
ターミナルの結果
问诊
高度な使用法
バッチデータ推論の例
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "闲聊", 1: "问诊"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = [
'胃痛,连续拉肚子好几天了,有时候半夜还呕吐',
'腿上的毛怎样去掉,不用任何药学和医学器械',
'你好,感冒咳嗽用什么药?',
'你觉得今天天气如何?我感觉咱可以去露营了!'
]
tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_ids = outputs.logits.argmax(-1).tolist()
intent = [ID2LABEL[pred_id] for pred_id in pred_ids]
print(intent)
ターミナルの結果
["问诊", "闲聊", "问诊", "闲聊"]
📄 ライセンス
このプロジェクトは Apache - 2.0 ライセンスの下で公開されています。
情報テーブル
属性 |
詳細 |
モデルタイプ |
テキスト分類 |
学習データ |
Hugging Face のオープンソース対話データセットと中科の内部の医療分野の対話データセットを融合したもの |
パイプラインタグ |
テキスト分類 |
タグ |
医療 |
評価指標 |
混同行列、正解率、F1 スコア |
ベースモデル |
hfl/chinese - roberta - wwm - ext |