🚀 Flux-Prompt-Enhance
Flux-Prompt-Enhance 是一个文本生成模型,它基于 google-t5/t5-base
模型,可用于将简短的提示词转换为更详细、丰富的描述,适用于文本到文本的生成任务。
🚀 快速开始
以下是使用 Flux-Prompt-Enhance 模型进行提示词增强的示例代码:
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
enhancer = pipeline('text2text-generation',
model=model,
tokenizer=tokenizer,
repetition_penalty= 1.2,
device=device)
max_target_length = 256
prefix = "enhance prompt: "
short_prompt = "beautiful house with text 'hello'"
answer = enhancer(prefix + short_prompt, max_length=max_target_length)
final_answer = answer[0]['generated_text']
print(final_answer)
💻 使用示例
基础用法
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
enhancer = pipeline('text2text-generation',
model=model,
tokenizer=tokenizer,
repetition_penalty= 1.2,
device=device)
max_target_length = 256
prefix = "enhance prompt: "
short_prompt = "beautiful house with text 'hello'"
answer = enhancer(prefix + short_prompt, max_length=max_target_length)
final_answer = answer[0]['generated_text']
print(final_answer)
高级用法
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
enhancer = pipeline('text2text-generation',
model=model,
tokenizer=tokenizer,
repetition_penalty= 1.5,
device=device)
max_target_length = 300
prefix = "enhance prompt: "
short_prompt = "beautiful house with text 'hello'"
answer = enhancer(prefix + short_prompt, max_length=max_target_length)
final_answer = answer[0]['generated_text']
print(final_answer)
📄 许可证
本项目采用 apache-2.0
许可证。
📚 详细文档
模型信息
属性 |
详情 |
基础模型 |
google-t5/t5-base |
训练数据集 |
gokaygokay/prompt-enhancer-dataset |
语言 |
en |
库名称 |
transformers |
任务类型 |
text2text-generation |
许可证 |
apache-2.0 |