๐ Medical Diagnosis AI Model - Powered by Mistral-7B & LoRA
This medical diagnosis AI model, powered by Mistral-7B and LoRA, offers accurate medical diagnoses and step-by-step reasoning. It's designed to assist healthcare professionals in making better clinical decisions.
๐ Quick Start
Use the following code to start using the model:
!pip install -q -U bitsandbytes
!pip install -q -U peft
!pip install -q -U trl
!pip install -q -U tensorboardX
!pip install -q wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel")
tokenizer = AutoTokenizer.from_pretrained("ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel")
prompt = "Patient reports chest pain and dizziness with nose bleeding, Whatโs the likely diagnosis is it cancer ?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Advanced Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct_FullModel")
tokenizer = AutoTokenizer.from_pretrained("ritvik77/Medical_Doctor_AI_LoRA-Mistral-7B-Instruct")
prompt = "Patient reports a long - term cough and fatigue. What could be the diagnosis?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
โจ Features
- Accurate Diagnoses: Provides accurate diagnoses for symptoms like chest pain, dizziness, and breathlessness.
- Step - by - Step Reasoning: Uses Chain - of - Thought (CoT) prompting for step - by - step medical reasoning.
- Efficient Inference: Reduces VRAM usage, ideal for GPUs with limited memory.
๐ฆ Installation
!pip install -q -U bitsandbytes
!pip install -q -U peft
!pip install -q -U trl
!pip install -q -U tensorboardX
!pip install -q wandb
๐ Documentation
Model Details
- Base Model: Mistral - 7B (7.7 billion parameters)
- Fine - Tuning Method: LoRA (Low - Rank Adaptation)
- Quantization: bnb_4bit (reduces memory footprint while retaining performance)
- Original Mistral - 7B Parameters: 7.7 billion
- LoRA Fine - Tuned Parameters: ~4.48% of total model parameters (~340 million)
- Final Merged Model Size (bnb_4bit Quantized): ~4.5GB
Model Description
This model leverages the powerful Mistral - 7B language model, known for its strong reasoning capabilities and deep language understanding. Through LoRA fine - tuning, it excels in medical - specific tasks such as diagnosing conditions from symptoms and providing detailed medical reasoning.
- Developed by: [Ritvik Gaur]
- Model type: [Medical LLM]
- License: Apache - 2.0
- Finetuned from model: [Mistral - 7B - Instruct - v3]
Training Procedure
Training Hyperparameters
Parameter |
Value |
Description |
Base Model |
mistralai/Mistral - 7B - Instruct |
Chosen for its strong reasoning capabilities. |
Fine - Tuning Framework |
LoRA (Low - Rank Adaptation) |
Efficiently fine - tuned only ~4.48% of total parameters. |
Quantization |
bnb_4bit |
Enabled for reduced VRAM consumption. |
Train Batch Size |
12 |
Optimized to balance GPU utilization and convergence. |
Eval Batch Size |
12 |
Matches training batch size to ensure stable evaluation. |
Gradient Accumulation Steps |
3 |
Effective batch size = 36 for improved stability. |
Learning Rate |
3e - 5 |
Lowered to ensure smoother convergence |
Warmup Ratio |
0.2 |
Gradual learning rate ramp - up for improved stability |
Scheduler Type |
Cosine |
Ensures smooth and controlled learning rate decay |
Number of Epochs |
5 |
Balanced to ensure convergence without overfitting |
Max Gradient Norm |
0.5 |
Prevents exploding gradients |
Weight Decay |
0.08 |
Regularization for improved generalization |
bf16 Precision |
True |
Maximizes GPU utilization and precision |
Gradient Checkpointing |
Enabled |
Reduces memory usage during training |
LoRA Configuration
Parameter |
Value |
Description |
Rank Dimension |
128 |
Balanced for strong expressiveness without excessive memory overhead |
LoRA Alpha |
128 |
Ensures stable gradient updates |
LoRA Dropout |
0.1 |
Helps prevent overfitting |
๐ง Technical Details
The model uses the Mistral - 7B base model and fine - tunes it using LoRA. The bnb_4bit quantization is applied to reduce the memory footprint. The training hyperparameters are carefully selected to balance performance, convergence, and generalization.
๐ License
This model is licensed under the Apache - 2.0 license.
โ ๏ธ Important Note
Please don't fully rely on this model for real - life illness diagnosis. This model is just for support of real verified health applications that require LLM.
๐ก Usage Tip
Users (both direct and downstream) should be aware of the risks, biases, and limitations of the model.