🚀 LLM2Vec:大语言模型是强大的文本编码器
LLM2Vec是一种将仅解码器的大语言模型转换为文本编码器的简单方法。它包含三个简单步骤:1) 启用双向注意力;2) 掩码下一个标记预测;3) 无监督对比学习。该模型还可以进一步微调以达到最先进的性能。
- 仓库地址:https://github.com/McGill-NLP/llm2vec
- 论文地址:https://arxiv.org/abs/2404.05961
🚀 快速开始
LLM2Vec提供了将仅解码器的大语言模型转换为文本编码器的解决方案,通过特定步骤和微调可实现出色性能。
✨ 主要特性
- 简单高效:通过三个简单步骤将仅解码器的大语言模型转换为文本编码器。
- 可微调:模型可以进一步微调以达到最先进的性能。
📦 安装指南
使用以下命令安装llm2vec
:
pip install llm2vec
💻 使用示例
基础用法
from llm2vec import LLM2Vec
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp"
)
config = AutoConfig.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
trust_remote_code=True,
config=config,
torch_dtype=torch.bfloat16,
device_map="cuda" if torch.cuda.is_available() else "cpu",
)
model = PeftModel.from_pretrained(
model,
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
)
l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)
instruction = (
"Given a web search query, retrieve relevant passages that answer the query:"
)
queries = [
[instruction, "how much protein should a female eat"],
[instruction, "summit define"],
]
q_reps = l2v.encode(queries)
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
]
d_reps = l2v.encode(documents)
q_reps_norm = torch.nn.functional.normalize(q_reps, p=2, dim=1)
d_reps_norm = torch.nn.functional.normalize(d_reps, p=2, dim=1)
cos_sim = torch.mm(q_reps_norm, d_reps_norm.transpose(0, 1))
print(cos_sim)
"""
tensor([[0.7740, 0.5580],
[0.4845, 0.4993]])
"""
📄 许可证
本项目采用MIT许可证。
❓ 问题反馈
如果您对代码有任何疑问,请随时给Parishad (parishad.behnamghader@mila.quebec
) 和Vaibhav (vaibhav.adlakha@mila.quebec
) 发邮件。