🚀 幾何形狀分類模型
本項目的幾何形狀分類模型是一個圖像分類的視覺語言編碼器模型,它基於 google/siglip2-base-patch16-224 進行微調,用於多類形狀識別任務。該模型使用 SiglipForImageClassification 架構對各種幾何形狀進行分類。

🚀 快速開始
安裝依賴
!pip install -q transformers torch pillow gradio
運行代碼
import gradio as gr
from transformers import AutoImageProcessor
from transformers import SiglipForImageClassification
from PIL import Image
import torch
model_name = "prithivMLmods/Geometric-Shapes-Classification"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
labels = {
"0": "Circle ◯",
"1": "Kite ⬰",
"2": "Parallelogram ▰",
"3": "Rectangle ▭",
"4": "Rhombus ◆",
"5": "Square ◼",
"6": "Trapezoid ⏢",
"7": "Triangle ▲"
}
def classify_shape(image):
"""Classifies the geometric shape in the input image."""
image = Image.fromarray(image).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
return predictions
iface = gr.Interface(
fn=classify_shape,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(label="Prediction Scores"),
title="Geometric Shapes Classification",
description="Upload an image to classify geometric shapes such as circle, triangle, square, and more."
)
if __name__ == "__main__":
iface.launch()
💻 使用示例
基礎用法
import gradio as gr
from transformers import AutoImageProcessor
from transformers import SiglipForImageClassification
from PIL import Image
import torch
model_name = "prithivMLmods/Geometric-Shapes-Classification"
model = SiglipForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
labels = {
"0": "Circle ◯",
"1": "Kite ⬰",
"2": "Parallelogram ▰",
"3": "Rectangle ▭",
"4": "Rhombus ◆",
"5": "Square ◼",
"6": "Trapezoid ⏢",
"7": "Triangle ▲"
}
def classify_shape(image):
"""對輸入圖像中的幾何形狀進行分類。"""
image = Image.fromarray(image).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
return predictions
iface = gr.Interface(
fn=classify_shape,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(label="Prediction Scores"),
title="Geometric Shapes Classification",
description="上傳一張圖像,對圓形、三角形、正方形等幾何形狀進行分類。"
)
if __name__ == "__main__":
iface.launch()
📚 詳細文檔
分類報告
Classification Report:
precision recall f1-score support
Circle ◯ 0.9921 0.9987 0.9953 1500
Kite ⬰ 0.9927 0.9927 0.9927 1500
Parallelogram ▰ 0.9926 0.9840 0.9883 1500
Rectangle ▭ 0.9993 0.9913 0.9953 1500
Rhombus ◆ 0.9846 0.9820 0.9833 1500
Square ◼ 0.9914 0.9987 0.9950 1500
Trapezoid ⏢ 0.9966 0.9793 0.9879 1500
Triangle ▲ 0.9772 0.9993 0.9881 1500
accuracy 0.9908 12000
macro avg 0.9908 0.9908 0.9907 12000
weighted avg 0.9908 0.9908 0.9907 12000
模型分類的類別
該模型將圖像分類為以下類別:
- 類別 0:圓形 ◯
- 類別 1:風箏形 ⬰
- 類別 2:平行四邊形 ▰
- 類別 3:矩形 ▭
- 類別 4:菱形 ◆
- 類別 5:正方形 ◼
- 類別 6:梯形 ⏢
- 類別 7:三角形 ▲
預期用途
幾何形狀分類 模型旨在識別圖像中的基本幾何形狀。示例用例如下:
- 教育工具:用於以可視化方式學習和教授幾何知識。
- 計算機視覺項目:作為機器人或自動化中的形狀檢測器。
- 圖像分析:識別圖表或工程圖紙中的符號。
- 輔助技術:支持視障應用中的形狀識別。
📄 許可證
本項目採用 Apache-2.0 許可證。
📦 模型信息
屬性 |
詳情 |
模型類型 |
圖像分類視覺語言編碼器模型 |
基礎模型 |
google/siglip2-base-patch16-224 |
訓練數據集 |
prithivMLmods/Math-Shapes |
庫名稱 |
transformers |
標籤 |
Shapes、Geometric、SigLIP2、art |
管道標籤 |
圖像分類 |