GTA1 7B
G
GTA1 7B
由 HelloKKMe 开发
GTA1是一个基于强化学习(GRPO)的GUI元素定位模型,能够精准定位图形用户界面中的元素。
下载量 570
发布时间 : 4/14/2025
模型简介
该模型采用强化学习方法,直接激励可操作且基于实际的响应,为GUI定位提供了新的解决方案。
模型特点
强化学习驱动
采用GRPO强化学习方法,实现目标对齐,直接激励可操作且基于实际的响应。
高精度定位
在多个挑战性数据集上取得了领先的性能。
模型能力
GUI元素定位
图像分析
文本理解
使用案例
自动化测试
GUI元素定位
在自动化测试中精准定位GUI元素,提高测试效率。
在多个数据集上表现优异,如ScreenSpot-V2、ScreenSpotPro和OSWORLD-G。
🚀 强化学习驱动的GUI定位模型
本项目借助强化学习(如GRPO),实现了更精准的GUI元素定位。与依赖冗长思维链推理的方法不同,GRPO直接激励可操作且基于实际的响应,为GUI定位带来了新的解决方案。
🚀 快速开始
模型推理
以下是使用训练好的模型进行推理的代码示例:
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import re
SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''
SYSTEM_PROMPT=SYSTEM_PROMPT.strip()
# Function to extract coordinates from model output
def extract_coordinates(raw_string):
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return [tuple(map(int, match)) for match in matches][0]
except:
return 0,0
# Load model and processor
model_path = "HelloKKMe/GTA1-7B"
max_new_tokens = 32
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=3136,
max_pixels= 4096 * 2160
)
# Load and resize image
image = Image.open("file path")
instruction = "description" # Instruction for grounding
width, height = image.width, image.height
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
min_pixels=processor.image_processor.min_pixels,
max_pixels=processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height,width=resized_width)
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
}
# Tokenize and prepare inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = processor.apply_chat_template([system_message, user_message], tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(model.device)
# Generate prediction
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=1.0, use_cache=True)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text)
pred_x*=scale_x
pred_y*=scale_y
print(pred_x,pred_y)
更多详细信息请参考我们的代码仓库。
✨ 主要特性
- 强化学习驱动:采用GRPO强化学习方法,实现目标对齐,直接激励可操作且基于实际的响应。
- 高精度定位:在多个挑战性数据集上取得了领先的性能。
📚 详细文档
模型性能
我们遵循标准评估协议,在三个具有挑战性的数据集上对模型进行了基准测试。我们的方法在所有开源模型家族中始终取得最佳结果。以下是对比结果:
模型 | 规模 | 是否开源 | ScreenSpot-V2 | ScreenSpotPro | OSWORLD-G |
---|---|---|---|---|---|
OpenAI CUA | — | ❌ | 87.9 | 23.4 | — |
Claude 3.7 | — | ❌ | 87.6 | 27.7 | — |
JEDI-7B | 7B | ✅ | 91.7 | 39.5 | 54.1 |
SE-GUI | 7B | ✅ | 90.3 | 47.0 | — |
UI-TARS | 7B | ✅ | 91.6 | 35.7 | 47.5 |
UI-TARS-1.5* | 7B | ✅ | 89.7* | 42.0* | 64.2* |
UGround-v1-7B | 7B | ✅ | — | 31.1 | 36.4 |
Qwen2.5-VL-32B-Instruct | 32B | ✅ | 91.9* | 48.0 | 59.6* |
UGround-v1-72B | 72B | ✅ | — | 34.5 | — |
Qwen2.5-VL-72B-Instruct | 72B | ✅ | 94.00* | 53.3 | 62.2* |
UI-TARS | 72B | ✅ | 90.3 | 38.1 | — |
GTA1 (Ours) | 7B | ✅ | 92.4 (∆ +2.7) | 50.1(∆ +8.1) | 67.7 (∆ +3.5) |
GTA1 (Ours) | 32B | ✅ | 93.2 (∆ +1.3) | 53.6 (∆ +5.6) | 61.9(∆ +2.3) |
GTA1 (Ours) | 72B | ✅ | 94.8(∆ +0.8) | 58.4 (∆ +5.1) | 66.7(∆ +4.5) |
注意:
- 模型规模以十亿(B)为单位表示参数数量。
- 短横线(—)表示当前结果不可用。
- 上标星号(﹡)表示我们的评估结果。
- UI-TARS-1.5 7B、Qwen2.5-VL-32B-Instruct和Qwen2.5-VL-72B-Instruct作为我们的基线模型。
- ∆ 表示我们的模型相对于基线的性能提升。
Clip Vit Large Patch14
CLIP是由OpenAI开发的视觉-语言模型,通过对比学习将图像和文本映射到共享的嵌入空间,支持零样本图像分类
图像生成文本
C
openai
44.7M
1,710
Clip Vit Base Patch32
CLIP是由OpenAI开发的多模态模型,能够理解图像和文本之间的关系,支持零样本图像分类任务。
图像生成文本
C
openai
14.0M
666
Siglip So400m Patch14 384
Apache-2.0
SigLIP是基于WebLi数据集预训练的视觉语言模型,采用改进的sigmoid损失函数,优化了图像-文本匹配任务。
图像生成文本
Transformers

S
google
6.1M
526
Clip Vit Base Patch16
CLIP是由OpenAI开发的多模态模型,通过对比学习将图像和文本映射到共享的嵌入空间,实现零样本图像分类能力。
图像生成文本
C
openai
4.6M
119
Blip Image Captioning Base
Bsd-3-clause
BLIP是一个先进的视觉-语言预训练模型,擅长图像描述生成任务,支持条件式和非条件式文本生成。
图像生成文本
Transformers

B
Salesforce
2.8M
688
Blip Image Captioning Large
Bsd-3-clause
BLIP是一个统一的视觉-语言预训练框架,擅长图像描述生成任务,支持条件式和无条件式图像描述生成。
图像生成文本
Transformers

B
Salesforce
2.5M
1,312
Openvla 7b
MIT
OpenVLA 7B是一个基于Open X-Embodiment数据集训练的开源视觉-语言-动作模型,能够根据语言指令和摄像头图像生成机器人动作。
图像生成文本
Transformers 英语

O
openvla
1.7M
108
Llava V1.5 7b
LLaVA 是一款开源多模态聊天机器人,基于 LLaMA/Vicuna 微调,支持图文交互。
图像生成文本
Transformers

L
liuhaotian
1.4M
448
Vit Gpt2 Image Captioning
Apache-2.0
这是一个基于ViT和GPT2架构的图像描述生成模型,能够为输入图像生成自然语言描述。
图像生成文本
Transformers

V
nlpconnect
939.88k
887
Blip2 Opt 2.7b
MIT
BLIP-2是一个视觉语言模型,结合了图像编码器和大型语言模型,用于图像到文本的生成任务。
图像生成文本
Transformers 英语

B
Salesforce
867.78k
359
精选推荐AI模型
Llama 3 Typhoon V1.5x 8b Instruct
专为泰语设计的80亿参数指令模型,性能媲美GPT-3.5-turbo,优化了应用场景、检索增强生成、受限生成和推理任务
大型语言模型
Transformers 支持多种语言

L
scb10x
3,269
16
Cadet Tiny
Openrail
Cadet-Tiny是一个基于SODA数据集训练的超小型对话模型,专为边缘设备推理设计,体积仅为Cosmo-3B模型的2%左右。
对话系统
Transformers 英语

C
ToddGoldfarb
2,691
6
Roberta Base Chinese Extractive Qa
基于RoBERTa架构的中文抽取式问答模型,适用于从给定文本中提取答案的任务。
问答系统 中文
R
uer
2,694
98