diff --git a/tests/test_sgl_engine_decode.py b/tests/test_sgl_engine_decode.py new file mode 100644 index 00000000..706d04c1 --- /dev/null +++ b/tests/test_sgl_engine_decode.py @@ -0,0 +1,161 @@ +from types import SimpleNamespace + +import pytest + +from torchspec.inference.engine.sgl_engine_decode import SglDecodeEngineMixin + + +class FakeSglEngine: + def __init__(self, completion_tokens: int = 3, output_ids=None, prompt_lengths=None): + self.completion_tokens = completion_tokens + self.output_ids = output_ids if output_ids is not None else [11, 12, 13] + self.generate_kwargs = None + self.tokenizer_manager = SimpleNamespace( + tokenizer=SimpleNamespace( + encode=lambda prompt: [0] * (prompt_lengths or {}).get(prompt, 1) + ) + ) + + def generate(self, **kwargs): + self.generate_kwargs = kwargs + if "prompt" in kwargs: + batch_size = len(kwargs["prompt"]) + else: + batch_size = len(kwargs["input_ids"]) + return [ + { + "meta_info": { + "spec_training_mooncake_store_keys": [f"sample-key-{i}"], + "prompt_tokens": 5, + "completion_tokens": self.completion_tokens, + }, + "output_ids": self.output_ids, + } + for i in range(batch_size) + ] + + +class FakeDecodeEngine(SglDecodeEngineMixin): + def __init__( + self, + args, + completion_tokens: int = 3, + output_ids=None, + prompt_lengths=None, + ): + self.args = args + self.rank = 0 + self._engine = FakeSglEngine( + completion_tokens=completion_tokens, + output_ids=output_ids, + prompt_lengths=prompt_lengths, + ) + + def _extract_image_data(self, multimodal_inputs): + return None + + def _get_tensor_shapes(self, seq_len): + return { + "input_ids": (seq_len,), + "hidden_states": (seq_len, 16), + "last_hidden_states": (seq_len, 16), + } + + def _get_tensor_dtypes(self): + return { + "input_ids": "torch.int64", + "hidden_states": "torch.bfloat16", + "last_hidden_states": "torch.bfloat16", + } + + +def make_args(**overrides): + values = { + "decode_max_new_tokens": 16, + "decode_min_new_tokens": 2, + "decode_stop_token_ids": None, + "decode_temperature": 1.0, + "decode_top_p": 1.0, + "decode_top_k": -1, + "attention_backend": "flex_attention", + "max_seq_length": None, + } + values.update(overrides) + return SimpleNamespace(**values) + + +def test_generate_with_decode_passes_min_new_tokens(): + engine = FakeDecodeEngine(make_args(decode_min_new_tokens=4)) + + outputs = engine.generate_with_decode(data_id=["row-1"], formatted_prompts=["prompt"]) + + assert outputs[0]["packed_loss_mask"] == "5,2" + assert engine._engine.generate_kwargs["sampling_params"]["max_new_tokens"] == 16 + assert engine._engine.generate_kwargs["sampling_params"]["min_new_tokens"] == 4 + + +def test_generate_with_decode_passes_stop_token_ids(): + engine = FakeDecodeEngine(make_args(decode_stop_token_ids=[163586])) + + engine.generate_with_decode(data_id=["row-1"], formatted_prompts=["prompt"]) + + assert engine._engine.generate_kwargs["sampling_params"]["stop_token_ids"] == [163586] + + +def test_generate_with_decode_rejects_invalid_min_new_tokens(): + engine = FakeDecodeEngine(make_args(decode_max_new_tokens=2, decode_min_new_tokens=3)) + + with pytest.raises(ValueError, match="cannot exceed"): + engine.generate_with_decode(data_id=["row-1"], formatted_prompts=["prompt"]) + + +def test_generate_with_decode_drops_zero_loss_completions(): + engine = FakeDecodeEngine(make_args(), completion_tokens=1) + + outputs = engine.generate_with_decode(data_id=["row-1"], formatted_prompts=["prompt"]) + + assert outputs == [None] + + +def test_generate_with_decode_drops_leading_stop_token_completions(): + engine = FakeDecodeEngine( + make_args(decode_stop_token_ids=[163586]), + completion_tokens=3, + output_ids=[163586, 11, 12], + ) + + outputs = engine.generate_with_decode(data_id=["row-1"], formatted_prompts=["prompt"]) + + assert outputs == [None] + + +def test_generate_with_decode_skips_prompts_without_min_new_token_room(): + engine = FakeDecodeEngine( + make_args(max_seq_length=4, decode_min_new_tokens=2), + prompt_lengths={"too-long": 3}, + ) + + outputs = engine.generate_with_decode( + data_id=["row-1"], + formatted_prompts=["too-long"], + ) + + assert outputs == [None] + assert engine._engine.generate_kwargs is None + + +def test_generate_with_decode_preserves_batch_positions_when_skipping_prompts(): + engine = FakeDecodeEngine( + make_args(max_seq_length=4, decode_min_new_tokens=2), + prompt_lengths={"too-long": 3, "ok": 2}, + ) + + outputs = engine.generate_with_decode( + data_id=["row-1", "row-2"], + formatted_prompts=["too-long", "ok"], + ) + + assert outputs[0] is None + assert outputs[1]["mooncake_key"] == "sample-key-0" + assert engine._engine.generate_kwargs["prompt"] == ["ok"] + assert engine._engine.generate_kwargs["spec_training_data_id"] == ["row-2"] diff --git a/torchspec/config/train_config.py b/torchspec/config/train_config.py index c19c8106..99e1b40b 100644 --- a/torchspec/config/train_config.py +++ b/torchspec/config/train_config.py @@ -159,6 +159,8 @@ class DecodeConfig: cuda_graph_max_bs: Optional[int] = None max_new_tokens: int = 512 + min_new_tokens: int = 2 + stop_token_ids: Optional[list[int]] = None max_running_requests: Optional[int] = None speculative_algorithm: Optional[str] = None speculative_draft_model_path: Optional[str] = None diff --git a/torchspec/inference/engine/__init__.py b/torchspec/inference/engine/__init__.py index bf478b4f..09019bf9 100644 --- a/torchspec/inference/engine/__init__.py +++ b/torchspec/inference/engine/__init__.py @@ -21,14 +21,21 @@ from torchspec.inference.engine.base import InferenceEngine from torchspec.inference.engine.hf_engine import HFEngine -from torchspec.inference.engine.hf_runner import HFRunner __all__ = [ "InferenceEngine", "HFEngine", - "HFRunner", ] +try: + from torchspec.inference.engine.hf_runner import HFRunner + + __all__.append("HFRunner") +except ImportError as _e: + from torchspec.utils.logging import logger as _logger + + _logger.debug("HFRunner not available: %s", _e) + # Lazy imports: SGLang/vLLM are optional — HF-only training (e.g. single-GPU # DFlash) should not require these heavy dependencies to be installed. try: diff --git a/torchspec/inference/engine/sgl_engine_decode.py b/torchspec/inference/engine/sgl_engine_decode.py index 29f770f0..8785d427 100644 --- a/torchspec/inference/engine/sgl_engine_decode.py +++ b/torchspec/inference/engine/sgl_engine_decode.py @@ -90,6 +90,43 @@ def _build_decode_engine_kwargs(self, engine_kwargs: dict) -> None: # -- Generation ----------------------------------------------------------- + def _decode_prompt_token_lengths( + self, + *, + use_prompts: bool, + formatted_prompts: list[str] | None, + input_ids_list: list[torch.Tensor] | None, + ) -> list[int | None]: + if not use_prompts: + lengths: list[int | None] = [] + assert input_ids_list is not None + for ids in input_ids_list: + if ids.dim() == 2 and ids.shape[0] == 1: + ids = ids.squeeze(0) + lengths.append(int(ids.numel())) + return lengths + + tokenizer_manager = getattr(self._engine, "tokenizer_manager", None) + tokenizer = getattr(tokenizer_manager, "tokenizer", None) + if tokenizer is None or not hasattr(tokenizer, "encode"): + return [None for _ in formatted_prompts or []] + + lengths = [] + for prompt in formatted_prompts or []: + try: + try: + token_ids = tokenizer.encode(prompt, add_special_tokens=False) + except TypeError: + token_ids = tokenizer.encode(prompt) + lengths.append(len(token_ids)) + except Exception as exc: + logger.warning( + f"SglEngine rank {self.rank}: failed to estimate prompt token " + f"length before decode ({exc!r}); letting SGLang tokenize it." + ) + lengths.append(None) + return lengths + def generate_with_decode( self, data_id: str | list[str], @@ -99,7 +136,7 @@ def generate_with_decode( return_last_hidden_states: bool = False, return_logits: bool = True, multimodal_inputs: list[dict] | None = None, - ) -> list[dict[str, Any]]: + ) -> list[dict[str, Any] | None]: """Generate training data with decoding (spec training with actual token generation). Unlike generate() which does prefill-only, this method generates new tokens @@ -149,7 +186,22 @@ def generate_with_decode( # Build sampling params for decode mode max_new_tokens = getattr(self.args, "decode_max_new_tokens", 512) - sampling_params = {"max_new_tokens": max_new_tokens} + min_new_tokens = getattr(self.args, "decode_min_new_tokens", 2) + if min_new_tokens < 0: + raise ValueError(f"decode.min_new_tokens must be >= 0, got {min_new_tokens}") + if min_new_tokens > max_new_tokens: + raise ValueError( + f"decode.min_new_tokens ({min_new_tokens}) cannot exceed " + f"decode.max_new_tokens ({max_new_tokens})" + ) + sampling_params = { + "max_new_tokens": max_new_tokens, + "min_new_tokens": min_new_tokens, + } + stop_token_ids = getattr(self.args, "decode_stop_token_ids", None) + if stop_token_ids: + sampling_params["stop_token_ids"] = list(stop_token_ids) + stop_token_id_set = set(stop_token_ids or []) temperature = getattr(self.args, "decode_temperature", 1.0) if temperature != 1.0: sampling_params["temperature"] = temperature @@ -159,56 +211,100 @@ def generate_with_decode( top_k = getattr(self.args, "decode_top_k", -1) if top_k > 0: sampling_params["top_k"] = top_k + logger.debug( + f"SglEngine rank {self.rank}: decode sampling_params={sampling_params}" + ) + + outputs: list[dict[str, Any] | None] = [None for _ in range(batch_size)] + active_indices = list(range(batch_size)) + max_seq_length = getattr(self.args, "max_seq_length", None) + if max_seq_length: + prompt_lengths = self._decode_prompt_token_lengths( + use_prompts=use_prompts, + formatted_prompts=formatted_prompts, + input_ids_list=None if use_prompts else input_ids_list, + ) + active_indices = [] + for i, prompt_len in enumerate(prompt_lengths): + if prompt_len is not None and prompt_len + min_new_tokens > max_seq_length: + logger.warning( + f"SglEngine rank {self.rank}: skipping data_id={data_ids[i]} " + f"because prompt_tokens={prompt_len} leaves less than " + f"min_new_tokens={min_new_tokens} within max_seq_length={max_seq_length}" + ) + continue + active_indices.append(i) + + if not active_indices: + return outputs + + active_data_ids = [data_ids[i] for i in active_indices] + active_multimodal_inputs = ( + [multimodal_inputs[i] for i in active_indices] + if multimodal_inputs is not None + else None + ) if use_prompts: + active_formatted_prompts = [formatted_prompts[i] for i in active_indices] logger.debug( - f"SglEngine rank {self.rank}: decode prompt mode processing data_ids={data_ids}, " - f"num_prompts={len(formatted_prompts)}" + f"SglEngine rank {self.rank}: decode prompt mode processing data_ids={active_data_ids}, " + f"num_prompts={len(active_formatted_prompts)}" ) engine_kwargs = { - "prompt": formatted_prompts, - "spec_training_data_id": data_ids, + "prompt": active_formatted_prompts, + "spec_training_data_id": active_data_ids, "sampling_params": sampling_params, "return_hidden_states": True, } else: input_ids_list_of_lists = [] - for ids in input_ids_list: + for i in active_indices: + ids = input_ids_list[i] if ids.dim() == 2 and ids.shape[0] == 1: ids = ids.squeeze(0) elif ids.dim() > 2: raise ValueError(f"Unexpected input_ids shape: {ids.shape}") input_ids_list_of_lists.append(ids.tolist()) + active_packed_loss_mask_list = ( + [packed_loss_mask_list[i] for i in active_indices] + if packed_loss_mask_list is not None + else None + ) logger.debug( - f"SglEngine rank {self.rank}: decode mode processing data_ids={data_ids}, " + f"SglEngine rank {self.rank}: decode mode processing data_ids={active_data_ids}, " f"shapes: {[len(ids) for ids in input_ids_list_of_lists]}" ) engine_kwargs = { "input_ids": input_ids_list_of_lists, - "spec_training_data_id": data_ids, - "packed_loss_mask": packed_loss_mask_list, + "spec_training_data_id": active_data_ids, + "packed_loss_mask": active_packed_loss_mask_list, "sampling_params": sampling_params, "return_hidden_states": True, } - image_data = self._extract_image_data(multimodal_inputs) + image_data = self._extract_image_data(active_multimodal_inputs) if image_data is not None: engine_kwargs["image_data"] = image_data results = self._engine.generate(**engine_kwargs) + if len(results) != len(active_indices): + raise RuntimeError( + f"SglEngine rank {self.rank}: decode expected {len(active_indices)} " + f"results from SGLang, got {len(results)}" + ) # IMPORTANT: Must produce exactly one output per input result to match # the zip(entries, outputs, strict=True) in the inference manager. - outputs = [] - for i, result in enumerate(results): + for result_index, result in enumerate(results): + i = active_indices[result_index] store_keys = result["meta_info"].get("spec_training_mooncake_store_keys", []) if not store_keys: logger.warning( f"SglEngine rank {self.rank}: No mooncake keys returned for " - f"data_id={data_ids[i]}, skipping this sample." + f"data_id={active_data_ids[result_index]}, skipping this sample." ) - outputs.append(None) continue meta_info = result["meta_info"] @@ -226,14 +322,25 @@ def generate_with_decode( key = store_keys[0] prompt_tokens = meta_info.get("prompt_tokens", 0) completion_tokens = meta_info.get("completion_tokens", 0) - if completion_tokens > 0: - seq_len = prompt_tokens + completion_tokens - 1 - else: + output_ids = result.get("output_ids", []) + if completion_tokens <= 1: + text = result.get("text", "") + logger.warning( + f"SglEngine rank {self.rank}: completion_tokens={completion_tokens} for " + f"data_id={active_data_ids[result_index]}, finish_reason={meta_info.get('finish_reason')}, " + f"output_ids={output_ids}, text={text!r}, dropping sample to avoid " + f"a zero-loss mask" + ) + continue + if output_ids and output_ids[0] in stop_token_id_set: logger.warning( - f"SglEngine rank {self.rank}: completion_tokens=0 for " - f"data_id={data_ids[i]}, sample will produce zero loss" + f"SglEngine rank {self.rank}: completion starts with stop token " + f"{output_ids[0]} for data_id={active_data_ids[result_index]}, " + f"finish_reason={meta_info.get('finish_reason')}, dropping sample." ) - seq_len = prompt_tokens + continue + + seq_len = prompt_tokens + completion_tokens - 1 logger.debug( f"SglEngine rank {self.rank}: decode mode - " f"prompt={prompt_tokens}, completion={completion_tokens}, " @@ -272,11 +379,12 @@ def generate_with_decode( ) output_dict.update({k: meta_info[k] for k in _METRIC_KEYS if k in meta_info}) - outputs.append(output_dict) + outputs[i] = output_dict + accepted_count = sum(output is not None for output in outputs) logger.debug( - f"SglEngine rank {self.rank}: decode generated {len(outputs)} mooncake keys " - f"for data_ids={data_ids}" + f"SglEngine rank {self.rank}: decode generated {accepted_count} mooncake keys " + f"for {len(data_ids)} data_ids={data_ids}" ) return outputs