🚀 Amazon SageMakerでの韓国語Rerankerのトレーニング
このドキュメントは韓国語Rerankerの開発のための微調整ガイドを提供します。ko-rerankerはBAAI/bge-reranker-largerをベースに、韓国語データで微調整されたモデルです。詳細については、korean-reranker-git / AWS Blog, 韓国語Rerankerを活用した検索増強生成(RAG)のパフォーマンス向上を参照してください。
🚀 クイックスタート
韓国語Rerankerの開発に必要な微調整ガイドを提供します。このモデルは韓国語データに対して微調整されており、より高精度な検索結果を提供します。
✨ 主な機能
- Rerankerは、埋め込みモデルとは異なり、質問と文書を入力として使用し、埋め込みではなく類似度を直接出力します。
- Rerankerに質問と文節を入力すると、関連性スコアを取得できます。
- RerankerはCrossEntropy損失に基づいて最適化されるため、関連性スコアは特定の範囲に制限されません。
💻 使用例
基本的な使用法
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["나는 너를 싫어해", "나는 너를 사랑해"], \
["나는 너를 좋아해", "너에 대한 나의 감정은 사랑 일 수도 있어"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
高度な使用法
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g5.large'
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "나는 너를 싫어해", "text_pair": "나는 너를 사랑해"},
{"text": "나는 너를 좋아해", "text_pair": "너에 대한 나의 감정은 사랑 일 수도 있어"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
out = json.loads(response['Body'].read().decode())
print (f'Response: {out}')
📚 ドキュメント
背景情報
Rerankerモデル
データセット
- msmarco-triplets
- MS MARCOパッセージデータセットからの(質問, 回答, 否定)-トリプレット、499,184サンプル
- このデータセットは英語で構成されており、Amazon Translateを使用して翻訳されました。
- 形式
{"query": str, "pos": List[str], "neg": List[str]}
- クエリは質問で、posは肯定的なテキストのリスト、negは否定的なテキストのリストです。クエリに対する否定的なテキストがない場合は、全体のコーパスから一部をランダムに抽出して否定的なテキストとして使用できます。
- 例
{"query": "대한민국의 수도는?", "pos": ["미국의 수도는 워싱턴이고, 일본은 도쿄이며 한국은 서울이다."], "neg": ["미국의 수도는 워싱턴이고, 일본은 도쿄이며 북한은 평양이다."]}
パフォーマンス
モデル |
has-right-in-contexts |
mrr (平均逆順位) |
without-reranker (default) |
0.93 |
0.80 |
with-reranker (bge-reranker-large) |
0.95 |
0.84 |
with-reranker (韓国語で微調整) |
0.96 |
0.87 |
./dataset/evaluation/eval_dataset.csv
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01,
}
謝辞
一部のコードはFlagEmbeddingとKoSimCSE-SageMakerをベースに開発されています。
引用
このリポジトリが役に立った場合は、スター⭐を付けて引用していただけると幸いです。
貢献者
- Dongjin Jang, Ph.D. (AWS AI/ML Specislist Solutions Architect) | Mail | Linkedin | Git
📄 ライセンス
FlagEmbeddingはMITライセンスの下でライセンスされています。
アナリティクス
