🚀 kgt5-wikikg90mv2モデル
このモデルは、WikiKG90Mv2データセットでゼロから学習されたt5-smallモデルです。方法の詳細については、https://github.com/apoorvumang/kgt5/ を参照してください。
🚀 クイックスタート
このモデルは、テールエンティティ予測タスクで学習されています。つまり、主語エンティティと関係を与えられた場合に、目的語エンティティを予測します。入力は "<エンティティテキスト>| <関係テキスト>" の形式で提供する必要があります。
✨ 主な機能
- エンティティと関係のテキスト表現を取得するために、生テキストのタイトルと説明を使用します。
- WikiKG90Mv2で約1.5エポック学習され、4x1080Ti GPUを使用しています。1エポックの学習時間は約5.5日です。
- モデルを評価する際には、各入力 (s,r) ペアに対してデコーダから300回サンプリングし、予測結果を評価します。
📦 インストール
以下のコードを使用して、事前学習済みモデルをロードできます。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
💻 使用例
基本的な使用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
高度な使用法
import torch
def getScores(ids, scores, pad_token_id):
"""get sequence scores from model.generate output"""
scores = torch.stack(scores, dim=1)
log_probs = torch.log_softmax(scores, dim=2)
ids = ids[:,1:]
x = ids.unsqueeze(-1).expand(log_probs.shape)
needed_logits = torch.gather(log_probs, 2, x)
final_logits = needed_logits[:, :, 0]
padded_mask = (ids == pad_token_id)
final_logits[padded_mask] = 0
final_scores = final_logits.sum(dim=-1)
return final_scores.cpu().detach().numpy()
def topkSample(input, model, tokenizer,
num_samples=5,
num_beams=1,
max_output_length=30):
tokenized = tokenizer(input, return_tensors="pt")
out = model.generate(**tokenized,
do_sample=True,
num_return_sequences = num_samples,
num_beams = num_beams,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
output_scores = True,
return_dict_in_generate=True,
max_length=max_output_length,)
out_tokens = out.sequences
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
return sorted_pair_list
def greedyPredict(input, model, tokenizer):
input_ids = tokenizer([input], return_tensors="pt").input_ids
out_tokens = model.generate(input_ids)
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
return out_str[0]
input = "Sophie Valdemarsdottir| noble title"
out = topkSample(input, model, tokenizer, num_samples=5)
out
🔧 技術詳細
- エンティティ表現はタイトルに設定され、説明は2つのエンティティが同じタイトルを持つ場合の曖昧さ解消に使用されます。
- モデルを評価する際には、各入力に対して300回サンプリングし、予測結果をログ確率でランク付けします。その後、フィルタリングを行います。
- 検証セットでは0.22のMRRを達成しています (完全なリーダーボードはこちら https://ogb.stanford.edu/docs/lsc/leaderboards/#wikikg90mv2)
📄 ライセンス
このモデルはMITライセンスの下で提供されています。
# download valid.txt. you can also try same url with test.txt. however test does not contain the correct tails
!wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
fname = 'valid.txt'
valid_lines = []
f = open(fname)
for line in f:
valid_lines.append(line.rstrip())
f.close()
print(valid_lines[0])
from tqdm.auto import tqdm
# try unfiltered hits@k. this is approximation since model can sample same seq multiple times
# you should run this on gpu if you want to evaluate on all points with 300 samples each
k = 1
count_at_k = 0
max_predictions = k
max_points = 1000
for line in tqdm(valid_lines[:max_points]):
input, target = line.split('\t')
model_output = topkSample(input, model, tokenizer, num_samples=max_predictions)
prediction_strings = [x[0] for x in model_output]
if target in prediction_strings:
count_at_k += 1
print('Hits at {0} unfiltered: {1}'.format(k, count_at_k/max_points))