🚀 LLM2Vec:大語言模型是強大的文本編碼器
LLM2Vec 是一種將僅解碼器的大語言模型轉換為文本編碼器的簡單方法。它包含三個簡單步驟:1) 啟用雙向注意力;2) 掩碼下一個詞預測;3) 無監督對比學習。該模型還可以進一步微調以達到最先進的性能。
🚀 快速開始
LLM2Vec 提供了一種簡潔有效的方式,將僅解碼器的大語言模型轉化為強大的文本編碼器。通過幾個簡單步驟,你就可以使用它進行文本編碼和相似度計算。
📦 安裝指南
使用以下命令安裝 llm2vec
:
pip install llm2vec
💻 使用示例
基礎用法
from llm2vec import LLM2Vec
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp"
)
config = AutoConfig.from_pretrained(
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",
trust_remote_code=True,
config=config,
torch_dtype=torch.bfloat16,
device_map="cuda" if torch.cuda.is_available() else "cpu",
)
model = PeftModel.from_pretrained(
model,
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",
)
l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)
instruction = (
"Given a web search query, retrieve relevant passages that answer the query:"
)
queries = [
[instruction, "how much protein should a female eat"],
[instruction, "summit define"],
]
q_reps = l2v.encode(queries)
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
]
d_reps = l2v.encode(documents)
q_reps_norm = torch.nn.functional.normalize(q_reps, p=2, dim=1)
d_reps_norm = torch.nn.functional.normalize(d_reps, p=2, dim=1)
cos_sim = torch.mm(q_reps_norm, d_reps_norm.transpose(0, 1))
print(cos_sim)
"""
tensor([[0.8180, 0.5825],
[0.1069, 0.1931]])
"""
注意事項
如果你對代碼有任何疑問,請隨時通過電子郵件聯繫 Parishad (parishad.behnamghader@mila.quebec
) 和 Vaibhav (vaibhav.adlakha@mila.quebec
)。
📄 許可證
本項目採用 MIT 許可證。
項目相關鏈接
- 代碼倉庫:https://github.com/McGill-NLP/llm2vec
- 論文鏈接:https://arxiv.org/abs/2404.05961