diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 9bf1ad57e..71b092249 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -321,9 +321,19 @@ def __call__(self, *args, **cfg): def generate(self, input, input_len=None, progress_callback=None, **cfg): self._reset_runtime_configs() if self.vad_model is None: - return self.inference( + results = self.inference( input, input_len=input_len, progress_callback=progress_callback, **cfg ) + if self.punc_model is not None: + deep_update(self.punc_kwargs, cfg) + for result in results: + punc_res = self.inference( + result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg + ) + if cfg.get("return_raw_text", self.kwargs.get("return_raw_text", False)): + result["raw_text"] = copy.copy(result["text"]) + result["text"] = punc_res[0]["text"] + return results else: return self.inference_with_vad(