🚀 [E5-V: マルチモーダル大規模言語モデルによる汎用埋め込み]
E5-Vは、マルチモーダル埋め込みを実現するためのフレームワークです。lmms-lab/llama3-llava-next-8bをベースに微調整されており、異なるタイプの入力間のモダリティギャップを効果的に埋め、微調整なしでもマルチモーダル埋め込みにおいて強力な性能を発揮します。
🚀 クイックスタート
E5-Vは、マルチモーダル大規模言語モデルを適応させてマルチモーダル埋め込みを実現するためのフレームワークです。このフレームワークは、異なるタイプの入力間のモダリティギャップを効果的に埋め、微調整なしでもマルチモーダル埋め込みにおいて強力な性能を発揮します。また、E5-Vには単一モダリティのトレーニングアプローチも提案されており、このアプローチではモデルはテキストペアのみでトレーニングされ、マルチモーダルトレーニングよりも良好な性能を示します。
詳細は、https://github.com/kongds/E5-V を参照してください。
✨ 主な機能
- E5-Vは、lmms-lab/llama3-llava-next-8bをベースに微調整されています。
- マルチモーダル埋め込みを実現するためのフレームワークを提案しています。
- 異なるタイプの入力間のモダリティギャップを効果的に埋めます。
- 微調整なしでもマルチモーダル埋め込みにおいて強力な性能を発揮します。
- 単一モダリティのトレーニングアプローチを提案しており、マルチモーダルトレーニングよりも良好な性能を示します。
💻 使用例
基本的な使用法
import torch
import torch.nn.functional as F
import requests
from PIL import Image
from transformers import AutoTokenizer
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n'
processor = LlavaNextProcessor.from_pretrained('royokong/e5-v')
model = LlavaNextForConditionalGeneration.from_pretrained('royokong/e5-v', torch_dtype=torch.float16).cuda()
img_prompt = llama3_template.format('<image>\nSummary above image in one word: ')
text_prompt = llama3_template.format('<sent>\nSummary above sentence in one word: ')
urls = ['https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg',
'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg']
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
texts = ['A dog sitting in the grass.',
'A cat standing in the snow.']
text_inputs = processor([text_prompt.replace('<sent>', text) for text in texts], return_tensors="pt", padding=True).to('cuda')
img_inputs = processor([img_prompt]*len(images), images, return_tensors="pt", padding=True).to('cuda')
with torch.no_grad():
text_embs = model(**text_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]
img_embs = model(**img_inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]
text_embs = F.normalize(text_embs, dim=-1)
img_embs = F.normalize(img_embs, dim=-1)
print(text_embs @ img_embs.t())