🚀 First image captioning model for Russian language vit-rugpt2-image-captioning
This is an image captioning model trained on the translated (en-ru) version of the COCO2014 dataset. It aims to generate captions for images in the Russian language.
✨ Features
- Trained on a translated version of the COCO2014 dataset (en-ru).
- Initialized with
google/vit-base-patch16-224-in21k
for the encoder and sberbank-ai/rugpt3large_based_on_gpt2
for the decoder.
📦 Installation
The installation process is not provided in the original README. If you want to use this model, you can install the necessary libraries as shown in the sample running code, such as transformers
, torch
, and Pillow
.
💻 Usage Examples
Basic Usage
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image
model = VisionEncoderDecoderModel.from_pretrained("vit-rugpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("vit-rugpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("vit-rugpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_caption(image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
predict_caption(['train2014/COCO_train2014_000000295442.jpg'])
Advanced Usage
from transformers import pipeline
image_to_text = pipeline("image-to-text", model="vit-rugpt2-image-captioning")
image_to_text("train2014/COCO_train2014_000000296754.jpg")
📚 Documentation
Model Details
The model was initialized with google/vit-base-patch16-224-in21k
for the encoder and sberbank-ai/rugpt3large_based_on_gpt2
for the decoder.
Metrics on test data
- Bleu: 8.672
- Bleu precision 1: 30.567
- Bleu precision 2: 7.895
- Bleu precision 3: 3.261
📄 License
The license information is not provided in the original README.
🔗 Contact
For any help, you can reach out through the following links: