🚀 T5 large LM Adapt for Text to SQL
このモデルは、自然言語のプロンプトから構造化されたSQLクエリを生成することを目的としています。
🚀 クイックスタート
このモデルは、自然言語で提示された質問からSQLクエリを生成するText2SQLタスクに特化しています。通常のアプローチでは、SQLクエリが未知の列を含んだり、特定のデータベーススキーマを考慮しないことがあります。そこで、当社のアプローチでは、学習時にデータベーススキーマを入力質問に組み込むことで、適用可能なSQLクエリを生成するために利用できる列と関係を指定します。
✨ 主な機能
- 自然言語の質問から構造化されたSQLクエリを生成します。
- データベーススキーマを入力に組み込むことで、未知のスキーマに対する汎化能力を向上させます。
📦 インストール
このモデルを使用するには、🤗 Transformersライブラリが必要です。以下のコマンドでインストールできます。
pip install transformers
💻 使用例
基本的な使用法
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_path = 'gaussalgo/T5-LM-Large-text2sql-spider'
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
question = "What is the average, minimum, and maximum age for all French musicians?"
schema = """
"stadium" "Stadium_ID" int , "Location" text , "Name" text , "Capacity" int , "Highest" int , "Lowest" int , "Average" int , foreign_key: primary key: "Stadium_ID" [SEP] "singer" "Singer_ID" int , "Name" text , "Country" text , "Song_Name" text , "Song_release_year" text , "Age" int , "Is_male" bool , foreign_key: primary key: "Singer_ID" [SEP] "concert" "concert_ID" int , "concert_Name" text , "Theme" text , "Year" text , foreign_key: "Stadium_ID" text from "stadium" "Stadium_ID" , primary key: "concert_ID" [SEP] "singer_in_concert" foreign_key: "concert_ID" int from "concert" "concert_ID" , "Singer_ID" text from "singer" "Singer_ID" , primary key: "concert_ID" "Singer_ID"
"""
input_text = " ".join(["Question: ",question, "Schema:", schema])
model_inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**model_inputs, max_length=512)
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print("SQL Query:")
print(output_text)
出力結果:
SQL Query:
SELECT avg(age), min(age), max(age) FROM singer WHERE country = 'France'
📚 ドキュメント
ベースモデル
このモデルは、t5-large-LM-adaptのチェックポイントからファインチューニングされています。
SpiderとSpider-Synデータセット
このモデルは、SpiderとSpider-Synデータセットのトレーニング分割でファインチューニングされました。質問だけでなく、データベーススキーマも質問に追加し、モデルが特定のデータベースに対する質問を生成できるようにしました。
入力プロンプトの例:
Question: What is the average, minimum, and maximum age for all French musicians?
Schema: "stadium" "Stadium_ID" int , "Location" text , "Name" text , "Capacity" int , "Highest" int , "Lowest" int ,
"Average" int , foreign_key: primary key: "Stadium_ID" [SEP] "singer" "Singer_ID" int , "Name" text , "Country" text ,
"Song_Name" text , "Song_release_year" text , "Age" int , "Is_male" bool ,
foreign_key: primary key: "Singer_ID" [SEP],
"concert" "concert_ID" int , "concert_Name" text , "Theme" text , "Year" text , foreign_key: "Stadium_ID" text from "stadium",
"Stadium_ID" , primary key: "concert_ID" [SEP] "singer_in_concert",
foreign_key: "concert_ID" int from "concert",
"concert_ID" , "Singer_ID" text from "singer" "Singer_ID" , primary key: "concert_ID" "Singer_ID"
期待される出力:
SELECT avg(age), min(age), max(age) FROM singer WHERE country = 'France'
データベースをクエリした結果:
[[34.5, 25, 43]]
データベーススキーマの形式
このモデルが学習した標準化されたデータベーススキーマは以下の通りです。
table_name column1_name column1_type column2_name column2_type ... foreign_key: FK_name FK_type from table_name column_name primary key: column_name [SEP]
table_name2 ...
評価
評価は、SpiderとSpider-synデータセットの開発分割で行われました。開発分割に含まれるデータベースは、トレーニング分割のデータベースと重複していません。これにより、モデルがトレーニング中に評価対象のデータベースに触れないようにしています。評価は、生成されたクエリと参照クエリを使用してデータベースをクエリした結果を比較することで行われました。SpiderとSpider-Synの開発分割にはそれぞれ1032サンプルが含まれています。
- Spider開発セットの正解率: 49.2%
- Spider Syn開発セットの正解率: 39.5%
トレーニング
このモデルは、Adaptorライブラリ 0.2.1を使用して、SpiderとSpider-synデータセットのトレーニング分割で以下のパラメータでトレーニングされました。
training_arguments = AdaptationArguments(output_dir="train_dir",
learning_rate=5e-5,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
stopping_patience=8,
save_total_limit=8,
do_train=True,
do_eval=True,
bf16=True,
warmup_steps=1000,
gradient_accumulation_steps=8,
logging_steps=10,
eval_steps=200,
save_steps=1000,
num_train_epochs=10,
evaluation_strategy="steps")
トレーニングは比較的再現しやすいですが、依存しているSpiderデータセットの修正コピーを公開することは希望していません。もしこの方向でさらに調査したい場合は、新しいPRを通じて、またはstefanik(at)gaussalgo.comまでメールでご連絡ください。
📄 ライセンス
原文書にライセンス情報は記載されていません。