🚀 パリジェンマ-3b-mix-448-ft-テーブル検出
このモデルは、google/paligemma-3b-mix-448 を ucsahin/pubtables-detection-1500-samples データセットでミックス精度でファインチューニングしたバージョンです。
評価セットでは以下の結果を達成しています。
🚀 クイックスタート
モデルの読み込み
Transformersでは、以下のようにモデルを読み込むことができます。
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"
device = "cuda:0"
dtype = torch.bfloat16
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
推論の実行
推論には、以下のコードを使用できます。
prompt = "detect table"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
generation = generation[0][input_len:]
bbox_text = processor.decode(generation, skip_special_tokens=True)
print(bbox_text)
4ビット量子化モデルの読み込み
bitsandbytes
を使用して4ビットまたは8ビットの量子化モデルを読み込むこともできます。ただし、モデルは、例えば4つではなく5つの位置タグ "<loc[value]>" や、"table" 以外の異なるラベルを含む出力を生成する可能性があり、さらなる後処理が必要になる場合があります。提供されている後処理スクリプトは、前者のケースを処理するはずです。
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, BitsAndBytesConfig
import torch
model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"
device = "cuda:0"
dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype
)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
quantization_config=bnb_config
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
✨ 主な機能
このモデルは、テキストプロンプトを与えて画像内のテーブルを検出するタスク用にファインチューニングされたマルチモーダル言語モデルです。モデルは、画像とテキストの入力を組み合わせて、提供された画像内のテーブルの周囲のバウンディングボックスを予測します。
このモデルの主な目的は、画像内のテーブル検出プロセスを自動化することです。画像内のテーブルを識別することが重要な文書処理、データ抽出、画像分析などのさまざまなアプリケーションで利用できます。
入力
- 画像: モデルには、1つ以上のテーブルを含む画像を入力する必要があります。画像はJPEGやPNGなどの標準形式である必要があります。
- テキストプロンプト: さらに、モデルの注意力をテーブル検出タスクに向けるために、テキストプロンプトが必要です。プロンプトは、目的のアクションを明確に示す必要があります。テキストプロンプトとして "detect table" を使用してください。
出力
- バウンディングボックス: モデルは、バウンディングボックスの座標を特殊な <loc[value]> トークンの形式で出力します。ここで、valueは正規化された座標を表す数値です。各検出は、y_min、x_min、y_max、x_maxの順序で4つの位置座標で表され、その後にそのボックス内で検出されたラベルが続きます。値を座標に変換するには、まず数値を1024で割り、次にyを画像の高さで、xを画像の幅で乗算する必要があります。これにより、元の画像サイズに対するバウンディングボックスの座標が得られます。
すべてが順調にいけば、モデルは、画像内で検出されたテーブルの数に応じて、"<loc[value]><loc[value]><loc[value]><loc[value]> table; <loc[value]><loc[value]><loc[value]><loc[value]> table" のようなテキストを出力します。次に、以下のスクリプトを使用して、テキスト出力をPASCAL VOC形式のバウンディングボックスに変換できます。
import re
def post_process(bbox_text, image_width, image_height):
loc_values_str = [bbox.strip() for bbox in bbox_text.split(";")]
converted_bboxes = []
for loc_value_str in loc_values_str:
loc_values = re.findall(r'<loc(\d+)>', loc_value_str)
loc_values = [int(x) for x in loc_values]
loc_values = loc_values[:4]
loc_values = [value/1024 for value in loc_values]
loc_values = [
int(loc_values[1]*image_width), int(loc_values[0]*image_height),
int(loc_values[3]*image_width), int(loc_values[2]*image_height),
]
converted_bboxes.append(loc_values)
return converted_bboxes
🔧 技術詳細
学習ハイパーパラメータ
学習中に使用されたハイパーパラメータは以下の通りです。
- 学習率: 0.0001
- 学習バッチサイズ: 4
- 評価バッチサイズ: 4
- シード: 42
- 勾配累積ステップ: 4
- bf16: True ミックス精度
- 総学習バッチサイズ: 16
- オプティマイザ: Adam (betas=(0.9,0.999), epsilon=1e-08)
- 学習率スケジューラの種類: 線形
- 学習率スケジューラのウォームアップステップ: 5
- エポック数: 3
学習結果
学習損失 |
エポック |
ステップ |
検証損失 |
2.957 |
0.1775 |
15 |
2.1300 |
1.9656 |
0.3550 |
30 |
1.8421 |
1.6716 |
0.5325 |
45 |
1.6898 |
1.5514 |
0.7101 |
60 |
1.5803 |
1.5851 |
0.8876 |
75 |
1.5271 |
1.4134 |
1.0651 |
90 |
1.4771 |
1.3566 |
1.2426 |
105 |
1.4528 |
1.3093 |
1.4201 |
120 |
1.4227 |
1.2897 |
1.5976 |
135 |
1.4115 |
1.256 |
1.7751 |
150 |
1.4007 |
1.2666 |
1.9527 |
165 |
1.3678 |
1.2213 |
2.1302 |
180 |
1.3744 |
1.0999 |
2.3077 |
195 |
1.3633 |
1.1931 |
2.4852 |
210 |
1.3606 |
1.0722 |
2.6627 |
225 |
1.3619 |
1.1485 |
2.8402 |
240 |
1.3544 |
フレームワークのバージョン
- PEFT 0.11.1
- Transformers 4.42.0.dev0
- Pytorch 2.3.0+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1
📄 ライセンス
このモデルはgemmaライセンスの下で提供されています。
バイアス、リスク、制限事項
バイアス、リスク、制限事項については、google/paligemma-3b-mix-448 を参照してください。