🚀 基于 RoBERTa 微调の医学问诊意図識別モデル
このプロジェクトは、心理健康の大規模モデルを開発するための対話診断システムにおける意図識別タスクです。ユーザーが対話システムに入力した query
テキストの意図を識別し、【診察】または【雑談】であるかを判断します。
🚀 クイックスタート
プロジェクト概要
- プロジェクトの由来:中科(安徽)G60智慧健康創新研究院(以下「中科」と略称)が心理健康の大規模モデル開発に関連して構築した対話診断システムの一部としての意図識別タスクです。
- モデルの用途:ユーザーが対話システムに入力した
query
テキストの意図を識別し、【診察】または【雑談】であるかを判断します。
データの説明
- データの出所:Hugging Faceのオープンソース対話データセットと中科の内部の医学対話データセットをクリーニングと前処理を行った後に融合して構築されました。
- データの分割:合計6000件のサンプルがあり、そのうち訓練セットに4800件、テストセットに1200件が含まれています。データ構築過程では正負のサンプルのバランスが保たれています。
- データの例:
[
{
"query": "最近人気のある5本の映画の名前は何ですか",
"label": "nonmed"
},
{
"query": "関節痛や足の痛みの原因は何でしょうか",
"label": "med"
},
{
"query": "最近冷汗が出て、お腹が痛く、吐き気と嘔吐があり、勉強や仕事に大きく影響しています",
"label": "med"
}
]
実験環境
Featurizeオンラインプラットフォームインスタンス:
- CPU:6コア E5 - 2680 V4
- GPU:RTX3060、12.6GBビデオメモリ
- プリインストールされたイメージ:Ubuntu 20.04、Python 3.9/3.10、PyTorch 2.0.1、TensorFlow 2.13.0、Docker 20.10.10、CUDAは最新バージョンを維持
- 手動でインストールする必要があるライブラリ:
pip install transformers datasets evaluate accelerate
訓練方法
Hugging Faceの transformers
ライブラリを使用して、哈工大讯飞共同研究所 (HFL) が公開した chinese - roberta - wwm - ext 中国語事前学習モデルを微調整します。
訓練パラメータ、結果と制限
{
output_dir: "output",
num_train_epochs: 2,
learning_rate: 3e-5,
lr_scheduler_type: "cosine",
per_device_train_batch_size: 16,
per_device_eval_batch_size: 16,
weight_decay: 0.01,
warmup_ratio: 0.02,
logging_steps: 0.01,
logging_strategy: "steps",
fp16: True,
eval_strategy: "steps",
eval_steps: 0.1,
save_strategy: 'epoch'
}
- 微調整の結果
| データセット | 正解率 | F1スコア |
| ------ | ------ | ------ |
| テストセット | 0.99 | 0.98 |
- 制限
全体的に、微調整後のモデルは医学診察の意図識別において良好な結果を示します。しかし、今回のモデル訓練に使用されたデータ量が限られており、サンプルの多様性が不十分であるため、一部のケースでは結果に偏差が生じる可能性があります。
使い方
基本的な使用法
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "雑談", 1: "診察"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = 'この子は現在28歳で、気分が悪いときにしばしば予兆なく吐血し、呼吸器系と消化器系の検査を何度も受けていますが、結果は出ていません。最近3日間、毎朝吐血する現象が続いています'
tokenized_query = tokenizer(query, return_tensors='pt')
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_id = outputs.logits.argmax(-1).item()
intent = ID2LABEL[pred_id]
print(intent)
ターミナルの結果
診察
高度な使用法
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
ID2LABEL = {0: "雑談", 1: "診察"}
MODEL_NAME = 'HZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
torch_dtype='auto'
)
query = [
'胃が痛く、何日も続けて下痢しています。時々真夜中に嘔吐することもあります',
'足の毛を薬や医療器具を使わずに取り除く方法はありますか',
'こんにちは、風邪で咳が出る場合はどんな薬を使えばいいですか?',
'今日の天気はどうですか?キャンプに行くのはいい天気だと思います!'
]
tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()}
outputs = model(**tokenized_query)
pred_ids = outputs.logits.argmax(-1).tolist()
intent = [ID2LABEL[pred_id] for pred_id in pred_ids]
print(intent)
ターミナルの結果
["診察", "雑談", "診察", "雑談"]
📄 ライセンス
このプロジェクトはApache - 2.0ライセンスの下で公開されています。
📚 その他の情報
項目 |
詳細 |
ベースモデル |
hfl/chinese - roberta - wwm - ext |
パイプラインタグ |
テキスト分類 |
タグ |
医学 |
評価指標 |
混同行列、正解率、F1スコア |