🚀 DNA和块扩散模型
本项目使用块扩散架构(代码,论文)以及AgroNT的6核苷酸标记(4^6 = 4,096)。采用了 dna-blockdiff-2 的权重,在木瓜基因组上进行了一个epoch的训练。输出标记受到限制,以避免输出单核苷酸或N(未知核苷酸)标记。
训练损失有所波动,但验证曲线(在人类基因组上)持续改善。
🚀 快速开始
加载模型
from transformers import AutoModelForMaskedLM
m = AutoModelForMaskedLM.from_pretrained(
"monsoon-nlp/dna-blockdiff-papaya",
trust_remote_code=True,
)
序列困惑度计算
使用 此代码分支
cd bd3lms && python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
algo.backbone=hf_dit \
data=instadeep \
model.length=256 \
block_size=4 \
wandb=null \
mode=ppl_eval \
eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \
data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya"
文本生成
使用 此代码分支 并尽量保持参数不变
cd bd3lms && python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
algo.backbone=hf_dit \
data=instadeep \
model.length=256 \
block_size=4 \
wandb=null \
mode=sample_eval \
eval.checkpoint_path="monsoon-nlp/dna-blockdiff-papaya" \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize4 \
data.tokenizer_name_or_path="monsoon-nlp/dna-blockdiff-papaya"
100% 64/64 [00:07<00:00, 8.73it/s]
Sliding Window Gen PPL: 100% 1/1 [00:00<00:00, 4.26it/s]
Text samples: ['<cls> AAATGG TTATTG CAAATC TCTAAA GAAGTA TTAAGA GAATGA TAAGAT ATGTTG AGAGAA TTACAC AGCATT GAGAAG TCTAAA TTGAAA AACCAT AAAAAT GTGAGT AGGTCA GTATGT AAGAAT TGTGTT GAACTT ATCAAT ATGTAG ACATCA TTTTGA TATAAA TATATA AAGAAA ATTTAA AAAAAA TAATAA ATAACT TTAAAA TGTTAA TAATAT TAAAAT GGAGAA GAATAA CCTTTA TTATCT ATTACA ATAATA ATTATA TTTTGG ATGAAA CATTCA GAATAT TAGATA ATTTTT ATTAAT GTATCT TCAAAT GAACAA ACTTAT ATTTAA AAACTC TAAAAT ATTTAT AGACTA AAAACT AGAGAA ATTAAT AATAAA AATAAA AAACAC AAATTT ATAAAA CCAAAT AAAGGT AATAAA AACAAA ATATTT ACAAAT AACTAT TAATGA AGTTAA AAAATG AATAAA TTTATA ATAAAA TATTTA TGTTTT AAATTA AAAATT TGAATA AAACTC ACAAAT TATTTA AATACT AATATG TATTTA TATAAT AATATA TGAAAA AATTAT GAATTT TAATTA AAATTT TTATAT TTATAA AAATTT ATATTA ATTAAT TTTTAA CAACTT AAATAA AAAGGA ATATTA AAGTCA ATAATT ATATAT TACTTA TAGACA AATAAA AAAATT CTCAAT AAAATT TAAAAT ATTAAA ATTTTG AAATTA AAATAA AAATAT AATAAT TCACTT CACACA ATACAA CTAACT TATACA ATTAAT TTAAAA GATTAA TTGAAT AAAATT ATTATC ACATGA AATTGG AATAAA CAAAAT AATATA TAAATA TATCAA AAATTG ATATAT GAAAAT CTTTAT GTGAAA TTTTAA GAAATA AATTTA ATATGC TGTTTT AAATTT TTTAAA TTTATT AAATTA AATTAA TATTAA ATTTTA ATAATA AAAATT TATAAT AATTAA TAATTT ATTAGC TTAAAA TTAAAT ATTTTA ATGTAA AAACTA TAATGC AATTTA AAGATT TTTTTA AATTAT ATAAGT TAATAA CTATAA TAATAC ATTTCT TTAATT AAAGAA GAAATT TTAAAT TTAAAT TTTTAA GTTAGA ATTACA TTAAAA TATAAA TATAAT AAATAA TAATTA TTAAAA TATACT AAATAG TTTATT AATTAT ATACTT AATATA ATATTT AATATT ATTATA AAAAAT AATCAT ATATAT ATAATT TTTTTT CTTTTT AACTTA TAAATT AATCAG TTATGA TACTTT ATAAAT ATTTGT TAATGG TGAATG AATATG CTTGAA AAGAAC AAAGAA GAAATT AAGAGA ACTTGA ATTTGG TGGTTA ATAAAT CTAATT ATATAT ATTATA TAAAAA TAGGAA TAATTT GAAAAT TAATAG AAAAGA AAAAGA ATAATT TTATGC TTCTTT ATATAA TTTAAC AAATAT TTTTTT ATAATA ATAATA TAATTA AACTTA AATTAT ATTATA TTCATC ATTATA']
Generative perplexity: tensor(19.5559, device='cuda:0')
Entropy: tensor(3.4023, device='cuda:0')
⚠️ 重要提示
该脚本计算的是 gpt2-large
序列的困惑度,但这对于评估DNA序列的准确性可能并无帮助。
📄 许可证
本项目采用 apache-2.0
许可证。
📦 模型信息
属性 |
详情 |
库名称 |
transformers |
基础模型 |
monsoon-nlp/dna-blockdiff-2 |
标签 |
biology, bd3lm |
许可证 |
apache-2.0 |