Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 93 additions & 36 deletions kittentts/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,73 @@
import re

from misaki import en, espeak
import numpy as np
import phonemizer
import soundfile as sf
import onnxruntime as ort

from .preprocess import TextPreprocessor


def basic_english_tokenize(text):
"""Basic English tokenizer that splits on whitespace and punctuation."""
import re
tokens = re.findall(r"\w+|[^\w\s]", text)
return tokens


def ensure_punctuation(text):
"""Ensure text ends with punctuation. If not, add a comma."""
text = text.strip()
if not text:
return text
if text[-1] not in '.!?,;:':
text = text + ','
if text[-1] not in ".!?,;:":
text = text + ","
return text


def _protect_inline_abbreviations(text):
"""Protect period-bearing abbreviations from sentence splitting."""

placeholders = {}
patterns = [
(r"\ba\.m\.(?=\s|$|[,;:!?])", "a.m."),
(r"\bp\.m\.(?=\s|$|[,;:!?])", "p.m."),
(r"\bDr\.(?=\s|$)", "Dr."),
(r"\bMr\.(?=\s|$)", "Mr."),
(r"\bMrs\.(?=\s|$)", "Mrs."),
(r"\bMs\.(?=\s|$)", "Ms."),
(r"\bProf\.(?=\s|$)", "Prof."),
]

for index, (pattern, original) in enumerate(patterns):
placeholder = f"__ABBR_{index}__"
text, count = re.subn(pattern, placeholder, text, flags=re.IGNORECASE)
if count:
placeholders[placeholder] = original

return text, placeholders


def _restore_inline_abbreviations(text, placeholders):
"""Restore protected abbreviations after sentence splitting."""
for placeholder, original in placeholders.items():
text = text.replace(placeholder, original)
return text


def chunk_text(text, max_len=400):
"""Split text into chunks for processing long texts."""
import re

sentences = re.split(r'[.!?]+', text)

protected_text, placeholders = _protect_inline_abbreviations(text)

sentences = re.split(r"[.!?]+", protected_text)
chunks = []

for sentence in sentences:
sentence = sentence.strip()
sentence = _restore_inline_abbreviations(sentence, placeholders).strip()
if not sentence:
continue

if len(sentence) <= max_len:
chunks.append(ensure_punctuation(sentence))
else:
Expand All @@ -48,19 +83,19 @@ def chunk_text(text, max_len=400):
temp_chunk = word
if temp_chunk:
chunks.append(ensure_punctuation(temp_chunk.strip()))

return chunks


class TextCleaner:
def __init__(self, dummy=None):
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»"" '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"

symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)

dicts = {}
for i in range(len(symbols)):
dicts[symbols[i]] = i
Expand All @@ -78,7 +113,13 @@ def __call__(self, text):


class KittenTTS_1_Onnx:
def __init__(self, model_path="kitten_tts_nano_preview.onnx", voices_path="voices.npz", speed_priors={}, voice_aliases={}):
def __init__(
self,
model_path="kitten_tts_nano_preview.onnx",
voices_path="voices.npz",
speed_priors={},
voice_aliases={},
):
"""Initialize KittenTTS with model and voice data.

Args:
Expand All @@ -104,73 +145,90 @@ def __init__(self, model_path="kitten_tts_nano_preview.onnx", voices_path="voice
self.voice_aliases = voice_aliases

self.preprocessor = TextPreprocessor(remove_punctuation=False)

def _prepare_inputs(self, text: str, voice: str, speed: float = 1.0) -> dict:
"""Prepare ONNX model inputs from text and voice parameters."""
if voice in self.voice_aliases:
voice = self.voice_aliases[voice]

if voice not in self.available_voices:
raise ValueError(f"Voice '{voice}' not available. Choose from: {self.available_voices}")

if voice in self.speed_priors:
speed = speed * self.speed_priors[voice]

# Phonemize the input text
phonemes_list = self.phonemizer.phonemize([text])

# Process phonemes to get token IDs
phonemes = basic_english_tokenize(phonemes_list[0])
phonemes = ' '.join(phonemes)
phonemes = " ".join(phonemes)
tokens = self.text_cleaner(phonemes)

# Add start and end tokens
tokens.insert(0, 0)
tokens.append(10)
tokens.append(0)

input_ids = np.array([tokens], dtype=np.int64)
ref_id = min(len(text), self.voices[voice].shape[0] - 1)
ref_s = self.voices[voice][ref_id:ref_id+1]
ref_id = min(len(text), self.voices[voice].shape[0] - 1)
ref_s = self.voices[voice][ref_id : ref_id + 1]

return {
"input_ids": input_ids,
"style": ref_s,
"speed": np.array([speed], dtype=np.float32),
}

def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0, clean_text: bool=True) -> np.ndarray:

def generate(
self,
text: str,
voice: str = "expr-voice-5-m",
speed: float = 1.0,
clean_text: bool = True,
) -> np.ndarray:

out_chunks = []
if clean_text:
text = self.preprocessor(text)

for text_chunk in chunk_text(text):
out_chunks.append(self.generate_single_chunk(text_chunk, voice, speed))
chunk = self.generate_single_chunk(text_chunk, voice, speed)
out_chunks.append(chunk)
return np.concatenate(out_chunks, axis=-1)

def generate_single_chunk(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray:
def generate_single_chunk(
self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0
) -> np.ndarray:
"""Synthesize speech from text.

Args:
text: Input text to synthesize
voice: Voice to use for synthesis
speed: Speech speed (1.0 = normal)

Returns:
Audio data as numpy array
"""
onnx_inputs = self._prepare_inputs(text, voice, speed)

outputs = self.session.run(None, onnx_inputs)

# Trim audio
audio = outputs[0][..., :-5000]

return audio

def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice-5-m",
speed: float = 1.0, sample_rate: int = 24000, clean_text: bool=True) -> None:

def generate_to_file(
self,
text: str,
output_path: str,
voice: str = "expr-voice-5-m",
speed: float = 1.0,
sample_rate: int = 24000,
clean_text: bool = True,
) -> None:
"""Synthesize speech and save to file.

Args:
text: Input text to synthesize
output_path: Path to save the audio file
Expand All @@ -182,4 +240,3 @@ def generate_to_file(self, text: str, output_path: str, voice: str = "expr-voice
audio = self.generate(text, voice, speed, clean_text=clean_text)
sf.write(output_path, audio, sample_rate)
print(f"Audio saved to {output_path}")