🚀 韩国重排器在Amazon SageMaker上的训练
本项目提供了用于开发 韩语重排器 的微调指南。ko-reranker 是基于 BAAI/bge-reranker-larger 在韩语数据上进行微调的模型。更多详细信息,请参考 korean-reranker-git / AWS博客,利用韩语重排器提升检索增强生成(RAG)性能。
✨ 主要特性
- 与嵌入模型不同,重排器将问题和文档作为输入,并直接输出相似度,而非嵌入向量。
- 向重排器输入问题和段落,可获得相关性得分。
- 重排器基于交叉熵损失进行优化,因此相关性得分不受特定范围限制。
📦 安装指南
文档未提及安装步骤,此部分跳过。
💻 使用示例
基础用法
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["나는 너를 싫어해", "나는 너를 사랑해"], \
["나는 너를 좋아해", "너에 대한 나의 감정은 사랑 일 수도 있어"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
高级用法
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g5.large'
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "나는 너를 싫어해", "text_pair": "나는 너를 사랑해"},
{"text": "나는 너를 좋아해", "text_pair": "너에 대한 나의 감정은 사랑 일 수도 있어"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
out = json.loads(response['Body'].read().decode())
print (f'Response: {out}')
📚 详细文档
背景知识
- 上下文顺序会影响准确性(迷失在中间,Liu等人,2023)。
- 使用重排器的原因:
- 当前大语言模型并非输入的上下文越多越好,相关内容排在前面才能更好地给出答案。
- 语义搜索中使用的相似度(相关性)分数不够精确。(即排名靠前的内容一定比排名靠后的内容与问题更相似吗?)
- 嵌入向量擅长捕捉文档背后的含义。
- 问题和答案在语义上并不相同。(假设文档嵌入)
- 使用近似最近邻搜索(ANNs)会带来一定的惩罚。
重排器模型
数据集
- msmarco-triplets:
- 来自MS MARCO段落数据集的(问题,答案,负样本)三元组,共499,184个样本。
- 该数据集由英文组成,通过Amazon Translate进行翻译后使用。
- 格式:
{"query": str, "pos": List[str], "neg": List[str]}
查询是问题,pos是正文本列表,neg是负文本列表。如果查询没有负文本,可以从整个语料库中随机抽取一些作为负文本。
{"query": "대한민국의 수도는?", "pos": ["미국의 수도는 워싱턴이고, 일본은 도쿄이며 한국은 서울이다."], "neg": ["미국의 수도는 워싱턴이고, 일본은 도쿄이며 북한은 평양이다."]}
性能
模型 |
上下文中包含正确答案的比例 |
平均倒数排名(MRR) |
无重排器(默认) |
0.93 |
0.80 |
使用重排器(bge-reranker-large) |
0.95 |
0.84 |
使用重排器(使用韩语微调) |
0.96 |
0.87 |
./dataset/evaluation/eval_dataset.csv
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01
}
🔧 技术细节
文档未提及技术细节相关内容,此部分跳过。
📄 许可证
FlagEmbedding采用 MIT许可证。
致谢
部分代码基于 FlagEmbedding 和 KoSimCSE-SageMaker 开发。
引用
如果您发现此仓库有用,请考虑点赞 ⭐ 并引用。
贡献者
- Jang Dongjin博士(AWS人工智能/机器学习专家解决方案架构师) | 邮箱 | 领英 | GitHub
统计
