🚀 DNA和塊擴散模型
本項目使用塊擴散架構(代碼,論文)以及AgroNT的6核苷酸標記(4^6 = 4,096)。採用了 dna-blockdiff-2 的權重,在木瓜基因組上進行了一個epoch的訓練。輸出標記受到限制,以避免輸出單核苷酸或N(未知核苷酸)標記。
訓練損失有所波動,但驗證曲線(在人類基因組上)持續改善。
🚀 快速開始
加載模型
from transformers import AutoModelForMaskedLM
m = AutoModelForMaskedLM.from_pretrained(
"monsoon-nlp/dna-blockdiff-papaya",
trust_remote_code=True,
)
序列困惑度計算
使用 此代碼分支
cd bd3lms && python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
algo.backbone=hf_dit \
data=instadeep \
model.length=256 \
block_size=4 \
wandb=null \
mode=ppl_eval \
eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \
data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya"
文本生成
使用 此代碼分支 並儘量保持參數不變
cd bd3lms && python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
algo.backbone=hf_dit \
data=instadeep \
model.length=256 \
block_size=4 \
wandb=null \
mode=sample_eval \
eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \
data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya"
100% 64/64 [00:07<00:00, 8.73it/s]
Sliding Window Gen PPL: 100% 1/1 [00:00<00:00, 4.26it/s]
Text samples: ['<cls> AAATGG TTATTG CAAATC TCTAAA GAAGTA TTAAGA GAATGA TAAGAT ATGTTG AGAGAA TTACAC AGCATT GAGAAG TCTAAA TTGAAA AACCAT AAAAAT GTGAGT AGGTCA GTATGT AAGAAT TGTGTT GAACTT ATCAAT ATGTAG ACATCA TTTTGA TATAAA TATATA AAGAAA ATTTAA AAAAAA TAATAA ATAACT TTAAAA TGTTAA TAATAT TAAAAT GGAGAA GAATAA CCTTTA TTATCT ATTACA ATAATA ATTATA TTTTGG ATGAAA CATTCA GAATAT TAGATA ATTTTT ATTAAT GTATCT TCAAAT GAACAA ACTTAT ATTTAA AAACTC TAAAAT ATTTAT AGACTA AAAACT AGAGAA ATTAAT AATAAA AATAAA AAACAC AAATTT ATAAAA CCAAAT AAAGGT AATAAA AACAAA ATATTT ACAAAT AACTAT TAATGA AGTTAA AAAATG AATAAA TTTATA ATAAAA TATTTA TGTTTT AAATTA AAAATT TGAATA AAACTC ACAAAT TATTTA AATACT AATATG TATTTA TATAAT AATATA TGAAAA AATTAT GAATTT TAATTA AAATTT TTATAT TTATAA AAATTT ATATTA ATTAAT TTTTAA CAACTT AAATAA AAAGGA ATATTA AAGTCA ATAATT ATATAT TACTTA TAGACA AATAAA AAAATT CTCAAT AAAATT TAAAAT ATTAAA ATTTTG AAATTA AAATAA AAATAT AATAAT TCACTT CACACA ATACAA CTAACT TATACA ATTAAT TTAAAA GATTAA TTGAAT AAAATT ATTATC ACATGA AATTGG AATAAA CAAAAT AATATA TAAATA TATCAA AAATTG ATATAT GAAAAT CTTTAT GTGAAA TTTTAA GAAATA AATTTA ATATGC TGTTTT AAATTT TTTAAA TTTATT AAATTA AATTAA TATTAA ATTTTA ATAATA AAAATT TATAAT AATTAA TAATTT ATTAGC TTAAAA TTAAAT ATTTTA ATGTAA AAACTA TAATGC AATTTA AAGATT TTTTTA AATTAT ATAAGT TAATAA CTATAA TAATAC ATTTCT TTAATT AAAGAA GAAATT TTAAAT TTAAAT TTTTAA GTTAGA ATTACA TTAAAA TATAAA TATAAT AAATAA TAATTA TTAAAA TATACT AAATAG TTTATT AATTAT ATACTT AATATA ATATTT AATATT ATTATA AAAAAT AATCAT ATATAT ATAATT TTTTTT CTTTTT AACTTA TAAATT AATCAG TTATGA TACTTT ATAAAT ATTTGT TAATGG TGAATG AATATG CTTGAA AAGAAC AAAGAA GAAATT AAGAGA ACTTGA ATTTGG TGGTTA ATAAAT CTAATT ATATAT ATTATA TAAAAA TAGGAA TAATTT GAAAAT TAATAG AAAAGA AAAAGA ATAATT TTATGC TTCTTT ATATAA TTTAAC AAATAT TTTTTT ATAATA ATAATA TAATTA AACTTA AATTAT ATTATA TTCATC ATTATA']
Generative perplexity: tensor(19.5559, device='cuda:0')
Entropy: tensor(3.4023, device='cuda:0')
⚠️ 重要提示
該腳本計算的是 gpt2-large
序列的困惑度,但這對於評估DNA序列的準確性可能並無幫助。
📄 許可證
本項目採用 apache-2.0
許可證。
📦 模型信息
屬性 |
詳情 |
庫名稱 |
transformers |
基礎模型 |
monsoon-nlp/dna-blockdiff-2 |
標籤 |
biology, bd3lm |
許可證 |
apache-2.0 |