đ Korean Reranker Training on Amazon SageMaker
This guide provides fine-tuning instructions for developing a Korean reranker. The ko-reranker is a fine-tuned model for Korean data based on BAAI/bge-reranker-larger. For more details, please refer to korean-reranker-git and AWS Blog, Boosting Retrieval Augmented Generation (RAG) Performance with Korean Reranker.
⨠Features
- Unlike embedding models, rerankers take questions and documents as inputs and directly output similarity scores instead of embeddings.
- By inputting questions and passages into the reranker, you can obtain relevance scores.
- Since rerankers are optimized based on CrossEntropy loss, the relevance scores are not limited to a specific range.
đĻ Installation
No specific installation steps are provided in the original document, so this section is skipped.
đģ Usage Examples
Basic Usage
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["ëë ëëĨŧ ėĢė´í´", "ëë ëëĨŧ ėŦëí´"], \
["ëë ëëĨŧ ėĸėí´", "ëė ëí ëė ę°ė ė ėŦë ėŧ ėë ėė´"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
Advanced Usage
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type='ml.g5.large'
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "ëë ëëĨŧ ėĢė´í´", "text_pair": "ëë ëëĨŧ ėŦëí´"},
{"text": "ëë ëëĨŧ ėĸėí´", "text_pair": "ëė ëí ëė ę°ė ė ėŦë ėŧ ėë ėė´"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
out = json.loads(response['Body'].read().decode())
print (f'Response: {out}')
đ Documentation
Background
- Context order affects accuracy (Lost in Middle, Liu et al., 2023).
- Reasons for using rerankers:
- Currently, adding more context to LLMs does not necessarily improve performance. Relevant information should be ranked higher for better answers.
- The similarity (relevance) scores used in semantic search are not precise. (i.e., does the top-ranked result always contain more relevant information to the question than the lower-ranked ones?)
Reranker models
Dataset
- msmarco-triplets:
- (Question, Answer, Negative)-Triplets from the MS MARCO Passages dataset, with 499,184 samples.
- The dataset is in English and was translated using Amazon Translate.
- Format:
{"query": str, "pos": List[str], "neg": List[str]}
The query is the question, "pos" is a list of positive texts, and "neg" is a list of negative texts. If there are no negative texts for a query, some can be randomly selected from the entire corpus.
{"query": "ëíë¯ŧęĩė ėëë?", "pos": ["미ęĩė ėëë ėėąí´ė´ęŗ , ėŧëŗ¸ė ëėŋė´ëа íęĩė ėė¸ė´ë¤."], "neg": ["미ęĩė ėëë ėėąí´ė´ęŗ , ėŧëŗ¸ė ëėŋė´ëа ëļíė íėė´ë¤."]}
Performance
Model |
has-right-in-contexts |
mrr (mean reciprocal rank) |
without-reranker (default) |
0.93 |
0.80 |
with-reranker (bge-reranker-large) |
0.95 |
0.84 |
with-reranker (fine-tuned using korean) |
0.96 |
0.87 |
./dataset/evaluation/eval_dataset.csv
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01
}
đ§ Technical Details
No specific technical details are provided in the original document, so this section is skipped.
đ License
FlagEmbedding is licensed under the MIT License.
Acknowledgement
Part of the code is developed based on FlagEmbedding and KoSimCSE-SageMaker.
Citation
If you find this repository useful, please consider giving a like â and citation.
Contributors
Dongjin Jang, Ph.D. (AWS AI/ML Specislist Solutions Architect) | Mail | Linkedin | Git
Analytics
