🚀 iSEEEK
遺伝子ランキングを探索することにより、超大規模な単一細胞トランスクリプトームを統合するための普遍的なアプローチ
🚀 クイックスタート
💻 使用例
基本的な使用法
import torch
import gzip
import re
from tqdm import tqdm
import numpy as np
import scanpy as sc
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerFast, BertForMaskedLM
class LineDataset(Dataset):
def __init__(self, lines):
self.lines = lines
self.regex = re.compile(r'\-|\.')
def __getitem__(self, i):
return self.regex.sub('_', self.lines[i])
def __len__(self):
return len(self.lines)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_num_threads(2)
tokenizer = PreTrainedTokenizerFast.from_pretrained("TJMUCH/transcriptome-iseeek")
model = BertForMaskedLM.from_pretrained("TJMUCH/transcriptome-iseeek").bert
model = model.to(device)
model.eval()
lines = [s.strip().decode() for s in gzip.open("pbmc_ranking.txt.gz")]
labels = [s.strip().decode() for s in gzip.open("pbmc_label.txt.gz")]
labels = np.asarray(labels)
ds = LineDataset(lines)
dl = DataLoader(ds, batch_size=80)
features = []
for a in tqdm(dl, total=len(dl)):
batch = tokenizer(a, max_length=128, truncation=True,
padding=True, return_tensors="pt")
for k, v in batch.items():
batch[k] = v.to(device)
with torch.no_grad():
out = model(**batch)
f = out.last_hidden_state[:,0,:]
features.extend(f.tolist())
features = np.stack(features)
adata = sc.AnnData(features)
adata.obs['celltype'] = labels
adata.obs.celltype = adata.obs.celltype.astype("category")
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata)
sc.pl.umap(adata, color=['celltype','leiden'],save= "UMAP")
高度な使用法
cell_counts = len(lines)
x = np.zeros((cell_counts, len(tokenizer)), dtype=np.float16)
for a in tqdm(dl, total=len(dl)):
batch = tokenizer(a, max_length=128, truncation=True,
padding=True, return_tensors="pt")
for k, v in batch.items():
batch[k] = v.to(device)
with torch.no_grad():
out = model(**batch)
eos_idxs = batch.attention_mask.sum(dim=1) - 1
f = out.last_hidden_state
batch_size = f.shape[0]
input_ids = batch.input_ids
for i in range(batch_size):
token_norms = [f[i][j].norm().item() for j in range(1, eos_idxs[i])]
idxs = input_ids[i].tolist()[1:eos_idxs[i]]
x[counter, idxs] = token_norms
counter = counter + 1