🚀 韓國重排器在Amazon SageMaker上的訓練
本項目提供了用於開發 韓語重排器 的微調指南。ko-reranker 是基於 BAAI/bge-reranker-larger 在韓語數據上進行微調的模型。更多詳細信息,請參考 korean-reranker-git / AWS博客,利用韓語重排器提升檢索增強生成(RAG)性能。
✨ 主要特性
- 與嵌入模型不同,重排器將問題和文檔作為輸入,並直接輸出相似度,而非嵌入向量。
- 向重排器輸入問題和段落,可獲得相關性得分。
- 重排器基於交叉熵損失進行優化,因此相關性得分不受特定範圍限制。
📦 安裝指南
文檔未提及安裝步驟,此部分跳過。
💻 使用示例
基礎用法
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["나는 너를 싫어해", "나는 너를 사랑해"], \
["나는 너를 좋아해", "너에 대한 나의 감정은 사랑 일 수도 있어"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
高級用法
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g5.large'
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "나는 너를 싫어해", "text_pair": "나는 너를 사랑해"},
{"text": "나는 너를 좋아해", "text_pair": "너에 대한 나의 감정은 사랑 일 수도 있어"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
out = json.loads(response['Body'].read().decode())
print (f'Response: {out}')
📚 詳細文檔
背景知識
- 上下文順序會影響準確性(迷失在中間,Liu等人,2023)。
- 使用重排器的原因:
- 當前大語言模型並非輸入的上下文越多越好,相關內容排在前面才能更好地給出答案。
- 語義搜索中使用的相似度(相關性)分數不夠精確。(即排名靠前的內容一定比排名靠後的內容與問題更相似嗎?)
- 嵌入向量擅長捕捉文檔背後的含義。
- 問題和答案在語義上並不相同。(假設文檔嵌入)
- 使用近似最近鄰搜索(ANNs)會帶來一定的懲罰。
重排器模型
數據集
- msmarco-triplets:
- 來自MS MARCO段落數據集的(問題,答案,負樣本)三元組,共499,184個樣本。
- 該數據集由英文組成,通過Amazon Translate進行翻譯後使用。
- 格式:
{"query": str, "pos": List[str], "neg": List[str]}
查詢是問題,pos是正文本列表,neg是負文本列表。如果查詢沒有負文本,可以從整個語料庫中隨機抽取一些作為負文本。
{"query": "대한민국의 수도는?", "pos": ["미국의 수도는 워싱턴이고, 일본은 도쿄이며 한국은 서울이다."], "neg": ["미국의 수도는 워싱턴이고, 일본은 도쿄이며 북한은 평양이다."]}
性能
模型 |
上下文中包含正確答案的比例 |
平均倒數排名(MRR) |
無重排器(默認) |
0.93 |
0.80 |
使用重排器(bge-reranker-large) |
0.95 |
0.84 |
使用重排器(使用韓語微調) |
0.96 |
0.87 |
./dataset/evaluation/eval_dataset.csv
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01
}
🔧 技術細節
文檔未提及技術細節相關內容,此部分跳過。
📄 許可證
FlagEmbedding採用 MIT許可證。
致謝
部分代碼基於 FlagEmbedding 和 KoSimCSE-SageMaker 開發。
引用
如果您發現此倉庫有用,請考慮點贊 ⭐ 並引用。
貢獻者
- Jang Dongjin博士(AWS人工智能/機器學習專家解決方案架構師) | 郵箱 | 領英 | GitHub
統計
