🚀 GPT-J 6B
GPT-J 6Bは、Transformerモデルで、Ben WangのMesh Transformer JAX を使用して学習されました。「GPT-J」はモデルのクラスを指し、「6B」は学習可能なパラメータの数を表します。
🚀 クイックスタート
このモデルは、AutoModelForCausalLM
機能を使用して簡単に読み込むことができます。
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B" )
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B" )
✨ 主な機能
GPT-Jは、英語の内部表現を学習し、下流タスクに役立つ特徴を抽出するために使用できます。このモデルは、プロンプトからテキストを生成することに最適です。
📚 ドキュメント
モデルの説明
GPT-J 6Bは、モデル次元が4096、フィードフォワード次元が16384の28層から構成されています。モデル次元は16のヘッドに分割され、各ヘッドの次元は256です。各ヘッドの64次元には、Rotary Position Embedding (RoPE)が適用されています。このモデルは、GPT-2/GPT-3と同じBPEセットを使用して、50257のトークン化語彙で学習されています。
ハイパーパラメータ
値
\(n_{parameters}\)
6053381344
\(n_{layers}\)
28*
\(d_{model}\)
4096
\(d_{ff}\)
16384
\(n_{heads}\)
16
\(d_{head}\)
256
\(n_{ctx}\)
2048
\(n_{vocab}\)
50257/50400† (GPT-2/3と同じトークナイザー)
位置符号化
Rotary Position Embedding (RoPE)
RoPE次元
64
* 各層は、1つのフィードフォワードブロックと1つの自己注意ブロックで構成されています。
† 埋め込み行列のサイズは50400ですが、GPT-2トークナイザーで使用されるのは50257エントリだけです。
想定される使用方法と制限事項
想定外の使用方法
GPT-J-6Bは、ファインチューニング、監視、および/またはモデレーションなしでのデプロイを想定していません。これ自体は製品ではなく、人との対話には使用できません。たとえば、このモデルは有害または不快なテキストを生成する可能性があります。特定の使用ケースに関連するリスクを評価してください。
GPT-J-6Bは英語のデータセットで学習されているため、翻訳や他の言語のテキスト生成には適していません。
GPT-J-6Bは、文章の執筆や商用チャットボットなど、言語モデルが一般的にデプロイされる下流のコンテキストでファインチューニングされていません。これは、ChatGPTのような製品が人間のフィードバックによる強化学習(RLHF)などの方法を使用してファインチューニングされているのとは異なり、GPT-J-6Bは与えられたプロンプトに対してChatGPTのように応答しないことを意味します。
制限事項とバイアス
GPT-Jの核心機能は、テキストの文字列を受け取り、次のトークンを予測することです。言語モデルはこれ以外のタスクにも広く使用されていますが、この作業には多くの未知の要素があります。GPT-Jにプロンプトを与える際には、統計的に最も可能性の高い次のトークンが必ずしも最も「正確」なテキストを生成するわけではないことを覚えておくことが重要です。GPT-Jが事実的に正確な出力を生成することに決して依存しないでください。
GPT-Jは、不適切な言葉や露骨な表現などを含むことが知られているPileデータセットで学習されています。使用ケースによっては、GPT-Jが社会的に受け入れられないテキストを生成する可能性があります。Pileのバイアスに関する詳細な分析については、Pile論文のセクション5と6 を参照してください。
すべての言語モデルと同様に、GPT-Jが特定のプロンプトにどのように応答するかを事前に予測することは困難であり、不快なコンテンツが突然発生する可能性があります。結果を公開する前に、人間が出力を選別またはフィルタリングすることをお勧めします。これにより、望ましくないコンテンツを削除し、結果の品質を向上させることができます。
学習データ
GPT-J 6Bは、EleutherAI によって作成された大規模な精選データセットであるthe Pile で学習されました。
学習手順
このモデルは、TPU v3 - 256ポッドで383,500ステップにわたって4020億トークンで学習されました。自己回帰型言語モデルとして学習され、次のトークンを正しく予測する尤度を最大化するために交差エントロピー損失を使用しました。
評価結果
モデル
公開状況
学習FLOPs
LAMBADA PPL ↓
LAMBADA Acc ↑
Winogrande ↑
Hellaswag ↑
PIQA ↑
データセットサイズ (GB)
ランダム予測
✓
0
~多い
~0%
50%
25%
25%
0
GPT-3 Ada‡
✗
-----
9.95
51.6%
52.9%
43.4%
70.5%
-----
GPT-2 1.5B
✓
-----
10.63
51.21%
59.4%
50.9%
70.8%
40
GPT-Neo 1.3B‡
✓
3.0e21
7.50
57.2%
55.0%
48.9%
71.1%
825
Megatron-2.5B*
✗
2.4e21
-----
61.7%
-----
-----
-----
174
GPT-Neo 2.7B‡
✓
6.8e21
5.63
62.2%
56.5%
55.8%
73.0%
825
GPT-3 1.3B*‡
✗
2.4e21
5.44
63.6%
58.7%
54.7%
75.1%
~800
GPT-3 Babbage‡
✗
-----
5.58
62.4%
59.0%
54.5%
75.5%
-----
Megatron-8.3B*
✗
7.8e21
-----
66.5%
-----
-----
-----
174
GPT-3 2.7B*‡
✗
4.8e21
4.60
67.1%
62.3%
62.8%
75.6%
~800
Megatron-11B†
✓
1.0e22
-----
-----
-----
-----
-----
161
GPT-J 6B‡
✓
1.5e22
3.99
69.7%
65.3%
66.1%
76.5%
825
GPT-3 6.7B*‡
✗
1.2e22
4.00
70.3%
64.5%
67.4%
78.0%
~800
GPT-3 Curie‡
✗
-----
4.00
69.3%
65.6%
68.5%
77.9%
-----
GPT-3 13B*‡
✗
2.3e22
3.56
72.5%
67.9%
70.9%
78.5%
~800
GPT-3 175B*‡
✗
3.1e23
3.00
76.2%
70.2%
78.9%
81.0%
~800
GPT-3 Davinci‡
✗
-----
3.0
75%
72%
78%
80%
-----
モデルは、性能で概ねソートされています。性能が利用できない場合は、FLOPsでソートされています。
* 評価数値はそれぞれの著者によって報告されています。他のすべての数値は、lm-evaluation-harness
をリリースされた重みまたはAPIアクセスで実行することによって提供されています。微妙な実装の違いや異なるゼロショットタスクのフレーミングのため、これらは直接比較できない場合があります。詳細については、このブログ記事 を参照してください。
† Megatron - 11Bは比較可能なメトリクスを提供しておらず、リリースされた重みを使用するいくつかの実装では、生成品質と評価が再現されていません。(1
2 3 )したがって、評価は試行されていません。
‡ これらのモデルは、テストセットの汚染の可能性があるデータで学習されています。OpenAIのGPT - 3モデルは、特定のテストセットの学習データの重複排除に失敗しています。一方、GPT - Neoモデルやこのモデルは、いずれのテストセットに対しても重複排除されていないPileで学習されています。
引用と関連情報
BibTeXエントリ
このモデルを引用するには、以下のようにしてください。
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
このモデルを学習したコードベースを引用するには、以下のようにしてください。
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
このモデルを使用した場合、是非ご意見をお聞かせください!GitHub 、Discordで連絡するか、Benにメールを送ってください。
📄 ライセンス
このモデルは、Apache - 2.0ライセンスの下で提供されています。
謝辞
このプロジェクトは、GoogleがTPU Research Cloud を通じて提供してくれた計算資源、およびCloud TPU VM Alphaへの早期アクセスを提供してくれたCloud TPUチームのおかげで可能になりました。
様々な形で支援してくれた皆さんに感謝します(アルファベット順)。