đ Model Card for Model ID
This model demonstrates how fine - tuning foundational models with the Neo4j - Text2Cypher(2024) Dataset can enhance performance on the Text2Cypher task. It's part of ongoing research, highlighting the dataset's potential rather than being a production - ready solution.
đ Quick Start
The model can be used to generate Cypher statements for querying graph databases. You can follow the example code in the "Example Cypher generation" section to get started.
⨠Features
- Demonstrates the effectiveness of fine - tuning on the Text2Cypher task using the Neo4j - Text2Cypher(2024) Dataset.
- Shares finetuned model overviews and benchmarking results through provided links.
- Allows users to generate Cypher statements based on given schemas and questions.
đĻ Installation
The framework version used is PEFT 0.12.0. You can install the necessary dependencies according to the requirements of the example code.
đģ Usage Examples
Basic Usage
from peft import PeftModel, PeftConfig
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
instruction = (
"Generate Cypher statement to query a graph database. "
"Use only the provided relationship types and properties in the schema. \n"
"Schema: {schema} \n Question: {question} \n Cypher output: "
)
def prepare_chat_prompt(question, schema) -> list[dict]:
chat = [
{
"role": "user",
"content": instruction.format(
schema=schema, question=question
),
}
]
return chat
def _postprocess_output_cypher(output_cypher: str) -> str:
partition_by = "**Explanation:**"
output_cypher, _, _ = output_cypher.partition(partition_by)
output_cypher = output_cypher.strip("`\n")
output_cypher = output_cypher.lstrip("cypher\n")
output_cypher = output_cypher.strip("`\n ")
return output_cypher
model_name = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
low_cpu_mem_usage=True,
)
question = "What are the movies of Tom Hanks?"
schema = "(:Actor)-[:ActedIn]->(:Movie)"
new_message = prepare_chat_prompt(question=question, schema=schema)
prompt = tokenizer.apply_chat_template(new_message, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
model_generate_parameters = {
"top_p": 0.9,
"temperature": 0.2,
"max_new_tokens": 512,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
inputs.to(model.device)
model.eval()
with torch.no_grad():
tokens = model.generate(**inputs, **model_generate_parameters)
tokens = tokens[:, inputs.input_ids.shape[1] :]
raw_outputs = tokenizer.batch_decode(tokens, skip_special_tokens=True)
outputs = [_postprocess_output_cypher(output) for output in raw_outputs]
print(outputs)
> ["MATCH (a:Actor {Name: 'Tom Hanks'})-[:ActedIn]->(m:Movie) RETURN m"]
Advanced Usage
You can adjust the training hyperparameters and model generation parameters according to your specific needs. For example, modify the lora_config
, sft_config
, bnb_config
in the training section, or change the top_p
, temperature
in the model generation parameters.
đ Documentation
Model Details
Model Description
This model showcases how fine - tuning foundational models with the Neo4j - Text2Cypher(2024) Dataset (link) can boost performance on the Text2Cypher task. Note that this is part of ongoing research, aiming to highlight the dataset's potential rather than being a production - ready solution.
Base model: google/gemma - 2 - 9b - it
Dataset: neo4j/text2cypher - 2024v1
An overview of the finetuned models and benchmarking results are shared at Link1 and Link2
Have ideas or insights? Contact us: [Neo4j/Team - GenAI](mailto:team - gen - ai@neo4j.com)
Training Details
Training Procedure
Used RunPod with the following setup:
- 1 x A100 PCIe
- 31 vCPU 117 GB RAM
- runpod/pytorch:2.4.0 - py3.11 - cuda12.4.1 - devel - ubuntu22.04
- On - Demand - Secure Cloud
- 60 GB Disk
- 60 GB Pod Volume
Training Hyperparameters
- lora_config = LoraConfig(
r = 64,
lora_alpha = 64,
target_modules = target_modules,
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
)
- sft_config = SFTConfig(
dataset_text_field = dataset_text_field,
per_device_train_batch_size = 4,
gradient_accumulation_steps = 8,
dataset_num_proc = 16,
max_seq_length = 1600,
logging_dir = "./logs",
num_train_epochs = 1,
learning_rate = 2e - 5,
save_steps = 5,
save_total_limit = 1,
logging_steps = 5,
output_dir = "outputs",
optim = "paged_adamw_8bit",
save_strategy = "steps",
)
- bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = torch.bfloat16,
)
Framework versions
NOTE on creating your own schemas
- In the dataset we used, the schemas are already provided. They are created either by
- Directly using the schema the input data source provided OR
- Creating schema using neo4j - graphrag package (Check: SchemaReader.get_schema(...) function)
- In your own Neo4j database, you can utilize
neo4j - graphrag package::SchemaReader
functions
đ§ Technical Details
Bias, Risks, and Limitations
We need to be cautious about a few risks:
- In our evaluation setup, the training and test sets come from the same data distribution (sampled from a larger dataset). If the data distribution changes, the results may not follow the same pattern.
- The datasets used were gathered from publicly available sources. Over time, foundational models may access both the training and test sets, potentially achieving similar or even better results.
Also check the related blogpost: Link
đ License
The license of this project is apache - 2.0.
Information Table
Property |
Details |
Base Model |
google/gemma-2-9b-it |
Library Name |
peft |
License |
apache-2.0 |
Datasets |
neo4j/text2cypher-2024v1 |
Language |
en |
Pipeline Tag |
text2text-generation |
Tags |
neo4j, cypher, text2cypher |