đ T5-small Model for WikiKG90Mv2
This is a T5-small model trained from scratch on the WikiKG90Mv2 dataset. It focuses on the tail entity prediction task, aiming to predict the object entity given a subject entity and a relation.
đ Quick Start
Model Information
- Model Type: T5-small
- Training Data: WikiKG90Mv2 dataset
Installation
To use this model, you need to install the transformers
library. You can install it using pip
:
pip install transformers
Usage Examples
Basic Usage
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
Advanced Usage
import torch
def getScores(ids, scores, pad_token_id):
"""get sequence scores from model.generate output"""
scores = torch.stack(scores, dim=1)
log_probs = torch.log_softmax(scores, dim=2)
ids = ids[:,1:]
x = ids.unsqueeze(-1).expand(log_probs.shape)
needed_logits = torch.gather(log_probs, 2, x)
final_logits = needed_logits[:, :, 0]
padded_mask = (ids == pad_token_id)
final_logits[padded_mask] = 0
final_scores = final_logits.sum(dim=-1)
return final_scores.cpu().detach().numpy()
def topkSample(input, model, tokenizer,
num_samples=5,
num_beams=1,
max_output_length=30):
tokenized = tokenizer(input, return_tensors="pt")
out = model.generate(**tokenized,
do_sample=True,
num_return_sequences = num_samples,
num_beams = num_beams,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
output_scores = True,
return_dict_in_generate=True,
max_length=max_output_length,)
out_tokens = out.sequences
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
return sorted_pair_list
def greedyPredict(input, model, tokenizer):
input_ids = tokenizer([input], return_tensors="pt").input_ids
out_tokens = model.generate(input_ids)
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
return out_str[0]
input = "Sophie Valdemarsdottir| noble title"
out = topkSample(input, model, tokenizer, num_samples=5)
out
đ Documentation
Model Training
- Training Task: Tail entity prediction task. Input should be in the form of "<entity text>| <relation text>".
- Entity and Relation Representations: We used the raw text title and descriptions from the ogb dataset to get entity and relation textual representations. Entity representation was set to the title, and description was used to disambiguate if 2 entities had the same title. If still no disambiguation was possible, we used the wikidata ID (eg. Q123456).
- Training Setup: The model was trained on WikiKG90Mv2 for approximately 1.5 epochs on 4x1080Ti GPUs. The training time for 1 epoch was approximately 5.5 days.
Model Evaluation
- Evaluation Method: We sample 300 times from the decoder for each input (s,r) pair. Then we remove predictions which do not map back to a valid entity, and rank the predictions by their log probabilities. Filtering was performed subsequently.
- Evaluation Result: We achieve 0.22 validation MRR (the full leaderboard is here https://ogb.stanford.edu/docs/lsc/leaderboards/#wikikg90mv2)
Further Processing
You can further load the list of entity aliases, then filter only those predictions which are valid entities then create a reverse mapping from alias -> integer id to get final predictions in required format. However, loading these aliases in memory as a dictionary requires a lot of RAM + you need to download the aliases file (made available here https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle) (relation file: https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle)
Code Example for Evaluation
!wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
fname = 'valid.txt'
valid_lines = []
f = open(fname)
for line in f:
valid_lines.append(line.rstrip())
f.close()
print(valid_lines[0])
from tqdm.auto import tqdm
k = 1
count_at_k = 0
max_predictions = k
max_points = 1000
for line in tqdm(valid_lines[:max_points]):
input, target = line.split('\t')
model_output = topkSample(input, model, tokenizer, num_samples=max_predictions)
prediction_strings = [x[0] for x in model_output]
if target in prediction_strings:
count_at_k += 1
print('Hits at {0} unfiltered: {1}'.format(k, count_at_k/max_points))
đ§ Technical Details
Input Format
The input should be provided in the form of "<entity text>| <relation text>".
Sampling and Ranking
During evaluation, we sample 300 times from the decoder for each input (s,r) pair. Then we remove invalid predictions and rank the remaining predictions by their log probabilities.
Filtering
After ranking, we perform filtering to get the final results.
đ License
This project is licensed under the MIT License.