Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions simulstreaming/translate/hovercraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
"uk": "Моє судно на повітряній подушці наповнене вуграми.",
# chatGPT provided:
# "uk": "Мій ховеркрафт повний вугрів."

"he": "הרחפת שלי מלאה בצלופחים.",
}
def hovercraft_sentence(lang_code):
return hovercraft_translations[lang_code]
30 changes: 19 additions & 11 deletions simulstreaming_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def generate_words(sp, step_results):

class LLMTranslator:

def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None, model_dir="ct2_EuroLLM-9B-Instruct/", tokenizer_dir="EuroLLM-9B-Instruct/"):
def __init__(self, system_prompt='Please translate.', max_context_length=4096, len_ratio=None, model_dir="ct2_EuroLLM-9B-Instruct/", tokenizer_dir="EuroLLM-9B-Instruct/", sys_prompt_in_user=False):
self.system_prompt = system_prompt

print("Loading the model...", file=sys.stderr)
Expand All @@ -46,16 +46,14 @@ def __init__(self, system_prompt='Please translate.', max_context_length=4096, l

# my regex sentence segmenter
self.segmenter = SentenceSegmenter()

def start_dialog(self):
return [{'role':'system', 'content': self.system_prompt }]
self.sys_prompt_in_user = sys_prompt_in_user

def build_prompt(self, dialog):
# Build prompt with system + user + generation prompt (assistant turn marker),
base_toks = self.tokenizer.apply_chat_template(dialog[:2], tokenize=True, add_generation_prompt=True)["input_ids"]
if len(dialog) == 3: # if there is assistant message
base_toks = self.tokenizer.apply_chat_template(dialog[:1 if self.sys_prompt_in_user else 2], tokenize=True, add_generation_prompt=True)["input_ids"]
if len(dialog) == 2 if self.sys_prompt_in_user else 3: # if there is assistant message
# then append the forced assistant content tokens.
forced_toks = self.tokenizer.encode(dialog[2]['content'], add_special_tokens=False)
forced_toks = self.tokenizer.encode(dialog[-1]['content'], add_special_tokens=False)
toks = base_toks + forced_toks
else:
toks = base_toks
Expand All @@ -66,8 +64,11 @@ def build_prompt(self, dialog):

def translate(self, src, tgt_forced=""):

dialog = self.start_dialog()
dialog += [{'role':'user','content': src}]
if not self.sys_prompt_in_user:
dialog = [{'role':'system', 'content': self.system_prompt },
{'role':'user','content': src}]
else:
dialog = [{'role':'user', 'content': self.system_prompt + "\n\n" + src}]
if tgt_forced != "":
dialog += [{'role':'assistant','content': tgt_forced}]

Expand Down Expand Up @@ -344,7 +345,11 @@ def finish(self):
"no": "Norwegian",
"ru": "Russian",
"tr": "Turkish",
"uk": "Ukrainian"
"uk": "Ukrainian",


# other languages:
"he": "Hebrew",
}

SrcLang = "English" # TODO: default parameters.
Expand Down Expand Up @@ -400,6 +405,9 @@ def translate_args(parser):

parser.add_argument('--sys_prompt', type=str, default=None,
help='System prompt. If None, default one is used, depending on the language.')
parser.add_argument('--sys-in-user', default=False, action="store_true",
help='Place "system prompt" in use message, e.g. for gemma models.')


parser.add_argument('--init_prompt_src', type=str, default=None, help='Init translation with source text. It should be a complete sentence in the source language. '
'It can be context specific for the given input. Default is ')
Expand Down Expand Up @@ -479,7 +487,7 @@ def simul_translator_factory(args):
len_threshold = args.len_threshold

llmtrans = LLMTranslator(system_prompt=sys_prompt, max_context_length=args.max_context_length, len_ratio=len_threshold,
model_dir=args.model_dir, tokenizer_dir=args.tokenizer_dir)
model_dir=args.model_dir, tokenizer_dir=args.tokenizer_dir, sys_prompt_in_user=args.sys_in_user)
lan = args.tgt_lan if not args.tgt_lan.startswith("zh") else "zh"
simul = SimulLLM(llmtrans,language=lan, min_len=args.min_len, chunk=args.min_chunk_size,
init_src=init_src, init_tgt=init_tgt, trimming=args.buffer_trimming
Expand Down