🚀 強化学習駆動のGUI位置特定モデル
本プロジェクトでは、強化学習(GRPOなど)を活用して、より高精度なGUI要素の位置特定を実現しています。長い思考チェーンの推論に依存する方法とは異なり、GRPOは操作可能で実際に基づく応答を直接促し、GUI位置特定に新しい解決策をもたらします。
🚀 クイックスタート
モデル推論
以下は、学習済みモデルを使用して推論を行うコード例です。
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import re
SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''
SYSTEM_PROMPT=SYSTEM_PROMPT.strip()
def extract_coordinates(raw_string):
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return [tuple(map(int, match)) for match in matches][0]
except:
return 0,0
model_path = "HelloKKMe/GTA1-7B"
max_new_tokens = 32
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=3136,
max_pixels= 4096 * 2160
)
image = Image.open("file path")
instruction = "description"
width, height = image.width, image.height
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
min_pixels=processor.image_processor.min_pixels,
max_pixels=processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width)
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
}
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
pred_x, pred_y = extract_coordinates(output_text)
pred_x*=scale_x
pred_y*=scale_y
print(pred_x,pred_y)
詳細情報については、コードリポジトリをご参照ください。
✨ 主な機能
- 強化学習駆動:GRPO強化学習手法を採用し、目標のアライメントを実現し、操作可能で実際に基づく応答を直接促します。
- 高精度な位置特定:複数の難易度の高いデータセットで、トップレベルの性能を達成しています。
📚 ドキュメント
モデル性能
我々は標準的な評価プロトコルに従い、3つの難易度の高いデータセットでモデルのベンチマークテストを行いました。我々の手法は、すべてのオープンソースモデルファミリーの中で常に最高の結果を達成しています。以下は比較結果です。
モデル |
規模 |
オープンソースか |
ScreenSpot-V2 |
ScreenSpotPro |
OSWORLD-G |
OpenAI CUA |
— |
❌ |
87.9 |
23.4 |
— |
Claude 3.7 |
— |
❌ |
87.6 |
27.7 |
— |
JEDI-7B |
7B |
✅ |
91.7 |
39.5 |
54.1 |
SE-GUI |
7B |
✅ |
90.3 |
47.0 |
— |
UI-TARS |
7B |
✅ |
91.6 |
35.7 |
47.5 |
UI-TARS-1.5* |
7B |
✅ |
89.7* |
42.0* |
64.2* |
UGround-v1-7B |
7B |
✅ |
— |
31.1 |
36.4 |
Qwen2.5-VL-32B-Instruct |
32B |
✅ |
91.9* |
48.0 |
59.6* |
UGround-v1-72B |
72B |
✅ |
— |
34.5 |
— |
Qwen2.5-VL-72B-Instruct |
72B |
✅ |
94.00* |
53.3 |
62.2* |
UI-TARS |
72B |
✅ |
90.3 |
38.1 |
— |
GTA1 (Ours) |
7B |
✅ |
92.4 (∆ +2.7) |
50.1(∆ +8.1) |
67.7 (∆ +3.5) |
GTA1 (Ours) |
32B |
✅ |
93.2 (∆ +1.3) |
53.6 (∆ +5.6) |
61.9(∆ +2.3) |
GTA1 (Ours) |
72B |
✅ |
94.8(∆ +0.8) |
58.4 (∆ +5.1) |
66.7(∆ +4.5) |
注意事項:
- モデルの規模は、パラメータ数を10億(B)単位で表しています。
- 短い横線(—)は、現在の結果が利用できないことを示します。
- 上付きの星印(﹡)は、我々の評価結果を示します。
- UI-TARS-1.5 7B、Qwen2.5-VL-32B-Instruct、Qwen2.5-VL-72B-Instructは、我々のベースラインモデルとして使用されています。
- ∆ は、我々のモデルがベースラインに対する性能向上を示します。