[Data] Apply DataProto to vLLM Inference & Align API with SGLang#967
[Data] Apply DataProto to vLLM Inference & Align API with SGLang#967wheresmyhair merged 2 commits intomainfrom
Conversation
Code reviewFound 1 issue:
LMFlow/examples/vllm_inference.py Lines 40 to 45 in dee43cf LMFlow/src/lmflow/models/hf_model_mixin.py Lines 559 to 581 in dee43cf Review prepared by @Jingyuan-zhu. |
|
Updated the |
|
Re-checked at 7408430. The docstring/example mismatch I flagged is fully fixed: |
Overview
DataPrototo vllm inference pipeline, aligning its API with the sglang inferencer introduced in Unified data exchange protocol across modules #960. This unifies data exchange across inference engines and modernizes the vllm integration.Detailed Description
DataProto integration
VLLMInferencernow returnsDataProtoinstead oflist[VLLMInferenceResultWithInput], with prompts innon_tensor_batch["inputs"]and generated text innon_tensor_batch["outputs"]prepare_inputs_for_inferencecreatesDataProtofor both sglang and vllm through a unified code path__vllm_inferenceinHFDecoderModelextracts prompts and sampling params fromDataProto, converts tovllm.SamplingParams, and stores outputs back into the protoDataProto.save_to_disk/load_from_diskinference_results_pathnow accepts a directory — results are automatically saved asinference_results.pklinside itAPI alignment with sglang and modernization
VLLMInferencernow mirrorsSGLangInferencerInferencerWithOffloadingbase class and all Ray-based distributed inference code -- vllm >= 0.8 supportsdata_parallel_sizenatively invllm.LLM(), using a multiprocessing backend with no Ray dependency--inference_data_parallel_sizeargumenttensor_parallel_size × data_parallel_sizeuse_beam_searchfrom sampling params (dropped in vLLM V1), added deprecation warningdeactivate_model_for_inference— old cleanup code referencedllm_engine.model_executor.driver_workerwhich no longer exists in V1--inference_max_model_lento cap context length (prompt and output) for models with large defaults>=0.4.3to>=0.8.0insetup.pyFiles changed
src/lmflow/pipeline/vllm_inferencer.pysrc/lmflow/models/hf_decoder_model.pysrc/lmflow/models/hf_model_mixin.pysrc/lmflow/args.pysrc/lmflow/pipeline/sglang_inferencer.pysrc/lmflow/pipeline/utils/memory_safe_vllm_inference.pyexamples/vllm_inference.pyscripts/run_vllm_inference.shscripts/run_sglang_inference.shsetup.pytests/pipeline/test_vllm_inferencer.pyDownstream impact
MemorySafeVLLMInferenceris updated to returnDataProto.iterative_dpo_aligner.pyconsumesMemorySafeVLLMInferencerand will need a separate update to handleDataProtoinstead oflist[VLLMInferenceResultWithInput].Tests
scripts/run_vllm_inference.shend-to-end with target model