diff --git a/examples/vllm_inference.py b/examples/vllm_inference.py index 950505db8..902d95486 100644 --- a/examples/vllm_inference.py +++ b/examples/vllm_inference.py @@ -37,12 +37,14 @@ def main(): pipeline_name=pipeline_name, model_args=model_args, data_args=data_args, pipeline_args=pipeline_args ) - inferencer.inference( + # `release_gpu=True` does an in-process best-effort cleanup; it is + # sufficient for this standalone example. For colocated training+inference + # (e.g. iterative DPO) or tensor_parallel_size > 1, prefer + # `MemorySafeVLLMInferencer` instead. + res = inferencer.inference( model, dataset, - release_gpu=False, - enable_decode_inference_result=pipeline_args.enable_decode_inference_result, - enable_distributed_vllm_inference=pipeline_args.enable_distributed_vllm_inference, + release_gpu=True, ) diff --git a/scripts/run_sglang_inference.sh b/scripts/run_sglang_inference.sh index c89c3c86e..2002fcd72 100644 --- a/scripts/run_sglang_inference.sh +++ b/scripts/run_sglang_inference.sh @@ -9,4 +9,4 @@ python examples/sglang_inference.py \ --top_p 0.95 \ --random_seed 42 \ --save_inference_results True \ - --inference_results_path output_data/sglang_inference_results/results.json \ No newline at end of file + --inference_results_path output_data/sglang_inference_results/ \ No newline at end of file diff --git a/scripts/run_vllm_inference.sh b/scripts/run_vllm_inference.sh new file mode 100644 index 000000000..044e32f78 --- /dev/null +++ b/scripts/run_vllm_inference.sh @@ -0,0 +1,13 @@ +python examples/vllm_inference.py \ + --model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \ + --dataset_path data/alpaca/prompt_only \ + --inference_engine vllm \ + --inference_gpu_memory_utilization 0.8 \ + --inference_max_model_len 16384 \ + --num_output_sequences 2 \ + --temperature 1.0 \ + --max_new_tokens 2048 \ + --top_p 0.95 \ + --random_seed 42 \ + --save_inference_results True \ + --inference_results_path output_data/vllm_inference_results/ diff --git a/setup.py b/setup.py index e810b8a2b..2e5842a9a 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ extra_require = { "multimodal": ["Pillow"], - "vllm": ["vllm>=0.4.3"], + "vllm": ["vllm>=0.8.0"], "sglang": ["sglang"], "ray": ["ray>=2.22.0"], "gradio": ["gradio"], diff --git a/src/lmflow/args.py b/src/lmflow/args.py index b32d8866a..249c29486 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -1041,9 +1041,27 @@ class InferencerArguments: inference_tensor_parallel_size: Optional[int] = field( default=1, metadata={"help": "The tensor parallel size for inference."} ) + inference_data_parallel_size: Optional[int] = field( + default=1, + metadata={ + "help": ( + "The data parallel size for inference. Only supported for vLLM (>= 0.8) inference engine. " + "Total GPUs used = tensor_parallel_size * data_parallel_size." + ) + }, + ) inference_gpu_memory_utilization: Optional[float] = field( default=0.95, metadata={"help": "The GPU memory utilization for inference."} ) + inference_max_model_len: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Maximum model context length for inference. If not set, uses the model's default. " + "Reduce this if the model's default exceeds available GPU memory." + ) + }, + ) enable_deterministic_inference: bool = field( default=False, metadata={ @@ -1065,7 +1083,14 @@ class InferencerArguments: results_path: Optional[str] = field(default=None, metadata={"help": "The path of results."}) save_inference_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."}) - inference_results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."}) + inference_results_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Directory to save inference results. Results are saved as 'inference_results.pkl' inside this directory." + ) + }, + ) def __post_init__(self): if self.use_accelerator is not None: @@ -1087,10 +1112,7 @@ def __post_init__(self): if self.inference_results_path is None: raise ValueError("Need to specify inference_results_path when save_inference_results is True.") else: - if not self.inference_results_path.endswith(".json"): - raise ValueError("The inference_results_path must be a json file.") - else: - Path(self.inference_results_path).parent.mkdir(parents=True, exist_ok=True) + Path(self.inference_results_path).mkdir(parents=True, exist_ok=True) if self.use_vllm is True: logger.warning( diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 14b6b4753..66b57cccc 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -37,10 +37,9 @@ TEXT_ONLY_DATASET_DESCRIPTION, ) from lmflow.utils.conversation_template import PRESET_TEMPLATES -from lmflow.utils.data_utils import VLLMInferenceResultWithInput from lmflow.utils.deprecated import deprecated_args from lmflow.utils.envs import is_accelerate_env -from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available +from lmflow.utils.versioning import is_flash_attn_available, is_vllm_available from lmflow.utils.protocol import DataProto logger = logging.getLogger(__name__) @@ -54,10 +53,6 @@ if is_vllm_available(): from vllm import SamplingParams -if is_ray_available(): - import ray - import ray.data - class HFDecoderModel(DecoderModel, HFModelMixin, Tunable): r""" @@ -321,10 +316,12 @@ def inference( inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", gpu_memory_utilization: Optional[float] = None, tensor_parallel_size: Optional[int] = None, + data_parallel_size: int = 1, + max_model_len: Optional[int] = None, enable_deterministic_inference: bool = False, attention_backend: Optional[str] = None, **kwargs, - ) -> Union[list[VLLMInferenceResultWithInput] | DataProto]: + ) -> Union[list, DataProto]: """ Perform generation process of the model. @@ -332,9 +329,8 @@ def inference( ------------ inputs : Union[str, list[str], torch.Tensor, DataProto] The sequence used as a prompt for the generation or as model inputs to the model. - When the inference engine is "vllm", this should be a string or a list of strings. + When the inference engine is "vllm" or "sglang", this should be a DataProto. When the inference engine is "huggingface", this should be a tensor. - When the inference engine is "sglang", this should be a DataProto. sampling_params : Optional[Union[dict, "SamplingParams"]], optional The sampling parameters to use, by default None. return_logprob : bool, optional @@ -347,6 +343,10 @@ def inference( The GPU memory utilization to use, by default None. tensor_parallel_size : int, optional The tensor parallel size to use, by default None. + data_parallel_size : int, optional + The data parallel size for vllm inference, by default 1. + max_model_len : int, optional + Maximum model context length for vllm inference, by default None. enable_deterministic_inference : bool, optional Whether to enable deterministic inference, by default False. attention_backend : Optional[str], optional @@ -365,12 +365,14 @@ def inference( inference_engine=inference_engine, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + max_model_len=max_model_len, enable_deterministic_inference=enable_deterministic_inference, attention_backend=attention_backend, ) if inference_engine == "vllm": - res = self.__vllm_inference(inputs=inputs, sampling_params=sampling_params) + res = self.__vllm_inference(inputs=inputs) elif inference_engine == "sglang": res = self.__sglang_inference( inputs=inputs, @@ -424,46 +426,29 @@ def __inference(self, inputs, *args, **kwargs): def __vllm_inference( self, - inputs: list[str], - sampling_params: Optional["SamplingParams"] = None, - ) -> list[VLLMInferenceResultWithInput]: - """Perform VLLM inference process of the model. - - Parameters - ---------- - inputs : list[str] - Prompt(s), string or a list of strings. - sampling_params : Optional[SamplingParams], optional - vllm SamplingParams object, by default None. - - Returns - ------- - list[VLLMInferenceResultWithInput] - Return a list of VLLMInferenceResultWithInput, where each - element contains the input prompt and the corresponding output. - - When `sampling_params.detokenize = True`, the output would be a list of strings, - contains sampling_params.n samples for the corresponding prompt. + inputs: DataProto, + ) -> DataProto: + """Perform VLLM inference process of the model.""" + prompts = inputs.non_tensor_batch["inputs"].tolist() + sampling_params_dict = inputs.meta_info["sampling_params"] + + vllm_sampling_params = SamplingParams( + n=sampling_params_dict.get("n", 1), + temperature=sampling_params_dict.get("temperature", 0.0), + max_tokens=sampling_params_dict.get("max_new_tokens", 100), + seed=sampling_params_dict.get("seed"), + top_p=sampling_params_dict.get("top_p", 1.0), + top_k=sampling_params_dict.get("top_k", 0), + stop_token_ids=sampling_params_dict.get("stop_token_ids"), + ) - When `sampling_params.detokenize = False`, return a list of list of ints - (token ids, no decoding after generation). - """ vllm_outputs = self.backend_model_for_inference.generate( - inputs, - sampling_params=sampling_params, + prompts, + sampling_params=vllm_sampling_params, use_tqdm=True, ) - # TODO: unified lmflow sample format - final_output = [] - for output in vllm_outputs: - if sampling_params.detokenize: - output_list = [sentence.text for sentence in output.outputs] - else: - output_list = [sentence.token_ids for sentence in output.outputs] - - final_output.append({"input": output.prompt, "output": output_list}) - - return final_output + inputs.non_tensor_batch["outputs"] = [output.outputs[0].text for output in vllm_outputs] + return inputs def __sglang_inference( self, @@ -495,9 +480,8 @@ def prepare_inputs_for_inference( dataset: Dataset, apply_chat_template: bool = True, inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", - enable_distributed_inference: bool = False, sampling_params: Optional[dict] = None, - ) -> Union[list[str], "ray.data.Dataset", DataProto]: + ) -> Union[list[str], DataProto]: if dataset.get_type() == "text_only": if apply_chat_template: dataset = dataset.map( @@ -572,24 +556,17 @@ def preprocess_conversation(sample): inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0] - if inference_engine == "vllm" and enable_distributed_inference: - inference_inputs = ray.data.from_items( - inference_inputs - ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])} - - if inference_engine == "sglang": + if inference_engine in ("sglang", "vllm"): if self.tokenizer.bos_token: - # in consistent with sglang bench_serving.py demo inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs] - # currently only test dataproto on sglang inference inference_inputs = np.array(inference_inputs) inference_inputs = DataProto.from_single_dict( data={"inputs": inference_inputs}, meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]} ) - - # handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines. + + # handling n>1 since we don't want one-to-many mapping inference_inputs = inference_inputs.repeat(sampling_params["n"]) return inference_inputs diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index dce5bd830..71ff6bba8 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -448,20 +448,27 @@ def __prepare_model_for_vllm_inference( model_args: ModelArguments, gpu_memory_utilization: float, tensor_parallel_size: int, + data_parallel_size: int = 1, + max_model_len: Optional[int] = None, ): if not is_vllm_available(): raise ImportError('VLLM is not available. Please install via `pip install -e ".[vllm]"`.') from vllm import LLM - self.backend_model_for_inference = LLM( + kwargs = dict( model=model_args.model_name_or_path, tokenizer=model_args.model_name_or_path, dtype=model_args.torch_dtype if model_args.torch_dtype else "auto", load_format="auto", gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, ) + if max_model_len is not None: + kwargs["max_model_len"] = max_model_len + + self.backend_model_for_inference = LLM(**kwargs) def __prepare_model_for_sglang_inference( self, @@ -513,6 +520,8 @@ def activate_model_for_inference( inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface", gpu_memory_utilization: Optional[float] = None, tensor_parallel_size: Optional[int] = None, + data_parallel_size: int = 1, + max_model_len: Optional[int] = None, enable_deterministic_inference: bool = False, attention_backend: Optional[str] = None, ): @@ -525,6 +534,8 @@ def activate_model_for_inference( model_args=self.model_args, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + max_model_len=max_model_len, ) elif inference_engine == "sglang": self.__prepare_model_for_sglang_inference( @@ -548,20 +559,26 @@ def deactivate_model_for_inference( ): """Deactivate the model and release the resources. - NOTE: Currently, VLLM doesn't have an official way to do this, and the - implementation below cannot release all gpu resources by our observation. - Thus this method is just a placeholder for future implementation. See: - [Github issue](https://github.com/vllm-project/vllm/issues/1908) + NOTE: For vllm (>=0.8), the best-effort release below works for most + single-GPU, inference-only use cases. It remains unreliable when + ``tensor_parallel_size > 1``, CUDA graphs are enabled, or the same + process also holds an HF training model — in those cases use + :class:`MemorySafeVLLMInferencer`, which isolates inference in a + subprocess. vllm still has no official in-process shutdown API + (RFC vllm-project/vllm#24885); ``MemorySafeVLLMInferencer`` is kept + for backward compatibility and will be migrated to vllm sleep mode + in a follow-up. """ if not self._activated: logger.warning("You are trying to deactivate the model for inference, but it is already deactivated.") return if inference_engine == "vllm": - from vllm.distributed.parallel_state import destroy_model_parallel - - destroy_model_parallel() - del self.backend_model_for_inference.llm_engine.model_executor.driver_worker + try: + from vllm.distributed.parallel_state import destroy_model_parallel + destroy_model_parallel() + except Exception: + pass del self.backend_model_for_inference gc.collect() torch.cuda.empty_cache() diff --git a/src/lmflow/pipeline/sglang_inferencer.py b/src/lmflow/pipeline/sglang_inferencer.py index 59bcfcce0..f9d012939 100644 --- a/src/lmflow/pipeline/sglang_inferencer.py +++ b/src/lmflow/pipeline/sglang_inferencer.py @@ -2,6 +2,7 @@ # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. import json import logging +import os from typing import Optional, Union from transformers import AutoTokenizer @@ -101,14 +102,13 @@ def save_inference_results( outputs: DataProto, inference_results_path: str, ): - if not inference_results_path.endswith(".pkl"): - logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl") - inference_results_path = inference_results_path + ".pkl" - outputs.save_to_disk(inference_results_path) - logger.info(f"Inference results are saved to {inference_results_path}.") + save_path = os.path.join(inference_results_path, "inference_results.pkl") + outputs.save_to_disk(save_path) + logger.info(f"Inference results are saved to {save_path}.") def load_inference_results( self, inference_results_path: str, ) -> DataProto: - return DataProto.load_from_disk(inference_results_path) + load_path = os.path.join(inference_results_path, "inference_results.pkl") + return DataProto.load_from_disk(load_path) diff --git a/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py b/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py index db14ae96d..387394ddd 100644 --- a/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py +++ b/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py @@ -45,10 +45,6 @@ def main(): model, dataset, release_gpu=False, - enable_decode_inference_result=pipeline_args.enable_decode_inference_result, - enable_distributed_inference=pipeline_args.enable_distributed_inference, - distributed_inference_num_instances=pipeline_args.distributed_inference_num_instances, - inference_batch_size=pipeline_args.vllm_inference_batch_size, ) print(len(res)) diff --git a/src/lmflow/pipeline/vllm_inferencer.py b/src/lmflow/pipeline/vllm_inferencer.py index 9873598f8..15d80f8fd 100644 --- a/src/lmflow/pipeline/vllm_inferencer.py +++ b/src/lmflow/pipeline/vllm_inferencer.py @@ -1,17 +1,14 @@ #!/usr/bin/env python # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. -import copy import importlib.resources as pkg_resources -import json import logging import os os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" import subprocess import sys -from typing import Any, Optional, Union +from typing import Optional -import numpy as np from transformers import AutoTokenizer from lmflow.args import ( @@ -24,288 +21,123 @@ from lmflow.pipeline.base_pipeline import BasePipeline from lmflow.utils.common import make_shell_args_from_dataclass from lmflow.utils.constants import MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE, RETURN_CODE_ERROR_BUFFER -from lmflow.utils.data_utils import VLLMInferenceResultWithInput -from lmflow.utils.versioning import is_ray_available, is_vllm_available +from lmflow.utils.protocol import DataProto +from lmflow.utils.versioning import is_vllm_available logger = logging.getLogger(__name__) if is_vllm_available(): - from vllm import SamplingParams + pass else: raise ImportError("VLLM is not available, please install vllm.") -if is_ray_available(): - import ray - import ray.data - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -else: - logger.warning("Ray is not available, distributed vllm inference will not be supported.") - -class InferencerWithOffloading(BasePipeline): +class VLLMInferencer(BasePipeline): def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): + assert inferencer_args.inference_engine == "vllm" self.model_args = model_args self.data_args = data_args self.inferencer_args = inferencer_args self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id + self.sampling_params = self._parse_args_to_sampling_params(inferencer_args) - def inference(self): - raise NotImplementedError(".inference is not implemented") - - def save_inference_results(self): - raise NotImplementedError(".save_inference_results is not implemented") - - def load_inference_results(self): - raise NotImplementedError(".load_inference_results is not implemented") - - -class VLLMInferencer(InferencerWithOffloading): - def __init__( - self, - model_args: ModelArguments, - data_args: DatasetArguments, - inferencer_args: InferencerArguments, - ): - assert inferencer_args.use_vllm, "The inferencer_args.use_vllm must be True." - super().__init__(model_args, data_args, inferencer_args) - self.sampling_params = self.parse_to_sampling_params(inferencer_args) - - def parse_to_sampling_params( + def _parse_args_to_sampling_params( self, inference_args: InferencerArguments, - ) -> SamplingParams: - return SamplingParams( - use_beam_search=inference_args.use_beam_search, - n=inference_args.num_output_sequences, - temperature=inference_args.temperature + 1e-6, - max_tokens=inference_args.max_new_tokens, - seed=inference_args.random_seed, - top_p=inference_args.top_p, - top_k=inference_args.top_k, - stop_token_ids=[self.eos_token_id] + inference_args.additional_stop_token_ids, - ) + ) -> dict: + if inference_args.use_beam_search: + logger.warning("`use_beam_search` is ignored, as vLLM V1 engine no longer supports beam search.") + + sampling_params = { + "n": inference_args.num_output_sequences, + "temperature": inference_args.temperature + 1e-6, + "max_new_tokens": inference_args.max_new_tokens, + "seed": inference_args.random_seed, + "top_p": inference_args.top_p, + "top_k": inference_args.top_k, + "stop_token_ids": [self.eos_token_id] + inference_args.additional_stop_token_ids, + } + + return sampling_params def inference( self, model: HFDecoderModel, dataset: Dataset, - enable_decode_inference_result: bool = True, release_gpu: bool = False, inference_args: Optional[InferencerArguments] = None, - enable_distributed_inference: bool = False, - **kwargs, - ) -> list[VLLMInferenceResultWithInput]: - """Perform inference using the provided model and dataset. Will save inference results if - `save_results` is set to True in `inferencer_args`. - - Parameters - ---------- - model : HFDecoderModel - LMFlow HFDecoderModel object - dataset : Dataset - LMFlow Dataset object - apply_chat_template : bool, optional - Whether to apply chat template to the input, by default True. - enable_decode_inference_result : bool, optional - Whether to decode after generation, by default False. - release_gpu : bool, optional - Whether to release gpu resources, by default False. - inference_args : InferencerArguments, optional - by default None - - Returns - ------- - list[VLLMInferenceResultWithInput] - Return a list of VLLMInferenceResultWithInput, where each - element contains the input prompt and the corresponding output. - - When `enable_decode_inference_result = True`, the output would be a list of strings, - contains sampling_params.n samples for the corresponding prompt. - - When `enable_decode_inference_result = False`, return a list of list of ints - (token ids, no decoding after generation). - """ + ) -> DataProto: if inference_args: logger.warning("Overriding the default inference arguments with the provided arguments in .inference()") - sampling_params = self.parse_to_sampling_params(inference_args) + sampling_params = self._parse_args_to_sampling_params(inference_args) else: sampling_params = self.sampling_params - sampling_params.detokenize = enable_decode_inference_result - model_input = model.prepare_inputs_for_inference( dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, - use_vllm=self.inferencer_args.use_vllm, - enable_distributed_inference=enable_distributed_inference, + inference_engine="vllm", + sampling_params=sampling_params, ) - if enable_distributed_inference: - outputs = self._distributed_inference( - model=model, - model_input=model_input, - sampling_params=sampling_params, - num_instances=kwargs.get("distributed_inference_num_instances"), - batch_size=kwargs.get("inference_batch_size", 4), - release_gpu=release_gpu, - ) - else: - outputs = self._inference( - model=model, - model_input=model_input, - sampling_params=sampling_params, - release_gpu=release_gpu, - ) - - if self.inferencer_args.save_results: - self.save_inference_results(outputs, self.inferencer_args.results_path) - - return outputs - - def _inference( - self, - model: HFDecoderModel, - model_input: list[str], - sampling_params: SamplingParams, - release_gpu: bool = False, - ) -> list[VLLMInferenceResultWithInput]: outputs = model.inference( inputs=model_input, - sampling_params=sampling_params, release_gpu=release_gpu, - use_vllm=True, + inference_engine="vllm", gpu_memory_utilization=self.inferencer_args.inference_gpu_memory_utilization, tensor_parallel_size=self.inferencer_args.inference_tensor_parallel_size, + data_parallel_size=self.inferencer_args.inference_data_parallel_size, + max_model_len=self.inferencer_args.inference_max_model_len, ) - return outputs - - def _distributed_inference( - self, - model: HFDecoderModel, - model_input: ray.data.Dataset, - sampling_params: SamplingParams, - num_instances: int, - batch_size: int = 4, - release_gpu: bool = False, - ) -> list[VLLMInferenceResultWithInput]: - # prepare distributed inference resources - # from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_distributed.py - ## strategy - def scheduling_strategy_fn(): - # One bundle per tensor parallel worker - pg = ray.util.placement_group( - [{"GPU": 1, "CPU": 1}] * self.inferencer_args.inference_tensor_parallel_size, - strategy="STRICT_PACK", - ) - return dict( - scheduling_strategy=PlacementGroupSchedulingStrategy(pg, placement_group_capture_child_tasks=True) - ) - - resources_kwarg: dict[str, Any] = {} - if self.inferencer_args.inference_tensor_parallel_size == 1: - # For tensor_parallel_size == 1, we simply set num_gpus=1. - resources_kwarg["num_gpus"] = 1 - else: - # Otherwise, we have to set num_gpus=0 and provide - # a function that will create a placement group for - # each instance. - resources_kwarg["num_gpus"] = 0 - resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn - - ## predictor - class DistributedPredictor: - def __init__( - self, - model: HFDecoderModel, - sampling_params: SamplingParams, - gpu_memory_utilization: float, - tensor_parallel_size: int, - release_gpu: bool = False, - ): - self.model = copy.deepcopy(model) - self.model.activate_model_for_inference( - use_vllm=True, - gpu_memory_utilization=gpu_memory_utilization, - tensor_parallel_size=tensor_parallel_size, - ) - self.sampling_params = sampling_params - self.release_gpu = release_gpu - - def __call__(self, batch: dict[str, np.ndarray]): - """batch: dict[str, np.ndarray], {"item": array(['...', '...', '...', ...])}""" - batched_inference_res = self.model.inference( - inputs=batch["item"], - sampling_params=self.sampling_params, - release_gpu=self.release_gpu, - use_vllm=True, - ) # this is the postprocessed output, see model.__vllm_inference - batched_final_res = { - "input": [sample["input"] for sample in batched_inference_res], - "output": [sample["output"] for sample in batched_inference_res], - } # do this since we're writing to a pandas dataframe - return batched_final_res - - # inference - model_input_mapping = model_input.map_batches( - DistributedPredictor, - concurrency=num_instances, # Set the concurrency to the number of LLM instances. - batch_size=batch_size, - fn_constructor_kwargs={ - "model": model, - "sampling_params": sampling_params, - "gpu_memory_utilization": self.inferencer_args.inference_gpu_memory_utilization, - "tensor_parallel_size": self.inferencer_args.inference_tensor_parallel_size, - "release_gpu": release_gpu, - }, - **resources_kwarg, - ) - - df_model_output = model_input_mapping.to_pandas() # the actual forwards are executed here - logger.info(f"Distributed vllm inference result preview:\n{df_model_output.head(10)}") - - model_output = [{"input": row["input"], "output": row["output"]} for _, row in df_model_output[:].iterrows()] + if self.inferencer_args.save_inference_results: + self.save_inference_results(outputs, self.inferencer_args.inference_results_path) - return model_output + return outputs def save_inference_results( self, - outputs: Union[list[list[str]], list[list[list[int]]]], - save_file_path: str, + outputs: DataProto, + inference_results_path: str, ): - with open(save_file_path, "w", encoding="utf-8") as f: - json.dump(outputs, f, ensure_ascii=False, indent=4) - - logger.info(f"Inference results are saved to {save_file_path}.") + save_path = os.path.join(inference_results_path, "inference_results.pkl") + outputs.save_to_disk(save_path) + logger.info(f"Inference results are saved to {save_path}.") def load_inference_results( self, - results_path: str, - ) -> Union[list[list[str]], list[list[list[int]]]]: - with open(results_path) as f: - results = json.load(f) - - return results + inference_results_path: str, + ) -> DataProto: + load_path = os.path.join(inference_results_path, "inference_results.pkl") + return DataProto.load_from_disk(load_path) class MemorySafeVLLMInferencer(VLLMInferencer): + """Run VLLM inference in a subprocess for memory safety. + + This is a workaround since vllm cannot release GPU memory properly + in-process. See: https://github.com/vllm-project/vllm/issues/1908 + """ + def __init__( self, model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, ): - assert inferencer_args.save_results, "For MemorySafeVLLMInferencer, `save_results` must be True." + assert inferencer_args.save_inference_results or inferencer_args.save_results, ( + "For MemorySafeVLLMInferencer, `save_inference_results` must be True." + ) super().__init__(model_args, data_args, inferencer_args) self.inferencer_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_vllm_inference.py" - def inference(self) -> list[VLLMInferenceResultWithInput]: + def inference(self) -> DataProto: inferencer_args = make_shell_args_from_dataclass( dataclass_objects=[ self.model_args, @@ -330,9 +162,6 @@ def inference(self) -> list[VLLMInferenceResultWithInput]: logger.info(f"MemorySafeVLLMInference subprocess run finished, info at finish: {cli_res}") if cli_res.returncode in RETURN_CODE_ERROR_BUFFER: - # > Fatal Python error: _enter_buffered_busy: could not acquire lock for - # > <_io.BufferedWriter name=''> at interpreter shutdown, possibly - # > due to daemon threads logger.warning( "^^^^^^^^^^ Please ignore the above error, as it comes from the subprocess. " "This may due to a kill signal with unfinished stdout/stderr writing in the subprocess. " @@ -341,7 +170,8 @@ def inference(self) -> list[VLLMInferenceResultWithInput]: if cli_res.returncode != 0: raise RuntimeError(f"Error during MemorySafeVLLMInference: {cli_res}") - outputs = self.load_inference_results(self.inferencer_args.results_path) + inference_results_path = self.inferencer_args.inference_results_path or self.inferencer_args.results_path + outputs = self.load_inference_results(inference_results_path) logger.info("MemorySafeVLLMInference result captured.") return outputs diff --git a/tests/pipeline/test_vllm_inferencer.py b/tests/pipeline/test_vllm_inferencer.py new file mode 100644 index 000000000..7aaa512fd --- /dev/null +++ b/tests/pipeline/test_vllm_inferencer.py @@ -0,0 +1,249 @@ +import os +import tempfile +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from lmflow.args import DatasetArguments, InferencerArguments, ModelArguments +from lmflow.utils.protocol import DataProto + + +@pytest.fixture +def model_args(): + return ModelArguments(model_name_or_path="Qwen/Qwen2-0.5B") + + +@pytest.fixture +def inferencer_args(): + return InferencerArguments( + inference_engine="vllm", + inference_gpu_memory_utilization=0.8, + num_output_sequences=2, + temperature=1.0, + max_new_tokens=128, + top_p=0.95, + top_k=50, + random_seed=42, + use_beam_search=False, + ) + + +@pytest.fixture +def data_args(): + return DatasetArguments(dataset_path=None) + + +class TestParseArgsToSamplingParams: + """Test that _parse_args_to_sampling_params returns the correct dict.""" + + def test_returns_dict(self, model_args, inferencer_args): + with patch("lmflow.pipeline.vllm_inferencer.AutoTokenizer") as mock_tok: + mock_tok.from_pretrained.return_value = MagicMock(eos_token_id=151643) + from lmflow.pipeline.vllm_inferencer import VLLMInferencer + + inferencer = VLLMInferencer(model_args, DatasetArguments(dataset_path=None), inferencer_args) + params = inferencer.sampling_params + + assert isinstance(params, dict) + assert set(params.keys()) == {"n", "temperature", "max_new_tokens", "seed", "top_p", "top_k", "stop_token_ids"} + + def test_values_match_args(self, model_args, inferencer_args): + with patch("lmflow.pipeline.vllm_inferencer.AutoTokenizer") as mock_tok: + mock_tok.from_pretrained.return_value = MagicMock(eos_token_id=151643) + from lmflow.pipeline.vllm_inferencer import VLLMInferencer + + inferencer = VLLMInferencer(model_args, DatasetArguments(dataset_path=None), inferencer_args) + params = inferencer.sampling_params + + assert params["n"] == 2 + assert params["max_new_tokens"] == 128 + assert params["seed"] == 42 + assert params["top_p"] == 0.95 + assert params["top_k"] == 50 + assert abs(params["temperature"] - 1.0) < 1e-4 + assert 151643 in params["stop_token_ids"] + + def test_override_with_inference_args(self, model_args, inferencer_args): + with patch("lmflow.pipeline.vllm_inferencer.AutoTokenizer") as mock_tok: + mock_tok.from_pretrained.return_value = MagicMock(eos_token_id=151643) + from lmflow.pipeline.vllm_inferencer import VLLMInferencer + + inferencer = VLLMInferencer(model_args, DatasetArguments(dataset_path=None), inferencer_args) + + override_args = InferencerArguments( + inference_engine="vllm", + temperature=0.5, + max_new_tokens=256, + num_output_sequences=4, + ) + new_params = inferencer._parse_args_to_sampling_params(override_args) + assert new_params["max_new_tokens"] == 256 + assert new_params["n"] == 4 + assert abs(new_params["temperature"] - 0.5) < 1e-4 + + +class TestDataProtoSaveLoad: + """Test DataProto pickle round-trip used by save/load_inference_results.""" + + def test_roundtrip(self): + proto = DataProto.from_single_dict( + data={"inputs": np.array(["Hello", "World"]), "outputs": np.array(["Hi", "Earth"])}, + meta_info={"sampling_params": {"n": 1, "temperature": 1.0}}, + ) + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f: + path = f.name + + proto.save_to_disk(path) + loaded = DataProto.load_from_disk(path) + + assert len(loaded) == 2 + assert list(loaded.non_tensor_batch["inputs"]) == ["Hello", "World"] + assert list(loaded.non_tensor_batch["outputs"]) == ["Hi", "Earth"] + assert loaded.meta_info["sampling_params"]["n"] == 1 + + def test_save_load_uses_dir(self): + """VLLMInferencer saves inference_results.pkl inside the given directory.""" + with patch("lmflow.pipeline.vllm_inferencer.AutoTokenizer") as mock_tok: + mock_tok.from_pretrained.return_value = MagicMock(eos_token_id=0) + from lmflow.pipeline.vllm_inferencer import VLLMInferencer + + args = InferencerArguments(inference_engine="vllm") + inferencer = VLLMInferencer( + ModelArguments(model_name_or_path="dummy"), DatasetArguments(dataset_path=None), args + ) + + proto = DataProto.from_single_dict(data={"inputs": np.array(["a"])}) + with tempfile.TemporaryDirectory() as tmpdir: + results_dir = f"{tmpdir}/results" + os.makedirs(results_dir) + inferencer.save_inference_results(proto, results_dir) + assert os.path.exists(os.path.join(results_dir, "inference_results.pkl")) + loaded = inferencer.load_inference_results(results_dir) + assert len(loaded) == 1 + + +class TestPrepareInputsDataProto: + """Test that prepare_inputs_for_inference creates a proper DataProto for vllm.""" + + def test_creates_dataproto_with_repeat(self): + """Simulate what prepare_inputs_for_inference does for vllm.""" + prompts = ["prompt_a", "prompt_b"] + sampling_params = {"n": 3, "temperature": 1.0} + + inference_inputs = np.array(prompts) + proto = DataProto.from_single_dict( + data={"inputs": inference_inputs}, + meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]}, + ) + proto = proto.repeat(sampling_params["n"]) + + # 2 prompts * n=3 = 6 rows + assert len(proto) == 6 + assert proto.meta_info["sampling_params"]["n"] == 1 + assert proto.meta_info["actual_n_rollouts"] == 3 + + inputs_list = proto.non_tensor_batch["inputs"].tolist() + # repeat interleaves: [a, a, a, b, b, b] + assert inputs_list == ["prompt_a"] * 3 + ["prompt_b"] * 3 + + +vllm = pytest.importorskip("vllm") + +from lmflow.datasets.dataset import Dataset +from lmflow.models.hf_decoder_model import HFDecoderModel +from lmflow.pipeline.vllm_inferencer import VLLMInferencer +from tests.datasets.conftest import dataset_inference_conversation_batch # noqa: F401 + + +@pytest.fixture +def vllm_test_model_args() -> ModelArguments: + return ModelArguments(model_name_or_path="Qwen/Qwen3-0.6B") + + +@pytest.fixture +def vllm_test_inferencer_args() -> InferencerArguments: + return InferencerArguments( + inference_engine="vllm", + inference_gpu_memory_utilization=0.8, + num_output_sequences=2, + temperature=1.0, + max_new_tokens=64, + top_p=0.95, + random_seed=42, + ) + + +@pytest.mark.gpu +def test_vllm_inferencer( + dataset_inference_conversation_batch: Dataset, # noqa: F811 + vllm_test_model_args: ModelArguments, + vllm_test_inferencer_args: InferencerArguments, +): + model = HFDecoderModel(model_args=vllm_test_model_args) + inferencer = VLLMInferencer( + data_args=dataset_inference_conversation_batch.data_args, + model_args=vllm_test_model_args, + inferencer_args=vllm_test_inferencer_args, + ) + res = inferencer.inference( + model=model, + dataset=dataset_inference_conversation_batch, + release_gpu=True, + ) + + # DataProto structure checks + assert isinstance(res, DataProto) + + # 2 conversations * n=2 = 4 rows + assert len(res) == 4 + + # Has inputs and outputs in non_tensor_batch + assert "inputs" in res.non_tensor_batch + assert "outputs" in res.non_tensor_batch + assert len(res.non_tensor_batch["inputs"]) == 4 + assert len(res.non_tensor_batch["outputs"]) == 4 + + # Each output should be a non-empty string + for output in res.non_tensor_batch["outputs"]: + assert isinstance(output, str) + assert len(output) > 0 + + # Sampling params in meta_info + assert "sampling_params" in res.meta_info + assert res.meta_info["sampling_params"]["n"] == 1 + + # Inputs repeat pattern: [conv1, conv1, conv2, conv2] + inputs = res.non_tensor_batch["inputs"].tolist() + assert inputs[0] == inputs[1] + assert inputs[2] == inputs[3] + assert inputs[0] != inputs[2] + + +@pytest.mark.gpu +def test_vllm_inferencer_save_load( + dataset_inference_conversation_batch: Dataset, # noqa: F811 + vllm_test_model_args: ModelArguments, + vllm_test_inferencer_args: InferencerArguments, +): + model = HFDecoderModel(model_args=vllm_test_model_args) + inferencer = VLLMInferencer( + data_args=dataset_inference_conversation_batch.data_args, + model_args=vllm_test_model_args, + inferencer_args=vllm_test_inferencer_args, + ) + res = inferencer.inference( + model=model, + dataset=dataset_inference_conversation_batch, + release_gpu=True, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + results_dir = os.path.join(tmpdir, "results") + os.makedirs(results_dir) + inferencer.save_inference_results(res, results_dir) + loaded = inferencer.load_inference_results(results_dir) + + assert len(loaded) == len(res) + assert list(loaded.non_tensor_batch["inputs"]) == list(res.non_tensor_batch["inputs"]) + assert list(loaded.non_tensor_batch["outputs"]) == list(res.non_tensor_batch["outputs"])