🚀 DistilBERT 占い質問検出モデル
本プロジェクトでは、DistilBERT
ベースの占い質問検出モデルを提供しています。このモデルは、入力されたテキストがタロット占いに適した質問かどうかを判断することができます。
📂 ディレクトリ構造
model.safetensors
: 学習済みのモデルの重みです。
config.json
: モデルアーキテクチャの設定ファイルです。
tokenizer.json
: トークナイザーの設定です。
special_tokens_map.json
: 特殊トークンの設定です。
vocab.txt
: トークナイザーの語彙ファイルです。
🚀 クイックスタート
1️⃣ 依存関係のインストール
Python 3.8 以上がインストールされた環境で、以下のコマンドを実行して必要な依存ライブラリをインストールしてください。
pip install torch transformers fastapi uvicorn safetensors
2️⃣ 直接推論を実行する
モデルをローカルで直接テストしたい場合は、inference.py
を実行します。
python inference.py
サンプルコード(inference.py)
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
model_path = "./distilbert-question-detector"
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
text = "Is this a question?"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
print(f"Probabilities: {probabilities}")
print(f"Predicted class: {predicted_class}")
3️⃣ API を実行する
FastAPI を使用して HTTP インターフェースをデプロイすることもできます。これにより、他のアプリケーションが HTTP リクエストを介してモデルにアクセスできるようになります。
uvicorn app:app --host 0.0.0.0 --port 8000
サンプル API コード(app.py)
from fastapi import FastAPI
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
app = FastAPI()
model_path = "./distilbert-question-detector/checkpoint-5150"
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.eval()
@app.post("/predict/")
async def predict(text: str):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class}
API を起動した後、以下のコマンドでテストできます。
curl -X 'POST' \
'http://127.0.0.1:8000/predict/' \
-H 'Content-Type: application/json' \
-d '{"text": "Is this a valid question?"}'
📌 結果の説明
predicted_class: 0
は、入力テキストが条件を満たすことを意味します。
predicted_class: 1
は、入力テキストが条件を満たさないことを意味します。
サンプル結果
{
"text": "Is this a valid question?",
"probabilities": [[0.9266, 0.0734]],
"predicted_class": 0
}
📄 ライセンス
このプロジェクトは AFL-3.0 ライセンスの下で公開されています。
プロパティ |
詳細 |
モデルタイプ |
DistilBERT ベースのシーケンス分類モデル |
パイプラインタグ |
テキスト分類 |
評価指標 |
正確度 |
ベースモデル |
distilbert/distilbert-base-uncased |
タグ |
tarot、question-detector |