diff --git a/simulstream/server/speech_processors/base_streamatt.py b/simulstream/server/speech_processors/base_streamatt.py index fa7ccd4..73a4683 100644 --- a/simulstream/server/speech_processors/base_streamatt.py +++ b/simulstream/server/speech_processors/base_streamatt.py @@ -26,6 +26,7 @@ BOW_PREFIX = "\u2581" +STRONG_PUNCTUATION = [".", "!", "?", ":", ";", "。"] logger = logging.getLogger(__name__) @@ -182,8 +183,7 @@ def _update_speech_history(self, discarded_text: int, cross_attn: torch.Tensor) # Check audio history not exceeding maximum allowed length self._cut_audio_exceeding_maxlen() - @staticmethod - def _strip_incomplete_words(tokens: List[str]) -> List[str]: + def _strip_incomplete_words(self, tokens: List[str]) -> List[str]: """ Remove last incomplete word(s) from the new hypothesis. @@ -193,6 +193,19 @@ def _strip_incomplete_words(tokens: List[str]) -> List[str]: Returns: List[str]: A list of generated tokens from which partial words are removed. """ + # Some tokenizers emit a trailing empty token after punctuation/EOS; drop it first so + # complete outputs like [" output", ".", ""] are not mistaken for incomplete words + while tokens and tokens[-1] == "": + tokens = tokens[:-1] + + if not tokens: + return [] + + last_token = tokens[-1].strip() + # If the hypothesis already ends with punctuation, keep it as a complete segment + if last_token and last_token[-1] in STRONG_PUNCTUATION: + return tokens + tokens_to_write = [] # iterate from the end and count how many trailing tokens to drop num_tokens_incomplete = 0 @@ -305,8 +318,6 @@ class PunctuationTextHistory: The current implementation supports only SentencePiece. """ - STRONG_PUNCTUATION = [".", "!", "?", ":", ";", "。"] - def __init__(self, config: SimpleNamespace): self.config = config @@ -317,7 +328,7 @@ def select_text_history(self, text_history): for token in reversed(text_history): prefix_token = token contains_punctuation = False - for punct in self.STRONG_PUNCTUATION: + for punct in STRONG_PUNCTUATION: if punct in prefix_token: contains_punctuation = True break diff --git a/uts/speech_processors/test_streamatt.py b/uts/speech_processors/test_streamatt.py index 180c408..efc0bcb 100644 --- a/uts/speech_processors/test_streamatt.py +++ b/uts/speech_processors/test_streamatt.py @@ -15,7 +15,8 @@ import unittest from types import SimpleNamespace -from simulstream.server.speech_processors.base_streamatt import PunctuationTextHistory +from simulstream.server.speech_processors.base_streamatt import ( + PunctuationTextHistory, BaseStreamAtt) class TestPunctuationTextHistory(unittest.TestCase): @@ -60,5 +61,77 @@ def test_no_strong_punctuation(self): self.assertEqual(selected_history, ['回', '到', '纽', '约', '后', ',', '我']) +class TestStripIncompleteWords(unittest.TestCase): + def setUp(self): + self.config = SimpleNamespace() + self._strip_incomplete_words = BaseStreamAtt._strip_incomplete_words + + def test_incomplete_word_is_stripped(self): + """Last word has no closing token — should be dropped.""" + stripped = self._strip_incomplete_words(self, ["▁U", "ser", "▁Inter", "ac"]) + self.assertEqual(stripped, ["▁U", "ser"]) + + def test_single_incomplete_word_returns_empty(self): + """Only one word and it's incomplete — nothing left to return.""" + stripped = self._strip_incomplete_words(self, ["▁Inter", "ac"]) + self.assertEqual(stripped, []) + + def test_multiple_incomplete_tokens_all_stripped(self): + """Several continuation tokens after the last BOW — all should be dropped.""" + stripped = self._strip_incomplete_words(self, ["▁U", "ser", "▁Inter", "ac", "ti"]) + self.assertEqual(stripped, ["▁U", "ser"]) + + def test_ends_with_period_kept(self): + """Trailing period counts as strong punctuation — full token list returned.""" + stripped = self._strip_incomplete_words(self, ["▁U", "ser", "▁Inter", "ac", "tion", "."]) + self.assertEqual(stripped, ["▁U", "ser", "▁Inter", "ac", "tion", "."]) + + def test_ends_with_multiple_periods(self): + """Trailing period counts as strong punctuation — full token list returned.""" + stripped = self._strip_incomplete_words( + self, ["▁U", "ser", "▁Inter", "ac", "tion", ".", ".", "."]) + self.assertEqual(stripped, ["▁U", "ser", "▁Inter", "ac", "tion", ".", ".", "."]) + + def test_ends_with_non_strong_punctuation(self): + """Non strong punctuation marks should be treated as standard tokens.""" + stripped = self._strip_incomplete_words(self, ["▁Hello", "-"]) + self.assertEqual(stripped, []) + + def test_ends_with_question_mark(self): + """Question marks should be treated as strong punctuation.""" + stripped = self._strip_incomplete_words(self, ["▁Is", "▁this", "▁work", "ing", "?"]) + self.assertEqual(stripped, ["▁Is", "▁this", "▁work", "ing", "?"]) + + def test_trailing_empty_token_stripped_before_check(self): + """Empty trailing tokens should be dropped; remaining punctuation keeps the list intact.""" + stripped = self._strip_incomplete_words(self, ["▁output", ".", ""]) + self.assertEqual(stripped, ["▁output", "."]) + + def test_multiple_trailing_empty_tokens(self): + """Multiple trailing empty tokens should be dropped.""" + stripped = self._strip_incomplete_words(self, ["▁Hello", ".", "", ""]) + self.assertEqual(stripped, ["▁Hello", "."]) + + def test_only_empty_tokens_returns_empty(self): + """Only empty tokens should be dropped.""" + stripped = self._strip_incomplete_words(self, ["", "", ""]) + self.assertEqual(stripped, []) + + def test_empty_input(self): + """Empty input should return an empty list.""" + stripped = self._strip_incomplete_words(self, []) + self.assertEqual(stripped, []) + + def test_single_bow_token_incomplete(self): + """A lone BOW token with no following token is itself incomplete.""" + stripped = self._strip_incomplete_words(self, ["▁Hello"]) + self.assertEqual(stripped, []) + + def test_no_bow_prefix_at_all(self): + """No BOW token anywhere — loop never breaks, returns empty list.""" + stripped = self._strip_incomplete_words(self, ["ac", "tion"]) + self.assertEqual(stripped, []) + + if __name__ == "__main__": unittest.main()