🚀 BiRefNet:高分辨率二分圖像分割的雙邊參考模型
BiRefNet是一個用於高分辨率二分圖像分割的模型,在背景去除、掩膜生成、二分圖像分割、偽裝目標檢測和顯著目標檢測等任務中表現出色。本項目提供了模型的官方實現、代碼、文檔和模型庫。
🚀 快速開始
安裝依賴包
pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
加載BiRefNet模型
使用HuggingFace的代碼和權重
from transformers import AutoModelForImageSegmentation
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
使用GitHub的代碼和HuggingFace的權重
git clone https://github.com/ZhengPeng7/BiRefNet.git
cd BiRefNet
from models.birefnet import BiRefNet
birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
使用GitHub的代碼和本地的權重
import torch
from utils import check_state_dict
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
使用加載好的BiRefNet進行推理
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from models.birefnet import BiRefNet
birefnet = ...
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
birefnet.half()
def extract_object(birefnet, imagepath):
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(imagepath)
input_images = transform_image(image).unsqueeze(0).to('cuda').half()
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
return image, mask
plt.axis("off")
plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
plt.show()
本地使用推理端點
import requests
import base64
from io import BytesIO
from PIL import Image
YOUR_HF_TOKEN = 'xxx'
API_URL = "xxx"
headers = {
"Authorization": "Bearer {}".format(YOUR_HF_TOKEN)
}
def base64_to_bytes(base64_string):
if "data:image" in base64_string:
base64_string = base64_string.split(",")[1]
image_bytes = base64.b64decode(base64_string)
return image_bytes
def bytes_to_base64(image_bytes):
image_stream = BytesIO(image_bytes)
image = Image.open(image_stream)
return image
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg",
"parameters": {}
})
output_image = bytes_to_base64(base64_to_bytes(output))
output_image
✨ 主要特性
- 多任務支持:可用於背景去除、掩膜生成、二分圖像分割、偽裝目標檢測和顯著目標檢測等多種任務。
- 高分辨率分割:能夠實現高分辨率的二分圖像分割。
- 多種使用方式:支持使用HuggingFace的代碼和權重、GitHub的代碼和HuggingFace的權重、GitHub的代碼和本地的權重等多種方式加載模型。
📦 安裝指南
安裝依賴包
pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
💻 使用示例
基礎用法
from transformers import AutoModelForImageSegmentation
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
高級用法
import torch
from utils import check_state_dict
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
📚 詳細文檔
本項目是論文 "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (CAAI AIR 2024) 的官方實現。更多詳細信息請訪問我們的GitHub倉庫:https://github.com/ZhengPeng7/BiRefNet,包括代碼、文檔和模型庫。
在線演示
🔧 技術細節
本項目的BiRefNet模型是用於標準二分圖像分割(DIS)的,在 DIS-TR 上進行訓練,並在 DIS-TEs和DIS-VD 上進行驗證。該模型在三個任務(DIS、HRSOD和COD)上取得了SOTA性能。
📄 許可證
本項目採用MIT許可證。詳情請見 LICENSE 文件。
致謝
- 感謝 @Freepik 慷慨提供GPU資源,用於訓練更高分辨率的BiRefNet模型和進行更多探索。
- 感謝 @fal 慷慨提供GPU資源,用於訓練更通用的BiRefNet模型。
- 感謝 @not-lain 幫助我們將BiRefNet模型更好地部署到HuggingFace上。
引用
@article{zheng2024birefnet,
title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
journal={CAAI Artificial Intelligence Research},
volume = {3},
pages = {9150038},
year={2024}
}
示例圖片
DIS示例1 |
DIS示例2 |
 |
 |
作者信息
1 南開大學 2 西北工業大學 3 國防科技大學 4 阿爾託大學 5 上海人工智能實驗室 6 特倫託大學
相關鏈接