🚀 文本到Cypher生成模型
本模型基於Google Gemma-2-9b-it基礎模型微調而來,旨在利用Neo4j-Text2Cypher(2024)數據集提升文本到Cypher查詢語句的生成性能,為圖數據庫查詢提供更智能的解決方案。
🚀 快速開始
本模型是使用Neo4j-Text2Cypher(2024)數據集對基礎模型進行微調的示例,展示了在文本到Cypher任務上的性能提升。需要注意的是,這是正在進行的研究和探索的一部分,旨在凸顯數據集的潛力,而非一個可用於生產環境的解決方案。
基礎模型:google/gemma-2-9b-it
數據集:neo4j/text2cypher-2024v1
微調模型的概述和基準測試結果可查看 鏈接1 和 鏈接2
如果您有想法或見解,請聯繫我們:Neo4j/Team-GenAI
✨ 主要特性
- 性能提升:通過使用Neo4j-Text2Cypher(2024)數據集進行微調,在文本到Cypher任務上表現更優。
- 研究探索:作為研究項目的一部分,為圖數據庫查詢生成提供新的思路。
📦 安裝指南
文檔中未提及具體安裝步驟,故跳過此章節。
💻 使用示例
基礎用法
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "DavidLanz/text2cypher-gemma-2-9b-it-finetuned-2024v1"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
question = "What are the movies of Tom Hanks?"
schema = "(:Actor)-[:ActedIn]->(:Movie)"
instruction = (
"Generate Cypher statement to query a graph database. "
"Use only the provided relationship types and properties in the schema. \n"
"Schema: {schema} \n Question: {question} \n Cypher output: "
)
prompt = instruction.format(schema=schema, question=question)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model.eval()
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated Cypher Query:", generated_text)
def prepare_chat_prompt(question, schema):
chat = [
{
"role": "user",
"content": instruction.format(
schema=schema, question=question
),
}
]
return chat
def _postprocess_output_cypher(output_cypher: str) -> str:
partition_by = "**Explanation:**"
output_cypher, _, _ = output_cypher.partition(partition_by)
output_cypher = output_cypher.strip("`\n")
output_cypher = output_cypher.lstrip("cypher\n")
output_cypher = output_cypher.strip("`\n ")
return output_cypher
new_message = prepare_chat_prompt(question=question, schema=schema)
try:
prompt = tokenizer.apply_chat_template(new_message, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512)
chat_generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
final_cypher = _postprocess_output_cypher(chat_generated_text)
print("Processed Cypher Query:", final_cypher)
except AttributeError:
print("Error: `apply_chat_template` not supported by this tokenizer. Check compatibility.")
高級用法
文檔中未提及高級用法相關代碼,故跳過此部分。
📚 詳細文檔
文檔中未提及詳細說明內容,故跳過此章節。
🔧 技術細節
訓練過程
使用RunPod進行訓練,配置如下:
- 1 x A100 PCIe
- 31 vCPU 117 GB RAM
- runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04
- 按需使用 - 安全雲
- 60 GB 磁盤
- 60 GB Pod 卷
訓練超參數
lora_config = LoraConfig(
r=64,
lora_alpha=64,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
sft_config = SFTConfig(
dataset_text_field=dataset_text_field,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
dataset_num_proc=16,
max_seq_length=1600,
logging_dir="./logs",
num_train_epochs=1,
learning_rate=2e-5,
save_steps=5,
save_total_limit=1,
logging_steps=5,
output_dir="outputs",
optim="paged_adamw_8bit",
save_strategy="steps",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
📄 許可證
許可證類型:gemma