🚀 强化学习GUI定位模型GTA1
本项目借助强化学习(如GRPO)实现GUI定位,通过奖励成功点击来直接激励可操作且基于实际的响应,而非依赖冗长的思维链推理。我们分享了使用GRPO训练的最先进的GUI定位模型。
🚀 快速开始
本项目主要聚焦于强化学习在GUI定位中的应用,借助GRPO算法训练模型,以实现更精准的GUI定位。通过在多个挑战性数据集上的测试,验证了模型的性能。
✨ 主要特性
- 目标对齐:强化学习(如GRPO)因其固有的目标对齐特性,即奖励成功点击,而非鼓励冗长的文本思维链(CoT)推理,有助于实现定位。
- 直接激励:与严重依赖冗长CoT推理的方法不同,GRPO直接激励可操作且基于实际的响应。
- 性能卓越:在多个挑战性数据集上进行基准测试,模型在所有开源模型家族中始终取得最佳结果。
📚 详细文档
模型性能
我们遵循标准评估协议,在三个具有挑战性的数据集上对模型进行基准测试。以下是比较结果:
模型 |
大小 |
开源 |
ScreenSpot-V2 |
ScreenSpotPro |
OSWORLD-G |
OpenAI CUA |
— |
❌ |
87.9 |
23.4 |
— |
Claude 3.7 |
— |
❌ |
87.6 |
27.7 |
— |
JEDI - 7B |
7B |
✅ |
91.7 |
39.5 |
54.1 |
SE - GUI |
7B |
✅ |
90.3 |
47.0 |
— |
UI - TARS |
7B |
✅ |
91.6 |
35.7 |
47.5 |
UI - TARS - 1.5* |
7B |
✅ |
89.7* |
42.0* |
64.2* |
UGround - v1 - 7B |
7B |
✅ |
— |
31.1 |
36.4 |
Qwen2.5 - VL - 32B - Instruct |
32B |
✅ |
91.9* |
48.0 |
59.6* |
UGround - v1 - 72B |
72B |
✅ |
— |
34.5 |
— |
Qwen2.5 - VL - 72B - Instruct |
72B |
✅ |
94.00* |
53.3 |
62.2* |
UI - TARS |
72B |
✅ |
90.3 |
38.1 |
— |
GTA1 (我们的模型) |
7B |
✅ |
92.4 (∆ +2.7) |
50.1(∆ +8.1) |
67.7 (∆ +3.5) |
GTA1 (我们的模型) |
32B |
✅ |
93.2 (∆ +1.3) |
53.6 (∆ +5.6) |
61.9(∆ +2.3) |
GTA1 (我们的模型) |
72B |
✅ |
94.8(∆ +0.8) |
58.4 (∆ +5.1) |
66.7(∆ +4.5) |
⚠️ 重要提示
- 模型大小以十亿(B)参数表示。
- 短横线(—)表示当前无法获取的结果。
- 上标星号(﹡)表示我们的评估结果。
- UI - TARS - 1.5 7B、Qwen2.5 - VL - 32B - Instruct 和 Qwen2.5 - VL - 72B - Instruct 作为我们的基线模型。
- ∆ 表示我们的模型与其基线相比的性能提升。
💻 使用示例
基础用法
以下是一个代码片段,展示了如何使用训练好的模型进行推理:
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import re
SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''
SYSTEM_PROMPT=SYSTEM_PROMPT.strip()
def extract_coordinates(raw_string):
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return [tuple(map(int, match)) for match in matches][0]
except:
return 0,0
model_path = "HelloKKMe/GTA1-32B"
max_new_tokens = 32
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=3136,
max_pixels= 4096 * 2160
)
image = Image.open("file path")
instruction = "description"
width, height = image.width, image.height
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
min_pixels=processor.image_processor.min_pixels,
max_pixels=processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width)
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
}
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(model.device)
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
pred_x, pred_y = extract_coordinates(output_text)
pred_x*=scale_x
pred_y*=scale_y
print(pred_x,pred_y)
更多详细信息请参考我们的 代码。