🚀 強化學習GUI定位模型GTA1
本項目藉助強化學習(如GRPO)實現GUI定位,通過獎勵成功點擊來直接激勵可操作且基於實際的響應,而非依賴冗長的思維鏈推理。我們分享了使用GRPO訓練的最先進的GUI定位模型。
🚀 快速開始
本項目主要聚焦於強化學習在GUI定位中的應用,藉助GRPO算法訓練模型,以實現更精準的GUI定位。通過在多個挑戰性數據集上的測試,驗證了模型的性能。
✨ 主要特性
- 目標對齊:強化學習(如GRPO)因其固有的目標對齊特性,即獎勵成功點擊,而非鼓勵冗長的文本思維鏈(CoT)推理,有助於實現定位。
- 直接激勵:與嚴重依賴冗長CoT推理的方法不同,GRPO直接激勵可操作且基於實際的響應。
- 性能卓越:在多個挑戰性數據集上進行基準測試,模型在所有開源模型家族中始終取得最佳結果。
📚 詳細文檔
模型性能
我們遵循標準評估協議,在三個具有挑戰性的數據集上對模型進行基準測試。以下是比較結果:
模型 |
大小 |
開源 |
ScreenSpot-V2 |
ScreenSpotPro |
OSWORLD-G |
OpenAI CUA |
— |
❌ |
87.9 |
23.4 |
— |
Claude 3.7 |
— |
❌ |
87.6 |
27.7 |
— |
JEDI - 7B |
7B |
✅ |
91.7 |
39.5 |
54.1 |
SE - GUI |
7B |
✅ |
90.3 |
47.0 |
— |
UI - TARS |
7B |
✅ |
91.6 |
35.7 |
47.5 |
UI - TARS - 1.5* |
7B |
✅ |
89.7* |
42.0* |
64.2* |
UGround - v1 - 7B |
7B |
✅ |
— |
31.1 |
36.4 |
Qwen2.5 - VL - 32B - Instruct |
32B |
✅ |
91.9* |
48.0 |
59.6* |
UGround - v1 - 72B |
72B |
✅ |
— |
34.5 |
— |
Qwen2.5 - VL - 72B - Instruct |
72B |
✅ |
94.00* |
53.3 |
62.2* |
UI - TARS |
72B |
✅ |
90.3 |
38.1 |
— |
GTA1 (我們的模型) |
7B |
✅ |
92.4 (∆ +2.7) |
50.1(∆ +8.1) |
67.7 (∆ +3.5) |
GTA1 (我們的模型) |
32B |
✅ |
93.2 (∆ +1.3) |
53.6 (∆ +5.6) |
61.9(∆ +2.3) |
GTA1 (我們的模型) |
72B |
✅ |
94.8(∆ +0.8) |
58.4 (∆ +5.1) |
66.7(∆ +4.5) |
⚠️ 重要提示
- 模型大小以十億(B)參數表示。
- 短橫線(—)表示當前無法獲取的結果。
- 上標星號(﹡)表示我們的評估結果。
- UI - TARS - 1.5 7B、Qwen2.5 - VL - 32B - Instruct 和 Qwen2.5 - VL - 72B - Instruct 作為我們的基線模型。
- ∆ 表示我們的模型與其基線相比的性能提升。
💻 使用示例
基礎用法
以下是一個代碼片段,展示瞭如何使用訓練好的模型進行推理:
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import re
SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''
SYSTEM_PROMPT=SYSTEM_PROMPT.strip()
def extract_coordinates(raw_string):
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return [tuple(map(int, match)) for match in matches][0]
except:
return 0,0
model_path = "HelloKKMe/GTA1-32B"
max_new_tokens = 32
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=3136,
max_pixels= 4096 * 2160
)
image = Image.open("file path")
instruction = "description"
width, height = image.width, image.height
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
min_pixels=processor.image_processor.min_pixels,
max_pixels=processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width)
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
}
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
pred_x, pred_y = extract_coordinates(output_text)
pred_x*=scale_x
pred_y*=scale_y
print(pred_x,pred_y)
更多詳細信息請參考我們的 代碼。