🚀 Pythia-2.8B
“Pythia Scaling Suite”是一系列用于促进可解释性研究的模型集合,包含不同规模的模型,在大语言模型研究方面具有重要价值。
🚀 快速开始
Pythia模型可以通过以下代码加载和使用,这里以第三个 pythia-70m-deduped
检查点为例:
from transformers import GPTNeoXForCausalLM, AutoTokenizer
model = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-70m-deduped",
revision="step3000",
cache_dir="./pythia-70m-deduped/step3000",
)
tokenizer = AutoTokenizer.from_pretrained(
"EleutherAI/pythia-70m-deduped",
revision="step3000",
cache_dir="./pythia-70m-deduped/step3000",
)
inputs = tokenizer("Hello, I am", return_tensors="pt")
tokens = model.generate(**inputs)
tokenizer.decode(tokens[0])
修订版/分支 step143000
与每个模型 main
分支上的模型检查点完全对应。有关如何使用所有Pythia模型的更多信息,请参阅 GitHub上的文档。
✨ 主要特性
- 促进研究:Pythia模型套件旨在促进大语言模型的科学研究,特别是可解释性研究。
- 模型多样:包含两组八个不同大小的模型(70M、160M、410M、1B、1.4B、2.8B、6.9B和12B),每个大小有两个模型,一个在Pile数据集上训练,一个在全局去重后的Pile数据集上训练。
- 检查点丰富:每个模型提供154个中间检查点,托管在Hugging Face上作为分支。
- 性能出色:尽管设计目标并非以提升下游性能为中心,但模型的性能与类似规模的模型(如OPT和GPT - Neo套件中的模型)相当或更优。
📦 安装指南
文档未提及安装步骤,可参考 GitHub上的文档 获取相关信息。
📚 详细文档
模型详情
Pythia模型 |
非嵌入参数 |
层数 |
模型维度 |
头数 |
批量大小 |
学习率 |
等效模型 |
70M |
18,915,328 |
6 |
512 |
8 |
2M |
1.0 x 10-3 |
— |
160M |
85,056,000 |
12 |
768 |
12 |
2M |
6.0 x 10-4 |
GPT - Neo 125M, OPT - 125M |
410M |
302,311,424 |
24 |
1024 |
16 |
2M |
3.0 x 10-4 |
OPT - 350M |
1.0B |
805,736,448 |
16 |
2048 |
8 |
2M |
3.0 x 10-4 |
— |
1.4B |
1,208,602,624 |
24 |
2048 |
16 |
2M |
2.0 x 10-4 |
GPT - Neo 1.3B, OPT - 1.3B |
2.8B |
2,517,652,480 |
32 |
2560 |
32 |
2M |
1.6 x 10-4 |
GPT - Neo 2.7B, OPT - 2.7B |
6.9B |
6,444,163,072 |
32 |
4096 |
32 |
2M |
1.2 x 10-4 |
OPT - 6.7B |
12B |
11,327,027,200 |
36 |
5120 |
40 |
2M |
1.2 x 10-4 |
— |
使用与限制
预期用途
- 研究用途:Pythia的主要预期用途是研究大语言模型的行为、功能和局限性,为进行科学实验提供可控环境。
- 微调部署:只要使用符合Apache 2.0许可证,也可对Pythia - 2.8B进行进一步微调并部署。Pythia模型可与Hugging Face的 Transformers库 配合使用。若决定使用预训练的Pythia - 2.8B作为微调模型的基础,请自行进行风险和偏差评估。
非预期用途
- 不适合部署:Pythia套件并非用于部署,本身不是产品,不能用于面向人类的交互。例如,模型可能生成有害或冒犯性文本,请评估特定用例相关的风险。
- 仅支持英语:Pythia模型仅支持英语,不适合翻译或生成其他语言的文本。
- 未针对下游场景微调:Pythia - 2.8B未针对语言模型常见的下游场景(如撰写散文或商业聊天机器人)进行微调,因此其对给定提示的响应方式与ChatGPT等产品不同。
局限性和偏差
- 输出准确性:大语言模型的核心功能是根据输入文本预测下一个标记,模型使用的标记不一定能产生最“准确”的文本,切勿依赖Pythia - 2.8B产生事实准确的输出。
- 数据偏差:该模型在 Pile 数据集上训练,该数据集包含亵渎、淫秽或其他冒犯性文本。有关性别、宗教和种族的记录偏差讨论,请参阅 Pile论文的第6节。即使提示本身不包含任何明确的冒犯性内容,Pythia - 2.8B也可能生成社会不可接受或不良的文本。
- 人工审核建议:如果计划使用通过例如托管推理API生成的文本,建议在向他人展示之前由人工对该语言模型的输出进行审核,并告知受众文本是由Pythia - 2.8B生成的。
训练
训练数据
Pile 是一个825GiB的英语通用数据集,由EleutherAI专门为训练大语言模型而创建。它包含来自22个不同来源的文本,大致分为五类:学术写作(如arXiv)、互联网(如CommonCrawl)、散文(如Project Gutenberg)、对话(如YouTube字幕)和其他(如GitHub、Enron Emails)。有关所有数据源的细分、方法和伦理影响的讨论,请参阅 Pile论文。有关Pile及其组成数据集的更详细文档,请参阅 数据表。Pile可从 官方网站 或 [社区镜像](https://the - eye.eu/public/AI/pile/) 下载。在训练Pythia - 2.8B之前,Pile数据集未进行去重处理。
训练过程
- 所有模型在完全相同的数据上按相同顺序进行训练。每个模型在训练期间处理299,892,736,000个标记,每2,097,152,000个标记保存一个检查点,从
step1000
到 step143000
(与 main
相同),共保存143个检查点。此外,还提供频繁的早期检查点:step0
和 step{1,2,4...512}
。这相当于非去重模型在Pile上训练不到1个周期,去重后的Pile上训练约1.5个周期。
- 所有 Pythia 模型以2M(2,097,152个标记)的批量大小训练143000步。有关训练过程的更多详细信息,包括 [如何复现](https://github.com/EleutherAI/pythia/blob/main/README.md#reproducing - training),请参阅 GitHub。Pythia使用与 [GPT - NeoX - 20B](https://huggingface.co/EleutherAI/gpt - neox - 20b) 相同的分词器。
评估
所有16个 Pythia 模型均使用 [LM Evaluation Harness](https://github.com/EleutherAI/lm - evaluation - harness) 进行评估。可在 GitHub仓库 的 results/json/*
中按模型和步骤访问评估结果。展开以下部分,查看所有Pythia和Pythia - deduped模型与OPT和BLOOM的评估结果对比图。
LAMBADA – OpenAI
Physical Interaction: Question Answering (PIQA)
WinoGrande
AI2 Reasoning Challenge—Easy Set
SciQ
变更日志
本节比较了之前发布的 Pythia v0 与当前模型之间的差异。有关这些更改及其背后动机的进一步讨论,请参阅Pythia论文的附录B。重新训练Pythia对基准性能没有影响。
- 统一批量大小:所有模型现在均以2M标记的统一批量大小进行训练。之前,参数大小为160M、410M和1.4B的模型以4M标记的批量大小进行训练。
- 增加检查点:除了每1000个训练步骤保存一个检查点外,还在初始化(步骤0)和步骤 {1,2,4,8,16,32,64,128,256,512} 增加了检查点。
- 使用Flash Attention:新的重新训练套件中使用了Flash Attention。
- 修正学习率调度:修正了原套件中存在的一个小不一致问题:所有参数大小为2.8B或更小的模型的学习率(LR)调度衰减到起始LR的10%,而6.9B和12B模型的LR调度衰减到0。在重新训练运行中,所有模型现在均以LR衰减到最大LR的0.1倍进行训练。
命名约定和参数数量
Pythia 模型于2023年1月进行了重命名。旧的命名约定可能仍意外存在于某些文档中。当前的命名约定(70M、160M等)基于总参数数量。
当前Pythia后缀 |
旧后缀 |
总参数 |
非嵌入参数 |
70M |
19M |
70,426,624 |
18,915,328 |
160M |
125M |
162,322,944 |
85,056,000 |
410M |
350M |
405,334,016 |
302,311,424 |
1B |
800M |
1,011,781,632 |
805,736,448 |
1.4B |
1.3B |
1,414,647,808 |
1,208,602,624 |
2.8B |
2.7B |
2,775,208,960 |
2,517,652,480 |
6.9B |
6.7B |
6,857,302,016 |
6,444,163,072 |
12B |
13B |
11,846,072,320 |
11,327,027,200 |
📄 许可证
本项目采用Apache 2.0许可证。