Skip to content

fudan-generative-vision/PromptReinjection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

3 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

[ICML 2026] Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers

Yuxuan Yao1,2,* โ€ƒ Yuxuan Chen1,* โ€ƒ Hui Li1 โ€ƒ Kaihui Cheng1 โ€ƒ Qipeng Guo3 โ€ƒ Yuwei Sun4 โ€ƒ
Zilong Dong5 โ€ƒ Jingdong Wang6 โ€ƒ Siyu Zhu1,2 โ€ƒ

1Fudan University โ€ƒ 2Shanghai Innovation Institute
3Shanghai AI Laboratory โ€ƒ 4Shanghai Academy of AI for Science โ€ƒ 5Alibaba Group โ€ƒ 6Baidu

Paper arXiv License

๐Ÿ“– Introduction

Prompt Reinjection is a training-free inference method for multimodal diffusion transformers that mitigates prompt forgetting by reinjecting early-layer prompt features into deeper text layers, improving GenEval overall scores by 6.48% on SD3.5-large and 7.75% on HunyuanImage-2.1, while adding only about 0.00002x block-level FLOPs for base reinjection and 0.088x for the full aligned variant.

๐Ÿ“˜ Overview

Prompt Reinjection starts from a simple observation: in multimodal diffusion transformers (MMDiTs) such as SD3-medium, SD3.5-large, FLUX.1, HunyuanImage-2.1, and Qwen-Image, prompt information fades as depth increases. That makes instruction following weaker, especially on position, attributes, counting, and long prompts.

Prompt Forgetting

Unlike traditional DiTs, where text serves as a relatively stable conditioning signal, MMDiTs jointly update text and image tokens throughout denoising, even though the text tokens receive no direct supervision. The paper shows that deeper text features gradually lose fine-grained prompt semantics, a phenomenon we call prompt forgetting.

Prompt Forgetting

The figure above captures the core trend: for SD3, SD3.5, and FLUX, prompt information becomes less recoverable in deeper layers. This helps explain why base MMDiT models often miss spatial relations, attributes, and numeracy constraints in generation.

Prompt Reinjection

Prompt Reinjection fixes this at inference time. It takes semantically stronger text features from an early layer, aligns them to the deeper feature space, and reinjects them into later blocks so prompt constraints stay active through the full denoising stack.

Prompt Reinjection

The method is training-free, lightweight, and easy to plug into the original MMDiT forward process. No retraining is required.

Qualitative Comparison

As shown above, Prompt Reinjection makes SD3.5, FLUX, and Qwen-Image follow prompt constraints more consistently across position, attribute, counting, and complex prompts. The paper reports consistent gains on GenEval, DPG-Bench, and T2I-CompBench++ while preserving overall generation quality.

๐Ÿš€ Quick Start

1. Create an environment

We recommend Python 3.10+ and one environment per model family.

python3.10 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip

Install dependencies with either an editable package install or the pinned root requirements:

pip install -e .
pip install -r requirements.txt

If you want a narrower per-model environment, install one model-specific requirement file instead:

pip install -r requirements/sd3.txt
pip install -r requirements/sd3.5.txt
pip install -r requirements/flux.txt
pip install -r requirements/qwen.txt
pip install -r requirements/hunyuanimage.txt

If you need a specific CUDA build, install torch and torchvision first from the official PyTorch channel, then rerun one of the commands above.

For HunyuanImage, install Tencent's official runtime before requirements/hunyuanimage.txt:

git clone https://github.com/Tencent-Hunyuan/HunyuanImage-2.1.git
pip install -r HunyuanImage-2.1/requirements.txt
pip install flash-attn==2.7.3 --no-build-isolation
pip install -r requirements/hunyuanimage.txt

2. Prepare checkpoints

Pass model paths explicitly with --model-path in open-source usage.

  • sd3, sd3.5, flux, qwen: --model-path /path/to/model
  • hunyuanimage: --model-path /path/to/HunyuanImage-2.1 and optional --model-name

For hunyuanimage, --model-path can point either to the HunyuanImage runtime root or to its ckpts directory.

3. Run one prompt

The default helper script reads the released per-model inference and Prompt Reinjection settings from prompt_reinjection/reinjection_configs.json, so standard inference does not need manual residual arguments.

By default, all models now run without memory-saving inference shortcuts such as CPU offload, Hunyuan runtime offload, VAE slicing, VAE tiling, or attention slicing. This keeps the default path as close as possible to the original plain inference flow. These options are enabled only when you pass them explicitly.

bash prompt_reinjection/test_reinjection.sh \
  --model sd3 \
  --model-path /path/to/SD3 \
  --prompt "A photo of a couch below a potted plant."

Supported --model values: sd3, sd3.5, flux, qwen, hunyuanimage.

To change the default released settings for a model, edit prompt_reinjection/reinjection_configs.json. The helper script and the benchmark entrypoints all read from the same file.

To run the plain base model instead of Prompt Reinjection:

bash prompt_reinjection/test_reinjection.sh \
  --model sd3 \
  --model-path /path/to/SD3 \
  --prompt "A photo of a couch below a potted plant." \
  --reinjection off

To enable memory-saving options explicitly:

  • flux and qwen: pass --cpu-offload model or --cpu-offload sequential
  • flux and qwen: optionally add --vae-slicing, --vae-tiling, or --attention-slicing auto
  • hunyuanimage: pass --enable-offload

Example:

bash prompt_reinjection/test_reinjection.sh \
  --model flux \
  --model-path /path/to/FLUX.1-dev \
  --prompt "A photo of a couch below a potted plant." \
  --cpu-offload model \
  --vae-slicing

๐Ÿงญ Manual SD3 Example

If you want to bypass the helper script and set the reinjection parameters manually, an SD3 example is:

python -m prompt_reinjection.run_sample \
  --model sd3 \
  --model-path /path/to/SD3 \
  --prompt "A photo of a couch below a potted plant." \
  --output outputs/manual_sd3.png \
  --steps 28 \
  --cfg 7.0 \
  --residual_origin_layer 1 \
  --residual_target_layers $(seq 2 23) \
  --residual_weights 0.025 \
  --residual_use_anchoring 1 \
  --residual_procrustes_path prompt_reinjection/rotations/sd3_coco5k_o1.pt

The same default memory policy also applies to the Python entrypoints. If you do not pass a memory-saving flag explicitly, the run stays on the plain inference path. For example, FLUX and Qwen only enable Diffusers offload when you pass --cpu-offload model or --cpu-offload sequential, and HunyuanImage only enables its runtime offload when you pass --enable-offload.

๐Ÿงฎ Procrustes Precomputation

We use COCO 5k for the released Procrustes statistics.

For open-source usage, we recommend:

  • sd3 and flux: use the released Procrustes-aligned Prompt Reinjection settings.
  • sd3.5, qwen, and hunyuanimage: use the most basic Prompt Reinjection variant without anchoring and without rotation. It already works well and adds almost zero inference cost.

SD3

python SD3/compute.py \
  --model /path/to/SD3 \
  --dataset coco \
  --datadir data \
  --num-samples 5000 \
  --origin-layer 1 \
  --target-layer-start 2 \
  --col-center \
  --output outputs/procrustes_rotations/sd3_coco5k_o1.pt

FLUX

python FLUX/compute.py \
  --model /path/to/FLUX.1-dev \
  --dataset coco \
  --datadir data \
  --num-samples 5000 \
  --origin-layer 2 \
  --target-layer-start 3 \
  --col-center \
  --output outputs/procrustes_rotations/flux_coco5k_o2.pt

๐Ÿค— Released Rotations

The released Procrustes rotations are hosted at LewisYao/PromptReinjection:

Download them to prompt_reinjection/rotations/ with:

hf download LewisYao/PromptReinjection \
  sd3_coco5k_o1.pt \
  flux_coco5k_o2.pt \
  --local-dir prompt_reinjection/rotations

The default helper script will pick them up automatically once they are placed under prompt_reinjection/rotations/.

๐Ÿ“Š Evaluation

The benchmark scripts below also read the released per-model defaults from prompt_reinjection/reinjection_configs.json. By default, they run with --reinjection on. Use --reinjection off for the plain base model, or edit the config file if you want to change the released settings globally.

Like the helper script, these benchmark entrypoints do not enable CPU offload, Hunyuan runtime offload, VAE slicing, VAE tiling, or attention slicing unless you pass those flags explicitly.

GenEval

python -m prompt_reinjection.run_geneval \
  --model sd3 \
  --model-path /path/to/SD3 \
  --metadata_file /path/to/geneval/metadata.jsonl \
  --outdir outputs/geneval_sd3

Base-model variant:

python -m prompt_reinjection.run_geneval \
  --model sd3 \
  --model-path /path/to/SD3 \
  --metadata_file /path/to/geneval/metadata.jsonl \
  --outdir outputs/geneval_sd3_base \
  --reinjection off

DPG-Bench

python -m prompt_reinjection.run_dpg \
  --model sd3.5 \
  --model-path /path/to/SD3.5-large \
  --prompt_dir /path/to/dpg/prompts \
  --save_dir outputs/dpg_sd35

T2I-CompBench++

python -m prompt_reinjection.run_t2i \
  --model qwen \
  --model-path /path/to/Qwen-Image \
  --dataset_dir /path/to/t2i-compbench/prompts \
  --outdir_base outputs/t2i_qwen

These scripts generate benchmark-format images. Final scoring should still be done with the official benchmark evaluators.

๐Ÿงฉ Apply to New Models

Prompt Reinjection is designed for MMDiT-style open-source models where text features evolve inside the denoising transformer together with visual features. If your new model follows this pattern, you can usually add a basic reinjection version with only a small amount of integration work.

To plug a new model into this framework, the minimum steps are:

  • Add a new model folder with an adapter.py that implements the adapter interface used in prompt_reinjection/adapter_api.py and follows the existing examples in SD3/adapter.py, FLUX/adapter.py, and Qwen/adapter.py.
  • Make the model pipeline or transformer expose set_residual_config(...) so it can receive residual_origin_layer, residual_target_layers, residual_weights, residual_use_anchoring, and residual_rotation_matrices, as shown in SD3/pipeline.py, FLUX/pipeline.py, and Qwen/pipeline.py.
  • Register the new adapter in prompt_reinjection/registry.py so it becomes available to run_sample, run_geneval, run_dpg, run_t2i, and compute_procrustes.
  • If you want rotation-based alignment later, also provide a model-specific compute.py and expose it as the adapterโ€™s compute_script, so python -m prompt_reinjection.compute_procrustes --model YOUR_MODEL ... can dispatch correctly.

For a new model, we recommend starting from the most basic reinjection setting first:

  • origin = 1
  • target = 2-last
  • weight = 0.025
  • no anchoring
  • no rotation

In practice, that means using the shallowest stable MMDiT block as the source, reinjecting into all later blocks, and keeping the setup as lightweight as possible. If your model has L blocks indexed from 0 to L-1, the default starting rule is:

residual_origin_layer = 1
residual_target_layers = [2, 3, ..., L-1]
residual_weights = 0.025
residual_use_anchoring = 0
residual_procrustes_path = ""

A typical manual run looks like this after the model has been integrated into the registry:

python -m prompt_reinjection.run_sample \
  --model your_model \
  --model-path /path/to/your/model \
  --prompt "A photo of a couch below a potted plant." \
  --output outputs/your_model_base_reinjection.png \
  --residual_origin_layer 1 \
  --residual_target_layers $(seq 2 LAST_LAYER) \
  --residual_weights 0.025 \
  --residual_use_anchoring 0 \
  --residual_procrustes_path ""

Replace LAST_LAYER with the last text-processing block index of your model. For example, if the model has 24 blocks indexed from 0 to 23, use $(seq 2 23).

Once this basic version runs and already improves instruction following, the next recommended upgrades are:

  • Turn on anchoring first by setting --residual_use_anchoring 1. This is usually the safest first upgrade when the model shows cross-layer scale or shift mismatch.
  • If you want further gains, add rotation-based alignment by computing a Procrustes file on a prompt set such as COCO-5K and passing it through --residual_procrustes_path. This is useful when shallow and deep text features differ not only in scale, but also in feature geometry.

In short, the recommended order for a new model is: first make basic reinjection work, then try anchoring, and only then add rotation if you want the strongest alignment.

๐Ÿ“ Citation

If you find this project useful, please cite the ICML 2026 paper:

@inproceedings{yao2026prompt,
  title={Prompt Reinjection: Alleviating Prompt Forgetting in Multimodal Diffusion Transformers},
  author={Yao, Yuxuan and Chen, Yuxuan and Li, Hui and Cheng, Kaihui and Guo, Qipeng and Sun, Yuwei and Dong, Zilong and Wang, Jingdong and Zhu, Siyu},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2026}
}

๐Ÿค— Acknowledgements

We thank the open-source communities behind SD3, SD3.5, FLUX, Qwen-Image, and HunyuanImage. This release builds on their public model and runtime ecosystems to make Prompt Reinjection reproducible in open source.

About

Alleviating Prompt Forgetting in Multimodal Diffusion Transformers

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors