🚀 Pythia-160m Fine-tuned with Cell2Sentence
This is a Pythia-160m model fine-tuned using Cell2Sentence on full scRNA-seq cells, adapting large language models to single-cell transcriptomics.
🚀 Quick Start
This is the Pythia-160m model developed by EleutherAI and fine-tuned using Cell2Sentence on full scRNA-seq cells. Cell2Sentence is a novel approach for adapting large language models to single - cell transcriptomics. We transform single - cell RNA sequencing data into sequences of gene names ordered by expression level, called "cell sentences". This model was trained on the immune tissue dataset from Domínguez et al. using 8 A100 40GB GPUs for about 20 hours on tasks like conditional cell generation, unconditional cell generation, and cell type prediction.
Cell2Sentence Links
Pythia Links
✨ Features
- Cell Generation: Capable of both conditional and unconditional cell generation.
- Cell Type Prediction: Can predict cell types based on gene expression.
- Post - processing: Comes with a post - processing function to clean up generated cell sentences.
📦 Installation
No installation steps were provided in the original README, so this section is skipped.
💻 Usage Examples
Basic Usage
We provide an example of how to use the model to conditionally generate a cell with a post - processing function to remove duplicate and invalid genes.
import json
import re
from collections import Counter
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def post_process_generated_cell_sentences(
cell_sentence: str,
gene_dictionary: List
):
"""
Post-processing function for generated cell sentences.
Invalid genes are removed and ranks of duplicated genes are averaged.
Arguments:
cell_sentence: generated cell sentence string
gene_dictionary: list of gene vocabulary (all uppercase)
Returns:
post_processed_sentence: generated cell sentence after post processing steps
"""
generated_gene_names = cell_sentence.split(" ")
generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]
generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]
gene_name_to_occurrences = Counter(generated_gene_names)
post_processed_sentence = generated_gene_names.copy()
for gene_name in gene_name_to_occurrences:
if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
average_position = int(sum(occurrence_positions) / len(occurrence_positions))
post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]
post_processed_sentence.insert(average_position, gene_name)
return post_processed_sentence
genes_path = "pbmc_vocab.json"
with open(vocab_path, "r") as f:
gene_dictionary = json.load(f)
model_name = "vandijklab/pythia-160m-c2s"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2"
).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)
cell_type = "T Cell"
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."
tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=1024,
top_k=50,
top_p=0.95,
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)
Advanced Usage
In order to generate full cells, the max_length
generation parameter should be changed to 9200. However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required. Unconditional cell generation and cell type prediction prompts are also available, but we do not include an example cell sentence to format the prompt. Refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.
📚 Documentation
Evaluation
This model was evaluated on KNN classification and Gromov - Wasserstein (GW) distance. The label for a generated cell is the corresponding cell type used in its corresponding prompt for generation. Ground truth cells were sampled with replacement from a held - out test dataset. The generated cells are converted to expression vectors using the method described in the paper.
Property |
Details |
Model Type |
Pythia - 160m fine - tuned with Cell2Sentence |
Training Data |
Immune tissue dataset from Domínguez et al. |
Model |
k=3 NN (↑) |
k=5 NN (↑) |
k=10 NN (↑) |
k=25 NN (↑) |
GW (↓) |
scGEN |
0.2376 |
0.2330 |
0.2377 |
0.2335 |
315.9505 |
scVI |
0.2436 |
0.2400 |
0.2425 |
0.2348 |
302.1285 |
scDiffusion |
0.2335 |
0.2288 |
0.2368 |
0.2306 |
72.0208 |
scGPT |
0.1838 |
0.1788 |
0.1811 |
0.1882 |
2989.8066 |
C2S (Pythia-160m) |
0.2588 |
0.2565 |
0.2746 |
0.2715 |
54.3040 |
📄 License
This project is licensed under the CC BY - NC - ND 4.0 license.