Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion simulstreaming/whisper/whisper_streaming/whisper_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def send_result(self, iteration_output):
# - the next words: segment transcript
if iteration_output:
if self.out_txt:
message = "%1.0f %1.0f %s" % (iteration_output['start'] * 1000, iteration_output['end'] * 1000, iteration_output['text'])
if iteration_output.get('start'):
message = "%1.0f %1.0f %s" % (iteration_output['start'] * 1000, iteration_output['end'] * 1000, iteration_output['text'])
else:
logger.debug("No token in this segment")
return
else:
message = json.dumps(iteration_output)
print(message, flush=True, file=sys.stderr)
Expand Down
37 changes: 25 additions & 12 deletions simulstreaming_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def init(self, offset=None):
self.model.refresh_segment(complete=True)

self.unicode_buffer = [] # hide incomplete unicode character for the next iteration
self.frame_buffer = 0

def insert_audio_chunk(self, audio):
self.audio_chunks.append(torch.from_numpy(audio))
Expand All @@ -162,9 +163,10 @@ def timestamped_text(self, tokens, generation):
split_words, split_tokens = generation["result"]["split_words"], generation["result"]["split_tokens"]

frames = [p["most_attended_frames"][0] for p in pr]
if self.unicode_buffer != []:
a = [frames[0]] * len(self.unicode_buffer)
if frames and self.frame_buffer:
a = [frames[0]] * self.frame_buffer #Buffer generation["result"] if timestamp accuracy becomes problem
frames = a + frames
self.frame_buffer = 0

tokens = tokens.copy()
ret = []
Expand Down Expand Up @@ -193,16 +195,27 @@ def hide_incomplete_unicode(self, tokens):
starts with '�'.
This function hides the last incomplete unicode character and adds it in the next iteration.
"""
if self.unicode_buffer != []:
logger.debug(f"Hiding incomplete unicode character: {self.unicode_buffer}")
tokens = self.unicode_buffer + tokens
self.unicode_buffer = [] # clear the buffer after processing
chars, _ = self.model.tokenizer.split_tokens_on_unicode(tokens)
if len(chars) > 0 and chars[-1].endswith('�'):
self.unicode_buffer = tokens[-1:] # keep the last incomplete unicode character
logger.debug(f"Hiding incomplete unicode character: {tokens[-1:]}")
return tokens[:-1] # remove the last token, which is incomplete unicode character
return tokens
if tokens == []:
return tokens #To preserve unicode_buffer
tokens = self.unicode_buffer + tokens #Add previous buffered token
self.frame_buffer = len(self.unicode_buffer)
self.unicode_buffer = []
decoded_str_bytes = self.model.tokenizer.encoding.decode_bytes_batch(
[[t] for t in tokens] #Split bytes on token
)
decoded_str = bytearray().join(decoded_str_bytes).decode('utf-8', errors="replace")
if len(decoded_str) > 0 and decoded_str[-1].endswith('�'): #split by char won't work because multple tokens can end up with one �.
for i in range(len(tokens)-1, 0, -1):
decoded_str_piece = bytearray().join(decoded_str_bytes[:i]).decode('utf-8', errors="replace")
if not (len(decoded_str_piece) > 0 and decoded_str_piece[-1].endswith('�')):
self.unicode_buffer = tokens[i:]
logger.debug(f"Hiding incomplete unicode character at end: {self.unicode_buffer}")
return tokens[:i]
logger.debug(f"Failed to split token, fallback to previous behaviour")
self.unicode_buffer = [tokens[-1]]
return tokens[:-1]
else:
return tokens

def process_iter(self):
if len(self.audio_chunks) == 0:
Expand Down