🚀 文本到Cypher生成模型
本模型基于Google Gemma-2-9b-it基础模型微调而来,旨在利用Neo4j-Text2Cypher(2024)数据集提升文本到Cypher查询语句的生成性能,为图数据库查询提供更智能的解决方案。
🚀 快速开始
本模型是使用Neo4j-Text2Cypher(2024)数据集对基础模型进行微调的示例,展示了在文本到Cypher任务上的性能提升。需要注意的是,这是正在进行的研究和探索的一部分,旨在凸显数据集的潜力,而非一个可用于生产环境的解决方案。
基础模型:google/gemma-2-9b-it
数据集:neo4j/text2cypher-2024v1
微调模型的概述和基准测试结果可查看 链接1 和 链接2
如果您有想法或见解,请联系我们:Neo4j/Team-GenAI
✨ 主要特性
- 性能提升:通过使用Neo4j-Text2Cypher(2024)数据集进行微调,在文本到Cypher任务上表现更优。
- 研究探索:作为研究项目的一部分,为图数据库查询生成提供新的思路。
📦 安装指南
文档中未提及具体安装步骤,故跳过此章节。
💻 使用示例
基础用法
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "DavidLanz/text2cypher-gemma-2-9b-it-finetuned-2024v1"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
question = "What are the movies of Tom Hanks?"
schema = "(:Actor)-[:ActedIn]->(:Movie)"
instruction = (
"Generate Cypher statement to query a graph database. "
"Use only the provided relationship types and properties in the schema. \n"
"Schema: {schema} \n Question: {question} \n Cypher output: "
)
prompt = instruction.format(schema=schema, question=question)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model.eval()
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated Cypher Query:", generated_text)
def prepare_chat_prompt(question, schema):
chat = [
{
"role": "user",
"content": instruction.format(
schema=schema, question=question
),
}
]
return chat
def _postprocess_output_cypher(output_cypher: str) -> str:
partition_by = "**Explanation:**"
output_cypher, _, _ = output_cypher.partition(partition_by)
output_cypher = output_cypher.strip("`\n")
output_cypher = output_cypher.lstrip("cypher\n")
output_cypher = output_cypher.strip("`\n ")
return output_cypher
new_message = prepare_chat_prompt(question=question, schema=schema)
try:
prompt = tokenizer.apply_chat_template(new_message, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512)
chat_generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
final_cypher = _postprocess_output_cypher(chat_generated_text)
print("Processed Cypher Query:", final_cypher)
except AttributeError:
print("Error: `apply_chat_template` not supported by this tokenizer. Check compatibility.")
高级用法
文档中未提及高级用法相关代码,故跳过此部分。
📚 详细文档
文档中未提及详细说明内容,故跳过此章节。
🔧 技术细节
训练过程
使用RunPod进行训练,配置如下:
- 1 x A100 PCIe
- 31 vCPU 117 GB RAM
- runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04
- 按需使用 - 安全云
- 60 GB 磁盘
- 60 GB Pod 卷
训练超参数
lora_config = LoraConfig(
r=64,
lora_alpha=64,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
sft_config = SFTConfig(
dataset_text_field=dataset_text_field,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
dataset_num_proc=16,
max_seq_length=1600,
logging_dir="./logs",
num_train_epochs=1,
learning_rate=2e-5,
save_steps=5,
save_total_limit=1,
logging_steps=5,
output_dir="outputs",
optim="paged_adamw_8bit",
save_strategy="steps",
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
📄 许可证
许可证类型:gemma