🚀 DistilBERT 占卜問題檢測模型
本項目提供了一個基於 DistilBERT
的占卜問題檢測模型,可用於判斷輸入文本是否為符合塔羅占卜的問題,為塔羅占卜相關應用提供了有力的文本分類支持。
🚀 快速開始
1️⃣ 安裝依賴
請確保你的環境已安裝 Python 3.8+,然後運行以下命令安裝所需的依賴庫:
pip install torch transformers fastapi uvicorn safetensors
2️⃣ 直接運行推理
如果你想直接在本地測試模型,可以運行 inference.py
:
python inference.py
基礎用法
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
model_path = "./distilbert-question-detector"
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
text = "Is this a question?"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
print(f"Probabilities: {probabilities}")
print(f"Predicted class: {predicted_class}")
3️⃣ 運行 API
你也可以使用 FastAPI 部署一個 HTTP 接口,允許其他應用通過 HTTP 請求訪問模型。
uvicorn app:app --host 0.0.0.0 --port 8000
高級用法
from fastapi import FastAPI
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
app = FastAPI()
model_path = "./distilbert-question-detector/checkpoint-5150"
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
@app.post("/predict/")
async def predict(text: str):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class}
API 運行後,可通過以下方式測試:
curl -X 'POST' \
'http://127.0.0.1:8000/predict/' \
-H 'Content-Type: application/json' \
-d '{"text": "Is this a valid question?"}'
📂 目錄結構
屬性 |
詳情 |
model.safetensors |
訓練好的模型權重文件 |
config.json |
模型架構的配置文件 |
tokenizer.json |
分詞器的配置文件 |
special_tokens_map.json |
特殊標記的配置文件 |
vocab.txt |
分詞器的詞彙文件 |
📌 結果說明
predicted_class: 0
代表輸入文本是符合條件
predicted_class: 1
代表輸入文本不符合條件
示例結果
{
"text": "Is this a valid question?",
"probabilities": [[0.9266, 0.0734]],
"predicted_class": 0
}
📄 許可證
本項目採用 AFL-3.0 許可證。