🚀 日本語句子BERT模型(版本2)
這是一個用於日語的Sentence - BERT模型(版本2)。它使用了比版本1更好的損失函數MultipleNegativesRankingLoss進行訓練,是一個改良版本。在內部的非公開數據集上,該模型的精度比版本1高出約1.5 - 2個百分點。
🚀 快速開始
本模型使用了預訓練模型[cl - tohoku/bert - base - japanese - whole - word - masking](https://huggingface.co/cl - tohoku/bert - base - japanese - whole - word - masking)。因此,在執行推理時需要安裝fugashi和ipadic(可使用命令pip install fugashi ipadic
進行安裝)。
✨ 主要特性
- 改良版模型:使用了更好的損失函數
MultipleNegativesRankingLoss
進行訓練,相比版本1精度有所提升。
- 依賴預訓練模型:基於[cl - tohoku/bert - base - japanese - whole - word - masking](https://huggingface.co/cl - tohoku/bert - base - japanese - whole - word - masking)進行訓練。
📦 安裝指南
執行推理前,需要安裝fugashi和ipadic,可使用以下命令進行安裝:
pip install fugashi ipadic
💻 使用示例
基礎用法
from transformers import BertJapaneseTokenizer, BertModel
import torch
class SentenceBertJapanese:
def __init__(self, model_name_or_path, device=None):
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
self.model = BertModel.from_pretrained(model_name_or_path)
self.model.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model.to(device)
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
@torch.no_grad()
def encode(self, sentences, batch_size=8):
all_embeddings = []
iterator = range(0, len(sentences), batch_size)
for batch_idx in iterator:
batch = sentences[batch_idx:batch_idx + batch_size]
encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
truncation=True, return_tensors="pt").to(self.device)
model_output = self.model(**encoded_input)
sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')
all_embeddings.extend(sentence_embeddings)
return torch.stack(all_embeddings)
MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
model = SentenceBertJapanese(MODEL_NAME)
sentences = ["暴走したAI", "暴走した人工知能"]
sentence_embeddings = model.encode(sentences, batch_size=8)
print("Sentence embeddings:", sentence_embeddings)
📚 詳細文檔
舊版本說明
關於舊版本的詳細解說可查看:https://qiita.com/sonoisa/items/1df94d0a98cd4f209051。若將模型名稱替換為"sonoisa/sentence - bert - base - ja - mean - tokens - v2",則可使用本模型。
📄 許可證
本項目採用CC - BY - SA 4.0許可證。