🚀 RuBERT for NLI (自然言語推論)
このモデルは、2つの短いテキスト間の論理関係(含意、矛盾、中立)を予測するために、DeepPavlov/rubert-base-cased をファインチューニングしたものです。
🚀 クイックスタート
このモデルは、自然言語推論(NLI)とゼロショット短テキスト分類に使用できます。
💻 使用例
基本的な使用法
NLIのためにモデルを実行する方法:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
if torch.cuda.is_available():
model.cuda()
text1 = 'Сократ - человек, а все люди смертны.'
text2 = 'Сократ никогда не умрёт.'
with torch.inference_mode():
out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
print({v: proba[k] for k, v in model.config.id2label.items()})
高度な使用法
このモデルをゼロショット短テキスト分類(ラベルのみで)に使用することもできます。例えば、センチメント分析に使用する場合:
def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
label_texts
tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
with torch.inference_mode():
result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
proba = result[:, model.config.label2id[label]].cpu().numpy()
if normalize:
proba /= sum(proba)
return proba
classes = ['Я доволен', 'Я недоволен']
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
あるいは、推論に Huggingface pipelines を使用することもできます。
📚 ドキュメント
モデルのソース
このモデルは、英語から自動翻訳された一連のNLIデータセットで学習されています。
ほとんどのデータセットは Felipe Salvatoreのリポジトリ から取得されました。
一部のデータセットはオリジナルのソースから取得されました。
パフォーマンス
以下の表は、対応するdevセットにおける5つのモデルのROC AUC(1クラス対残り)を示しています。
モデル |
add_one_rte |
anli_r1 |
anli_r2 |
anli_r3 |
copa |
fever |
help |
iie |
imppres |
joci |
mnli |
monli |
mpe |
scitail |
sick |
snli |
terra |
合計 |
n_observations |
387 |
1000 |
1000 |
1200 |
200 |
20474 |
3355 |
31232 |
7661 |
939 |
19647 |
269 |
1000 |
2126 |
500 |
9831 |
307 |
101128 |
tiny/entailment |
0.77 |
0.59 |
0.52 |
0.53 |
0.53 |
0.90 |
0.81 |
0.78 |
0.93 |
0.81 |
0.82 |
0.91 |
0.81 |
0.78 |
0.93 |
0.95 |
0.67 |
0.77 |
twoway/entailment |
0.89 |
0.73 |
0.61 |
0.62 |
0.58 |
0.96 |
0.92 |
0.87 |
0.99 |
0.90 |
0.90 |
0.99 |
0.91 |
0.96 |
0.97 |
0.97 |
0.87 |
0.86 |
threeway/entailment |
0.91 |
0.75 |
0.61 |
0.61 |
0.57 |
0.96 |
0.56 |
0.61 |
0.99 |
0.90 |
0.91 |
0.67 |
0.92 |
0.84 |
0.98 |
0.98 |
0.90 |
0.80 |
vicgalle-xlm/entailment |
0.88 |
0.79 |
0.63 |
0.66 |
0.57 |
0.93 |
0.56 |
0.62 |
0.77 |
0.80 |
0.90 |
0.70 |
0.83 |
0.84 |
0.91 |
0.93 |
0.93 |
0.78 |
facebook-bart/entailment |
0.51 |
0.41 |
0.43 |
0.47 |
0.50 |
0.74 |
0.55 |
0.57 |
0.60 |
0.63 |
0.70 |
0.52 |
0.56 |
0.68 |
0.67 |
0.72 |
0.64 |
0.58 |
threeway/contradiction |
|
0.71 |
0.64 |
0.61 |
|
0.97 |
|
|
1.00 |
0.77 |
0.92 |
|
0.89 |
|
0.99 |
0.98 |
|
0.85 |
threeway/neutral |
|
0.79 |
0.70 |
0.62 |
|
0.91 |
|
|
0.99 |
0.68 |
0.86 |
|
0.79 |
|
0.96 |
0.96 |
|
0.83 |
評価(および tiny と twoway モデルの学習)には、いくつかの追加データセットが使用されました。
📄 その他の情報
属性 |
詳情 |
パイプラインタグ |
ゼロショット分類 |
タグ |
rubert, russian, nli, rte, ゼロショット分類 |
データセット |
cointegrated/nli-rus-translated-v2021 |