import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
from rubyinserter import add_ruby
device = "cuda:0"if torch.cuda.is_available() else"cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("2121-8/japanese-parler-tts-mini-bate").to(device)
tokenizer = AutoTokenizer.from_pretrained("2121-8/japanese-parler-tts-mini-bate")
prompt = "こんにちは、今日はどのようにお過ごしですか?"
description = "A female speaker with a slightly high-pitched voice delivers her words at a moderate speed with a quite monotone tone in a confined environment, resulting in a quite clear audio recording."
prompt = add_ruby(prompt)
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_japanese_out.wav", audio_arr, model.config.sampling_rate)