Paper | Hugging Face | ReXRank Leaderboard
CheXOne is a reasoning-aware vision-language model for chest X-ray interpretation.
✨ Key Features:
- Reasoning Capability: Generates explicit reasoning traces alongside final answers.
- Multi-Task Support: Handles Visual Question Answering (VQA), report generation, and visual grounding tasks.
- Resident-Level Reporting: Achieves report quality that matches or surpasses resident-written reports in 50+% of studied cases.
- Two Inference Modes:
- Reasoning Mode: Higher performance with explicit reasoning traces.
- Instruct Mode: Faster inference without reasoning traces.
This code release includes:
- Step-by-step instructions to reproduce our methodology.
- Data preparation scripts for CheXInstruct-v2 and CheXReason.
- Complete training code, including instruction tuning and GRPO.
- Complete inference code, with evaluation for our model and comparative baselines.
- User study scripts and related documentation.
- Code for generating publication figures.
https://github.com/YBZh/CheXOne.git
cd CheXOne
pip install -e .Training and Fast Inference Environment:
| Range | Recommended | Notes | |
|---|---|---|---|
| python | >=3.10,<3.12 | 3.10/3.11 | |
| cuda | 12.x | cuda12 | No need to install if using CPU, NPU, MPS |
| torch | >=2.0 | 2.7.1 | |
| transformers | >=4.33 | 4.56.2 | |
| modelscope | >=1.23 | 1.30.0 | |
| peft | >=0.11,<0.18 | 0.17.1 | |
| flash_attn | 2.5.8 | ||
| trl | >=0.15,<0.21 | 0.20.0 | RLHF |
| deepspeed | >=0.14 | 0.17.6 | Training |
| vllm | >=0.5.1 | 0.10.1.1 | Inference/Deployment |
| sglang | >=0.4.6 | 0.4.10.post2 | Inference/Deployment |
| lmdeploy | >=0.5 | 0.10.1 | Inference/Deployment |
| evalscope | >=1.0 | 1.0.2 | Evaluation |
| gradio | 5.32.1 | Web-UI/App |
For more optional dependencies, you can refer to here.
CheXOne is post-trained on the Qwen2.5VL-3B-Instruct model, which is integrated in the latest HuggingFace Transformers. We advise you to build transformers from source as follows:
pip install git+https://github.com/huggingface/transformers accelerateBelow is an example usage to get started:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
# Load the model on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"StanfordAIMI/CheXOne", torch_dtype="auto", device_map="auto"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# "StanfordAIMI/CheXOne",
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
# Default processor
processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXOne")
# The default range for the number of visual tokens per image in the model is 4-16384.
# We recommend to set max_pixels=512*512 to align with the training setting.
# min_pixels = 256*28*28
# max_pixels = 512*512
# processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXOne", min_pixels=min_pixels, max_pixels=max_pixels)
# Inference Mode: Reasoning
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://github.com/YBZh/CheXOne/blob/main/asset/cxr.jpg",
},
{
"type": "text",
"text": "Write an example findings section for the CXR. Please reason step by step, and put your final answer within \\boxed{{}}.",
},
],
}
]
# Inference Mode: Instruct
# messages = [
# {
# "role": "user",
# "content": [
# {
# "type": "image",
# "image": "https://github.com/YBZh/CheXOne/blob/main/asset/cxr.jpg",
# },
# {
# "type": "text",
# "text": "Write an example findings section for the CXR.",
# },
# ],
# }
# ]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)- Images: Please download the CXR images following the instructions in the corresponding datasets: 📂 data.md
- Texts:
See: examples/train/chexone/train_script/1_sft.sh
This step performs supervised fine-tuning using curated CheXinstruct-v2 and CheXReason.
See: examples/train/chexone/grpo_prepare
To ensure strong learning signals for GRPO, we filter out low-variance samples. For each candidate, several stochastic model runs are used to estimate reward variance, and only the high informative samples in each category—those with highest reward variance—are selected. This strategy improves GRPO effectiveness and efficiency.
See: examples/train/chexone/train_script/2_grpo.sh
This step further optimizes the model with the GRPO algorithm to improve reasoning capabilities and robustness.
1. Normal Inference as stated in Quick Start
See: examples/train/chexone/prepare_testing/3_inference_code/CheXOne.sh
See: examples/train/chexone/prepare_testing/3_inference_code/OtherModels
See: examples/train/chexone/Reader-Study
This framework is licensed under the Apache License (Version 2.0). For models and datasets, please refer to the original resource page and follow the corresponding License.
