🚀 MPT-7B-Instruct-8k
MPT-7B-Instruct-8kは、長文形式の指示に従うためのモデルで、特に長い文書に対する質問応答や要約に優れています。このモデルは、Dolly HHRLHF 上で MPT-7B-8k を微調整することで構築されています。ここでの Dolly HHRLHF は、Databricks Dolly-15k と Anthropic Helpful and Harmless (HH-RLHF) データセットに由来しています。さらに、このモデルは Competition Math、Duorc、CoT GSM8k、Qasper、Quality、Summ Screen FD、Spider でも訓練されています。これは MPT-30B-Instruct の訓練データセットと同じです。
このモデルは MosaicML によって訓練され、修正されたデコーダー専用のトランスフォーマーアーキテクチャを採用しています。
🚀 クイックスタート
モデル情報
プロパティ |
詳細 |
モデルタイプ |
MPT-7B-Instruct-8kは、長文形式の指示に従うためのモデルで、特に長い文書の質問応答や要約に適しています。 |
訓練データ |
このモデルは、Competition Math、Databricks Dolly-15k、Anthropic HH-RLHF、Duorc、CoT GSM8k、Qasper、Quality、Summ Screen FD、Spider などの複数のデータセットで訓練されています。 |
使用例
基本的な使用法
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained(
'mosaicml/mpt-7b-instruct-8k',
trust_remote_code=True
)
注意:このモデルでは、from_pretrained
メソッドを呼び出す際に trust_remote_code=True
を渡す必要があります。これは、Hugging Face の transformers
パッケージにまだ含まれていないカスタムの MPT
モデルアーキテクチャを使用しているためです。MPT
には、FlashAttention、ALiBi、QK LayerNorm など、多くの訓練効率特性オプションが含まれています。
高度な使用法
triton 実装 の最適化された FlashAttention を使用するには、GPU (cuda:0
) 上で attn_impl='triton'
と bfloat16
精度でモデルをロードできます。
import torch
import transformers
name = 'mosaicml/mpt-7b-instruct-8k'
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'triton'
config.init_device = 'cuda:0'
model = transformers.AutoModelForCausalLM.from_pretrained(
name,
config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
このモデルは最初に 2048 のシーケンス長で訓練され、最大 8192 のシーケンス長に適応するための追加の事前訓練段階が行われました。ただし、ALiBi を使用すると、微調整または推論中に最大シーケンス長をさらに増やすことができます。例えば:
import transformers
name = 'mosaicml/mpt-7b-instruct-8k'
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
config.max_seq_len = 16384
model = transformers.AutoModelForCausalLM.from_pretrained(
name,
config=config,
trust_remote_code=True
)
このモデルは EleutherAI/gpt-neox-20b トークナイザーに基づく MPT-7B-chat トークナイザーを使用し、追加の ChatML トークンを含んでいます。
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('mosaicml/mpt-7b-8k')
その後、このモデルをテキスト生成パイプラインで使用できます。
注意:低精度で Torch モジュールを実行する場合は、torch.autocast コンテキストマネージャー を使用することをお勧めします。
from transformers import pipeline
with torch.autocast('cuda', dtype=torch.bfloat16):
inputs = tokenizer('以下は、無肉のバナナパンのレシピです:\n', return_tensors="pt").to('cuda')
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, device='cuda:0')
with torch.autocast('cuda', dtype=torch.bfloat16):
print(
pipe('以下は、無肉のバナナパンのレシピです:\n',
max_new_tokens=100,
do_sample=True,
use_cache=True))
📚 ドキュメント
🔧 技術詳細
モデルアーキテクチャ
このアーキテクチャは、標準のデコーダー専用トランスフォーマーを修正したものです。
モデルは、以下の点で標準トランスフォーマーを修正しています。
ハイパーパラメータ |
値 |
パラメータ数 |
67億 |
レイヤー数 |
32 |
ヘッド数 |
32 |
モデル次元 |
4096 |
語彙サイズ |
50432 |
シーケンス長 |
2048 |
データミックス
このモデルは、以下のデータミックスで訓練されています。
データソース |
ソース内のトークン数 |
割合 |
competition_math |
160万 |
3.66% |
cot_gsm8k |
336万 |
7.67% |
dialogsum |
10万 |
0.23% |
dolly_hhrlhf |
589万 |
13.43% |
duorc |
780万 |
17.80% |
qasper |
872万 |
19.90% |
quality |
1129万 |
25.78% |
scrolls/summ_screen_fd |
497万 |
11.33% |
spider |
8.9万 |
0.20% |
訓練設定
このモデルは MosaicML プラットフォーム を使用して、8台の80GB A100 GPU で約6.3時間訓練されました。
モデルは FSDP を使用してデータ並列訓練を行い、AdamW オプティマイザーを使用しました。
🚫 制限事項とバイアス
以下の内容は EleutherAI の GPT-NeoX-20B から引用しています。
MPT-7B-Instruct-8k は事実誤りのある出力を生成する可能性があるため、事実的に正確な情報を提供するためにこれに依存しないでください。
MPT-7B-Instruct-8k はさまざまな公開データセットで訓練されています。
事前訓練データのクリーニングに多大な努力を払っているにもかかわらず、このモデルは下品、偏見のある、またはその他の不快な出力を生成する可能性があります。
🙏 謝辞
このモデルは MosaicML NLP チームによって微調整されました。
⚠️ 免責事項
このモデルのライセンスは法律上のアドバイスを構成するものではありません。私たちは、このモデルを使用する第三者の行為について責任を負いません。このモデルを商用目的で使用する前に、弁護士に相談してください。
🛠️ MosaicML プラットフォーム
MosaicML プラットフォームで 訓練 や デプロイ を行う場合は、こちらで登録 してください。
📝 引用
このモデルを引用する場合は、以下の形式を使用してください。
@online{MosaicML2023Introducing,
author = {MosaicML NLP Team},
title = {Introducing MPT-30B: Raising the bar
for open-source foundation models},
year = {2023},
url = {www.mosaicml.com/blog/mpt-30b},
note = {Accessed: 2023-06-22},
urldate = {2023-06-22}
}
📄 ライセンス
Apache 2.0