🚀 Google Safesearch Mini V2
Google Safesearch Mini V2は、超精密なマルチクラス画像分類器で、露骨なコンテンツを正確に検出します。これは、ソーシャルメディアのモデレーションやデータセットのフィルタリングに役立ち、Stable Diffusionのセーフティチェッカーと比べて1.0GBのRAMとディスク容量を節約できます。
🚀 クイックスタート
Safesearch v3.1がリリースされました
Google Safesearch Mini V2は、V1とは異なるアプローチでトレーニングされました。InceptionResNetV2アーキテクチャと、インターネットからランダムに収集された約3,400,000枚の画像のデータセットを使用し、その一部はデータ拡張によって生成されました。トレーニングデータと検証データは、Google Images、Reddit、Kaggle、Imgurから収集され、企業、Google SafeSearch、モデレーターによって安全または不適切と分類されました。
モデルは交差エントロピー損失を使用して5エポックトレーニングされ、トレーニングセットと検証セットの両方で評価されました。予測確率が0.90未満の画像を特定し、データセットに必要な修正を加えた後、さらに8エポックトレーニングされました。その後、分類が困難なさまざまなケースでモデルをテストしたところ、茶色の猫の毛を人間の肌と誤認していることがわかりました。精度を向上させるために、Kaggleの15の追加データセットで1エポックファインチューニングし、最後のエポックではトレーニングデータとテストデータを組み合わせてトレーニングしました。これにより、トレーニングデータと検証データの両方で97%の精度が達成されました。
✨ 主な機能
- 超精密なマルチクラス画像分類器で、露骨なコンテンツを正確に検出します。
- ソーシャルメディアのモデレーションやデータセットのフィルタリングに役立ちます。
- Stable Diffusionのセーフティチェッカーと比べて1.0GBのRAMとディスク容量を節約できます。
📦 インストール
PyTorch
pip install --upgrade torchvision
💻 使用例
基本的な使用法
import torch, os
from torchvision import transforms
from PIL import Image
import urllib.request
import timm
image_path = "https://www.allaboutcats.ca/wp-content/uploads/sites/235/2022/03/shutterstock_320462102-2500-e1647917149997.jpg"
device = "cuda"
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if image_path.startswith('http://') or image_path.startswith('https://'):
import requests
from io import BytesIO
response = requests.get(image_path)
img = Image.open(BytesIO(response.content)).convert('RGB')
else:
img = Image.open(image_path).convert('RGB')
img = transform(img).unsqueeze(0)
img = img.cuda() if device.lower() == "cuda" else img.cpu()
return img
def eval():
model = timm.create_model("hf_hub:FredZhang7/google-safesearch-mini-v2", pretrained=True)
model.to(device)
img = preprocess_image(image_path)
with torch.no_grad():
out = model(img)
_, predicted = torch.max(out.data, 1)
classes = {
0: 'nsfw_gore',
1: 'nsfw_suggestive',
2: 'safe'
}
print('\n\033[1;31m' + classes[predicted.item()] + '\033[0m' if predicted.item() != 2 else '\033[1;32m' + classes[predicted.item()] + '\033[0m\n')
if __name__ == '__main__':
eval()
📄 ライセンス
このプロジェクトはApache-2.0ライセンスの下で公開されています。
Property |
Details |
Pipeline Tag |
image-classification |
Tags |
safety-checker, explicit-filter |
Metrics |
accuracy |
Library Name |
timm |