🚀 印地语图像描述模型
这是一个基于编码器 - 解码器架构的图像描述模型,编码器采用 VIT,解码器使用 GPT2 - 印地语。这是首次尝试将 ViT 和 GPT2 - 印地语模型应用于图像描述任务。我们使用了 Kaggle 上的 Flickr8k 印地语数据集对该模型进行训练。
该模型是在 Huggingface 组织的 HuggingFace 课程社区周期间进行训练的。
🚀 快速开始
如何使用
以下是如何使用此模型为 Flickr8k 数据集中的图像生成描述的示例代码:
import torch
import requests
from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, \
VisionEncoderDecoderModel
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
url = 'https://shorturl.at/fvxEQ'
image = Image.open(requests.get(url, stream=True).raw)
encoder_checkpoint = 'google/vit-base-patch16-224'
decoder_checkpoint = 'surajp/gpt2-hindi'
model_checkpoint = 'team-indain-image-caption/hindi-image-captioning'
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
caption_ids = model.generate(sample, max_length = 50)[0]
caption_text = clean_text(tokenizer.decode(caption_ids))
print(caption_text)
📦 安装指南
文档未提及安装相关内容,故跳过此章节。
💻 使用示例
基础用法
import torch
import requests
from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, \
VisionEncoderDecoderModel
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
url = 'https://shorturl.at/fvxEQ'
image = Image.open(requests.get(url, stream=True).raw)
encoder_checkpoint = 'google/vit-base-patch16-224'
decoder_checkpoint = 'surajp/gpt2-hindi'
model_checkpoint = 'team-indain-image-caption/hindi-image-captioning'
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
caption_ids = model.generate(sample, max_length = 50)[0]
caption_text = clean_text(tokenizer.decode(caption_ids))
print(caption_text)
高级用法
文档未提及高级用法相关代码示例,故跳过此部分。
📚 详细文档
训练数据
我们使用了 Flickr8k 印地语数据集对模型进行训练,该数据集是原始 Flickr8k 数据集的翻译版本,可在 Kaggle 上获取。
训练过程
此模型是在 Huggingface 组织的 HuggingFace 课程社区周期间进行训练的,训练在 Kaggle GPU 上完成。
训练参数
- 训练轮数(epochs) = 8
- 批次大小(batch_size) = 8
- 启用混合精度训练
团队成员