Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions simulstream/server/speech_processors/base_streamatt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


BOW_PREFIX = "\u2581"
STRONG_PUNCTUATION = [".", "!", "?", ":", ";", "。"]


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -182,8 +183,7 @@ def _update_speech_history(self, discarded_text: int, cross_attn: torch.Tensor)
# Check audio history not exceeding maximum allowed length
self._cut_audio_exceeding_maxlen()

@staticmethod
def _strip_incomplete_words(tokens: List[str]) -> List[str]:
def _strip_incomplete_words(self, tokens: List[str]) -> List[str]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _strip_incomplete_words(self, tokens: List[str]) -> List[str]:
@staticmethod
def _strip_incomplete_words(tokens: List[str]) -> List[str]:

"""
Remove last incomplete word(s) from the new hypothesis.

Expand All @@ -193,6 +193,19 @@ def _strip_incomplete_words(tokens: List[str]) -> List[str]:
Returns:
List[str]: A list of generated tokens from which partial words are removed.
"""
# Some tokenizers emit a trailing empty token after punctuation/EOS; drop it first so
# complete outputs like [" output", ".", ""] are not mistaken for incomplete words
while tokens and tokens[-1] == "":
tokens = tokens[:-1]

if not tokens:
return []

last_token = tokens[-1].strip()
# If the hypothesis already ends with punctuation, keep it as a complete segment
if last_token and last_token[-1] in STRONG_PUNCTUATION:
return tokens

Comment on lines +196 to +208
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not assume this is the right thing to do. We can try and add this variant but at the very minimum I'd add a flag in the configuration to control this.

tokens_to_write = []
# iterate from the end and count how many trailing tokens to drop
num_tokens_incomplete = 0
Expand Down Expand Up @@ -305,8 +318,6 @@ class PunctuationTextHistory:
The current implementation supports only SentencePiece.
"""

STRONG_PUNCTUATION = [".", "!", "?", ":", ";", "。"]

def __init__(self, config: SimpleNamespace):
self.config = config

Expand All @@ -317,7 +328,7 @@ def select_text_history(self, text_history):
for token in reversed(text_history):
prefix_token = token
contains_punctuation = False
for punct in self.STRONG_PUNCTUATION:
for punct in STRONG_PUNCTUATION:
if punct in prefix_token:
contains_punctuation = True
break
Expand Down
75 changes: 74 additions & 1 deletion uts/speech_processors/test_streamatt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import unittest
from types import SimpleNamespace

from simulstream.server.speech_processors.base_streamatt import PunctuationTextHistory
from simulstream.server.speech_processors.base_streamatt import (
PunctuationTextHistory, BaseStreamAtt)


class TestPunctuationTextHistory(unittest.TestCase):
Expand Down Expand Up @@ -60,5 +61,77 @@ def test_no_strong_punctuation(self):
self.assertEqual(selected_history, ['回', '到', '纽', '约', '后', ',', '我'])


class TestStripIncompleteWords(unittest.TestCase):
def setUp(self):
self.config = SimpleNamespace()
self._strip_incomplete_words = BaseStreamAtt._strip_incomplete_words

def test_incomplete_word_is_stripped(self):
"""Last word has no closing token — should be dropped."""
stripped = self._strip_incomplete_words(self, ["▁U", "ser", "▁Inter", "ac"])
self.assertEqual(stripped, ["▁U", "ser"])

def test_single_incomplete_word_returns_empty(self):
"""Only one word and it's incomplete — nothing left to return."""
stripped = self._strip_incomplete_words(self, ["▁Inter", "ac"])
self.assertEqual(stripped, [])

def test_multiple_incomplete_tokens_all_stripped(self):
"""Several continuation tokens after the last BOW — all should be dropped."""
stripped = self._strip_incomplete_words(self, ["▁U", "ser", "▁Inter", "ac", "ti"])
self.assertEqual(stripped, ["▁U", "ser"])

def test_ends_with_period_kept(self):
"""Trailing period counts as strong punctuation — full token list returned."""
stripped = self._strip_incomplete_words(self, ["▁U", "ser", "▁Inter", "ac", "tion", "."])
self.assertEqual(stripped, ["▁U", "ser", "▁Inter", "ac", "tion", "."])

def test_ends_with_multiple_periods(self):
"""Trailing period counts as strong punctuation — full token list returned."""
stripped = self._strip_incomplete_words(
self, ["▁U", "ser", "▁Inter", "ac", "tion", ".", ".", "."])
self.assertEqual(stripped, ["▁U", "ser", "▁Inter", "ac", "tion", ".", ".", "."])

def test_ends_with_non_strong_punctuation(self):
"""Non strong punctuation marks should be treated as standard tokens."""
stripped = self._strip_incomplete_words(self, ["▁Hello", "-"])
self.assertEqual(stripped, [])

def test_ends_with_question_mark(self):
"""Question marks should be treated as strong punctuation."""
stripped = self._strip_incomplete_words(self, ["▁Is", "▁this", "▁work", "ing", "?"])
self.assertEqual(stripped, ["▁Is", "▁this", "▁work", "ing", "?"])

def test_trailing_empty_token_stripped_before_check(self):
"""Empty trailing tokens should be dropped; remaining punctuation keeps the list intact."""
stripped = self._strip_incomplete_words(self, ["▁output", ".", ""])
self.assertEqual(stripped, ["▁output", "."])

def test_multiple_trailing_empty_tokens(self):
"""Multiple trailing empty tokens should be dropped."""
stripped = self._strip_incomplete_words(self, ["▁Hello", ".", "", ""])
self.assertEqual(stripped, ["▁Hello", "."])

def test_only_empty_tokens_returns_empty(self):
"""Only empty tokens should be dropped."""
stripped = self._strip_incomplete_words(self, ["", "", ""])
self.assertEqual(stripped, [])

def test_empty_input(self):
"""Empty input should return an empty list."""
stripped = self._strip_incomplete_words(self, [])
self.assertEqual(stripped, [])

def test_single_bow_token_incomplete(self):
"""A lone BOW token with no following token is itself incomplete."""
stripped = self._strip_incomplete_words(self, ["▁Hello"])
self.assertEqual(stripped, [])

def test_no_bow_prefix_at_all(self):
"""No BOW token anywhere — loop never breaks, returns empty list."""
stripped = self._strip_incomplete_words(self, ["ac", "tion"])
self.assertEqual(stripped, [])


if __name__ == "__main__":
unittest.main()
Loading