🚀 大型心理健康评估语言模型
本大型语言模型主要用于通过分析用户(说话者、作者、患者等)的文本或语音输入,评估心理健康问题的严重程度。训练数据集包含精神科医生根据不同程度心理健康问题患者的文本或语音做出的诊断。
该模型具有多种用途。例如,它可以协助医生诊断患者/特定个体的心理健康状况,也可以帮助个人进行自我诊断以了解自身的心理健康状况,还能分析虚构叙事中人物的心理特征。
该模型在测试数据集(30477 行)中的表现如下:准确率为 0.78,F1 值为 0.77。
此模型是我微调开源大语言模型以预测人类各种认知能力(如性格、态度、精神状态等)项目的一部分。
🚀 快速开始
以下测试示例可在 API 栏中使用:
- "I was okay just a moment ago. I will learn how to be okay again."
- "There were days when she was unhappy; she did not know why, when it did not seem worthwhile to be glad or sorry, to be alive or dead; when life appeared to her like a grotesque pandemonium and humanity like worms struggling blindly toward inevitable annihilation".
- "I hope to one day see a sea of people all wearing silver ribbons as a sign that they understand the secret battle and as a celebration of the victories made each day as we individually pull ourselves up out of our foxholes to see our scars heal and to remember what the sun looks like."
输出会分配一个 0 到 5 的标签,以对心理健康问题的严重程度进行分类。标签为 0 表示最低严重程度,意味着几乎没有或没有心理健康问题的症状。相反,标签为 5 表示最高严重程度,反映出严重的心理健康状况,可能需要立即进行全面干预。数值越大,情况可能越严重。请注意!
💻 使用示例
基础用法
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AutoConfig
model_path = "KevSun/mentalhealth_LM"
config = AutoConfig.from_pretrained(model_path, num_labels=6, problem_type="single_label_classification")
tokenizer = BertTokenizer.from_pretrained(model_path, use_fast=True)
model = BertForSequenceClassification.from_pretrained(model_path, config=config, ignore_mismatched_sizes=True)
def predict_text(text, model, tokenizer):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
max_probability, predicted_class_index = torch.max(probabilities, dim=-1)
return predicted_class_index.item(), max_probability.item(), probabilities.numpy()
text = "I was okay just a moment ago. I will learn how to be okay again."
predicted_class, max_prob, probs = predict_text(text, model, tokenizer)
print(f"Predicted class: {predicted_class}, Probability: {max_prob:.4f}")
📄 许可证
本项目采用 Apache-2.0 许可证。