diff --git a/simulstreaming/translate/hovercraft.py b/simulstreaming/translate/hovercraft.py index 6e81b68..84ccb28 100644 --- a/simulstreaming/translate/hovercraft.py +++ b/simulstreaming/translate/hovercraft.py @@ -57,6 +57,8 @@ "uk": "Моє судно на повітряній подушці наповнене вуграми.", # chatGPT provided: # "uk": "Мій ховеркрафт повний вугрів." + + "he": "הרחפת שלי מלאה בצלופחים.", } def hovercraft_sentence(lang_code): return hovercraft_translations[lang_code] diff --git a/simulstreaming_translate.py b/simulstreaming_translate.py index 8f2f155..8d57637 100644 --- a/simulstreaming_translate.py +++ b/simulstreaming_translate.py @@ -30,7 +30,7 @@ def generate_words(sp, step_results): class LLMTranslator: - def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None, model_dir="ct2_EuroLLM-9B-Instruct/", tokenizer_dir="EuroLLM-9B-Instruct/"): + def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None, model_dir="ct2_EuroLLM-9B-Instruct/", tokenizer_dir="EuroLLM-9B-Instruct/", sys_prompt_in_user=False): self.system_prompt = system_prompt print("Loading the model...", file=sys.stderr) @@ -46,16 +46,14 @@ def __init__(self, system_prompt='Please translate.', max_context_length=4096, l # my regex sentence segmenter self.segmenter = SentenceSegmenter() - - def start_dialog(self): - return [{'role':'system', 'content': self.system_prompt }] + self.sys_prompt_in_user = sys_prompt_in_user def build_prompt(self, dialog): # Build prompt with system + user + generation prompt (assistant turn marker), - base_toks = self.tokenizer.apply_chat_template(dialog[:2], tokenize=True, add_generation_prompt=True)["input_ids"] - if len(dialog) == 3: # if there is assistant message + base_toks = self.tokenizer.apply_chat_template(dialog[:1 if self.sys_prompt_in_user else 2], tokenize=True, add_generation_prompt=True)["input_ids"] + if len(dialog) == 2 if self.sys_prompt_in_user else 3: # if there is assistant message # then append the forced assistant content tokens. - forced_toks = self.tokenizer.encode(dialog[2]['content'], add_special_tokens=False) + forced_toks = self.tokenizer.encode(dialog[-1]['content'], add_special_tokens=False) toks = base_toks + forced_toks else: toks = base_toks @@ -66,8 +64,11 @@ def build_prompt(self, dialog): def translate(self, src, tgt_forced=""): - dialog = self.start_dialog() - dialog += [{'role':'user','content': src}] + if not self.sys_prompt_in_user: + dialog = [{'role':'system', 'content': self.system_prompt }, + {'role':'user','content': src}] + else: + dialog = [{'role':'user', 'content': self.system_prompt + "\n\n" + src}] if tgt_forced != "": dialog += [{'role':'assistant','content': tgt_forced}] @@ -344,7 +345,11 @@ def finish(self): "no": "Norwegian", "ru": "Russian", "tr": "Turkish", - "uk": "Ukrainian" + "uk": "Ukrainian", + + + # other languages: + "he": "Hebrew", } SrcLang = "English" # TODO: default parameters. @@ -400,6 +405,9 @@ def translate_args(parser): parser.add_argument('--sys_prompt', type=str, default=None, help='System prompt. If None, default one is used, depending on the language.') + parser.add_argument('--sys-in-user', default=False, action="store_true", + help='Place "system prompt" in use message, e.g. for gemma models.') + parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. ' 'It can be context specific for the given input. Default is ') @@ -479,7 +487,7 @@ def simul_translator_factory(args): len_threshold = args.len_threshold llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold, - model_dir=args.model_dir, tokenizer_dir=args.tokenizer_dir) + model_dir=args.model_dir, tokenizer_dir=args.tokenizer_dir, sys_prompt_in_user=args.sys_in_user) lan = args.tgt_lan if not args.tgt_lan.startswith("zh") else "zh" simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size, init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming