🚀 LLM2Vec:大語言模型是強大的文本編碼器
LLM2Vec是一種將僅解碼器的大語言模型轉換為文本編碼器的簡單方法。它包含三個簡單步驟:1) 啟用雙向注意力;2) 掩碼下一個標記預測;3) 無監督對比學習。該模型還可以進一步微調以達到最先進的性能。
- 倉庫地址:https://github.com/McGill-NLP/llm2vec
- 論文地址:https://arxiv.org/abs/2404.05961
🚀 快速開始
LLM2Vec提供了將僅解碼器的大語言模型轉換為文本編碼器的解決方案,通過特定步驟和微調可實現出色性能。
✨ 主要特性
- 簡單高效:通過三個簡單步驟將僅解碼器的大語言模型轉換為文本編碼器。
- 可微調:模型可以進一步微調以達到最先進的性能。
📦 安裝指南
使用以下命令安裝llm2vec
:
pip install llm2vec
💻 使用示例
基礎用法
from llm2vec import LLM2Vec
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp"
)
config = AutoConfig.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
trust_remote_code=True,
config=config,
torch_dtype=torch.bfloat16,
device_map="cuda" if torch.cuda.is_available() else "cpu",
)
model = PeftModel.from_pretrained(
model,
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
)
l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)
instruction = (
"Given a web search query, retrieve relevant passages that answer the query:"
)
queries = [
[instruction, "how much protein should a female eat"],
[instruction, "summit define"],
]
q_reps = l2v.encode(queries)
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
]
d_reps = l2v.encode(documents)
q_reps_norm = torch.nn.functional.normalize(q_reps, p=2, dim=1)
d_reps_norm = torch.nn.functional.normalize(d_reps, p=2, dim=1)
cos_sim = torch.mm(q_reps_norm, d_reps_norm.transpose(0, 1))
print(cos_sim)
"""
tensor([[0.7740, 0.5580],
[0.4845, 0.4993]])
"""
📄 許可證
本項目採用MIT許可證。
❓ 問題反饋
如果您對代碼有任何疑問,請隨時給Parishad (parishad.behnamghader@mila.quebec
) 和Vaibhav (vaibhav.adlakha@mila.quebec
) 發郵件。