Skip to content

sirine-b/MultiDiffSense

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MultiDiffSense: Diffusion-Based Multi-Modal Visuo-Tactile Image Generation

Paper License HuggingFace

Official implementation of MultiDiffSense, a unified ControlNet-based diffusion model that generates realistic and physically grounded tactile sensor images across three different sensor types (ViTac, TacTip, ViTacTip) from a single model, conditioned on CAD-derived depth maps and structured text prompts encoding contact pose and sensor modality.

MultiDiffSense Pipeline


Table of Contents


Overview

MultiDiffSense leverages ControlNet (built on Stable Diffusion 1.5) to translate depth map renderings of 3D objects into realistic tactile sensor images across three sensor modalities:

  • TacTip -- Optical tactile sensor with pin-based deformation markers
  • ViTac -- Vision-based tactile sensor (no markers)
  • ViTacTip -- Hybrid vision-tactile sensor

The model is conditioned on:

  1. Depth maps (rendered from STL files) as spatial control signals
  2. Text prompts describing the 4-DOF contact pose, and target sensor type

Quick Start (Pre-trained Model)

Generate tactile images directly using the pre-trained checkpoint (conditioned on short prompts + depth maps) from Hugging Face -- no training required. The checkpoint is downloaded automatically on first run.

pip install huggingface_hub

# Option 1: From a single depth map + text prompt:
python multidiffsense/controlnet/generate.py \
    --source_image path/to/depth_map.png \
    --prompt '{"sensor_context": "captured by a high-resolution vision only sensor ViTac.", "object_pose": {"x": 0.12, "y": -0.34, "z": 1.5, "yaw": 15.0}}'

# Option 2: From a prompt file (batch) -- each line contains a depth map path and prompt:
python multidiffsense/controlnet/generate.py \
    --dataset_dir datasets \
    --prompt_json datasets/test/prompt_ViTacTip.json

_ Note: For Option 2: each line in the prompt file is a JSON object that specifies the depth map path (relative to --dataset_dir), the text prompt and in case of training/testing the target path (in case of inference = no target image):

{"source": "source/1_0.png", "target": "target/1_ViTacTip_0.png", "prompt": {"sensor_context": "captured by a high-resolution vision only sensor ViTac.", "object_pose": {"x": 0.12, "y": -0.34, "z": 1.5, "yaw": 15.0}}}

For training from scratch, dataset preparation, evaluation, and ablation studies, see the sections below.


Repository Structure

MultiDiffSense/
|-- cldm/                             # ControlNet modules (at repo root)
|   |-- cldm.py                      # ControlledUNet, ControlNet, ControlLDM
|   |-- ddim_hacked.py               # DDIM sampler for ControlNet
|   |-- hack.py                      # CLIP and attention hacks
|   |-- logger.py                    # Image logging callback
|   |-- loss_plotter.py              # Training loss visualisation callback
|   +-- model.py                     # Model creation and checkpoint loading
|
|-- configs/                          # Configuration files
|   |-- controlnet_train.yaml         # Training config (short prompts)
|   +-- controlnet_train_long_prompt.yaml  # Training config (long prompts, ablation 2)
|
|-- ldm/                              # Latent Diffusion Model core (from CompVis/stable-diffusion)
|   |-- data/                        # Data utilities
|   |-- models/
|   |   |-- diffusion/               # DDPM, DDIM, PLMS samplers
|   |   +-- autoencoder.py           # VQ-VAE encoder/decoder
|   +-- modules/                     # UNet, attention, encoders, EMA, distributions
|
|-- multidiffsense/                   # Core source code
|   |-- controlnet/                   # ControlNet training/testing scripts
|   |   |-- train.py                 # Training (supports --no_prompt / --no_source ablation)
|   |   |-- test.py                  # Testing with quantitative metrics
|   |   |-- generate.py              # Inference-only generation
|   |   +-- data_loader.py           # Dataset class for ControlNet
|   |
|   |-- baseline_cgan/               # Pix2Pix (cGAN) baseline
|   |   |-- train.py                 # cGAN training
|   |   |-- test.py                  # cGAN testing with same metrics
|   |   |-- dataset_converter.py     # Convert ControlNet format -> Pix2Pix format
|   |   +-- README.md                # Baseline-specific setup instructions
|   |
|   |-- data_preparation/            # Dataset building pipeline
|   |   |-- all_processing.py        # Orchestrator: run full pipeline per object
|   |   |-- source_processing.py     # Render depth maps from STL (target-driven alignment)
|   |   |-- target_processing.py     # Rename + resize tactile sensor images
|   |   |-- prompt_creation.py       # Generate prompt.json (short or long style)
|   |   |-- ds_creation.py           # Assemble mega dataset (merge per-object datasets)
|   |   |-- dataset_split.py         # Train/val/test splitting (70/15/15)
|   |   +-- modality_split.py        # Split prompts by sensor modality for evaluation
|   |
|   +-- evaluation/                  # Evaluation utilities
|       +-- metrics.py               # SSIM, PSNR, MSE, LPIPS, FID computation
|
|-- data/                             # Raw data directory (user-populated)
|   +-- example/                     # Minimal example dataset
|       |-- stl/                     # STL mesh files: <obj_id>.stl
|       |-- csv/                     # Per-object pose CSV: <obj_id>.csv
|       +-- tactile/                 # Tactile images per object/sensor
|           +-- <obj_id>/
|               |-- TacTip/target/
|               |-- ViTac/target/
|               +-- ViTacTip/target/
|
|-- datasets/                         # Assembled dataset (generated by pipeline)
|   |-- source/                      # All depth maps (shared across splits)
|   |-- target/                      # All tactile images (shared across splits)
|   |-- prompt.json                  # Merged short prompts
|   |-- prompt_long.json             # Merged long prompts (ablation 2)
|   |-- train/
|   |   |-- prompt.json
|   |   +-- prompt_long.json
|   |-- val/
|   |   |-- prompt.json
|   |   +-- prompt_long.json
|   +-- test/
|       |-- prompt.json
|       |-- prompt_long.json
|       |-- prompt_TacTip.json       # Per-modality splits (from modality_split)
|       |-- prompt_ViTac.json
|       +-- prompt_ViTacTip.json
|
|-- models/                          # Model checkpoints
|   +-- cldm_v15.yaml                # ControlNet + SD1.5 architecture config
|
|-- scripts/                         # Shell scripts for common workflows
|-- figures/                         # Figures included in README
|-- tool_add_control.py              # Utility: create ControlNet init weights from SD1.5
|-- requirements.txt                 # Python dependencies
+-- README.md                        # This file

Installation

Option A: Conda (recommended)

git clone https://github.com/sirine-b/MultiDiffSense.git
cd MultiDiffSense
conda env create -f environment.yml
conda activate multidiffsense

Option B: pip

git clone https://github.com/sirine-b/MultiDiffSense.git
cd MultiDiffSense
pip install -r requirements.txt

Pre-trained Weights

Download the Stable Diffusion v1.5 checkpoint and create the ControlNet initialisation weights:

bash scripts/prepare_model.sh

This will:

  1. Download v1-5-pruned.ckpt from Hugging Face
  2. Run tool_add_control.py to produce models/control_sd15_ini.ckpt

Dataset Preparation

The full pipeline builds the training dataset from raw data in 4 steps.

Expected raw data structure:

data/example/
|-- stl/          # STL mesh files: <obj_id>.stl
|-- csv/          # Pose CSV files: <obj_id>.csv
+-- tactile/      # Tactile images per object/sensor
    +-- <obj_id>/
        |-- TacTip/target/
        |-- ViTac/target/
        +-- ViTacTip/target/

Step 1: Per-Object Processing

Process one or more objects end-to-end across all three sensor modalities:

python -m multidiffsense.data_preparation.all_processing \
    --stl_dir data/example/stl \
    --csv_dir data/example/csv \
    --tactile_dir data/example/tactile \
    --obj_ids 1

Processing order per object:

Step ViTac (1st) TacTip (2nd) ViTacTip (3rd)
Target processing Rename + resize Rename + resize Rename + resize
Source processing Generate from STL Copy from ViTac Copy from ViTac
Prompt creation Generate from CSV Generate from CSV Generate from CSV

Why ViTac first? Source (depth map) generation aligns each frame by extracting the object from the tactile image to determine its bounding box and centre position. ViTac images are vision-only with no pin markers on the sensor surface, making the object boundary much clearer and easier to segment than TacTip (pin markers) or ViTacTip (hybrid markers). Since the source depth maps represent the same object at the same pose regardless of sensor, they are generated once from ViTac and copied to the other two modalities.

The pipeline iterates only over frames that actually exist in the target directory (not the CSV row count), so missing or removed frames are handled gracefully.

Under the hood, this runs three sub-steps per sensor:

  1. Target processing (target_processing.py) -- renames raw tactile images to <obj_id>_<sensor>_<frame>.png and resizes to 512x512.
  2. Source processing (source_processing.py) -- uses target-driven alignment: segments the object in each target frame, then resizes, rotates, and positions the CAD depth map to match the target exactly. Uses Otsu's automatic thresholding (no per-object tuning). Only runs for ViTac; source images are copied to TacTip and ViTacTip.
  3. Prompt creation (prompt_creation.py) -- reads the per-object CSV and writes a JSONL prompt file. Supports --prompt_style short (default) or --prompt_style long for ablation studies (see Ablation Studies).

Step 2: Assemble Mega Dataset

Merge per-object datasets across all three sensor modalities into a single dataset:

python -m multidiffsense.data_preparation.ds_creation \
    --tactile_dir data/example/tactile \
    --output_dir datasets \
    --object_ids 1 \
    --sensors TacTip ViTac ViTacTip

This copies all source/target images into flat source/ and target/ directories and merges all per-object prompt.json files into one.

Step 3: Train/Val/Test Split

python -m multidiffsense.data_preparation.dataset_split \
    --base_dir datasets \
    --seed 16

Splits the merged prompt.json into train/, val/, test/ subdirectories (70/15/15). Groups by source image so all sensor modalities for the same contact stay in the same split. Images remain in the parent datasets/source/ and datasets/target/; only prompt files are placed in the split subdirectories.

Step 4: Per-Modality Split (for evaluation)

python -m multidiffsense.data_preparation.modality_split \
    --prompt_path datasets/test/prompt.json \
    --output_dir datasets/test

Creates prompt_TacTip.json, prompt_ViTac.json, and prompt_ViTacTip.json for per-sensor evaluation.

Final dataset layout:

datasets/
|-- source/            # All depth maps (shared across splits)
|-- target/            # All tactile images (shared across splits)
|-- prompt.json        # All samples
|-- train/prompt.json  # Train split (prompt entries only, no images)
|-- val/prompt.json    # Val split
+-- test/
    |-- prompt.json
    |-- prompt_TacTip.json
    |-- prompt_ViTac.json
    +-- prompt_ViTacTip.json

Training

ControlNet (MultiDiffSense)

python multidiffsense/controlnet/train.py \
    --config configs/controlnet_train.yaml \
    --batch_size 8 \
    --lr 1e-5 \
    --max_epochs 150 \
    --sd_locked

Key training parameters:

Parameter Default Description
batch_size 8 Training batch size
lr 1e-5 Learning rate
max_epochs 150 Maximum training epochs
sd_locked True Freeze Stable Diffusion backbone
precision 32 Training precision
early_stop_patience 10 Early stopping patience

Training logs and checkpoints are saved to results/lightning_logs/.


Testing & Evaluation

# Test on seen objects (test split)
python multidiffsense/controlnet/test.py \
    --config configs/controlnet_train.yaml \
    --checkpoint path/to/best_checkpoint.ckpt \
    --modality ViTacTip \
    --seen_objects \
    --output_dir results/test_seen

# Test on unseen objects
python multidiffsense/controlnet/test.py \
    --config configs/controlnet_train.yaml \
    --checkpoint path/to/best_checkpoint.ckpt \
    --modality TacTip \
    --output_dir results/test_unseen

Reported metrics (computed per-image and aggregated): SSIM (Structural Similarity Index), PSNR (Peak Signal-to-Noise Ratio in dB), MSE (Mean Squared Error), LPIPS (Learned Perceptual Image Patch Similarity, AlexNet), and FID (Frechet Inception Distance).

Results are saved as a CSV file and visual grids (control | target | generated).


Ablation Studies

All ablations involve retraining the model (not just test-time flag changes). Each ablation produces a separate checkpoint that is then evaluated.

Ablation 1: Conditioning Modality

Tests the contribution of each conditioning signal by training without it.

1a. Source only (no text prompt) -- train with empty prompts, depth map conditioning only:

# Train
python multidiffsense/controlnet/train.py \
    --config configs/controlnet_train.yaml \
    --no_prompt \
    --output_suffix _no_prompt

# Test (use the no-prompt checkpoint, with --no_prompt to match)
python multidiffsense/controlnet/test.py \
    --config configs/controlnet_train.yaml \
    --checkpoint results_no_prompt/lightning_logs/.../best.ckpt \
    --modality ViTacTip \
    --no_prompt \
    --output_dir results/ablation_no_prompt

1b. Prompt only (no depth map) -- train with blank source images, text prompt conditioning only:

# Train
python multidiffsense/controlnet/train.py \
    --config configs/controlnet_train.yaml \
    --no_source \
    --output_suffix _no_source

# Test
python multidiffsense/controlnet/test.py \
    --config configs/controlnet_train.yaml \
    --checkpoint results_no_source/lightning_logs/.../best.ckpt \
    --modality ViTacTip \
    --no_source \
    --output_dir results/ablation_no_source

The --output_suffix flag appends to the output directory (e.g. results_no_prompt/) so checkpoints from different ablations don't overwrite each other.

Ablation 2: Prompt Richness (Short vs Long)

Tests whether richer text prompts improve generation quality.

Short prompt (default): sensor context + object pose.

{"sensor_context": "captured by a high-resolution vision only sensor ViTac.",
 "object_pose": {"x": 0.12, "y": -0.34, "z": 1.5, "yaw": 15.0}}

Long prompt: object description + contact description + sensor context + style tags + negatives + object pose.

{"object_description": "A edge-shaped object with distinct geometric features",
 "contact_description": "Medium contact on the object surface with moderate indentation",
 "sensor_context": "Captured by a high-resolution vision only sensor ViTac",
 "style_tags": "High quality, detailed texture, realistic tactile response, sharp sensor reading",
 "negatives": "Blurry, low quality, artifacts, noise, distortion",
 "object_pose": {"x": 0.12, "y": -0.34, "z": 1.5, "yaw": 15.0}}

Workflow -- generate both prompt types, then train separately:

# Step 1: Generate short prompts (default, already done in normal pipeline)
python -m multidiffsense.data_preparation.all_processing \
    --stl_dir data/example/stl --csv_dir data/example/csv \
    --tactile_dir data/example/tactile --obj_ids 1 \
    --prompt_style short

# Step 2: Generate long prompts (saved as prompt_long.json alongside prompt.json)
python -m multidiffsense.data_preparation.all_processing \
    --stl_dir data/example/stl --csv_dir data/example/csv \
    --tactile_dir data/example/tactile --obj_ids 1  \
    --prompt_style long

# Step 3: Assemble + split both
python -m multidiffsense.data_preparation.ds_creation \
    --tactile_dir data/example/tactile --output_dir datasets \
    --object_ids 1 --prompt_style short
python -m multidiffsense.data_preparation.ds_creation \
    --tactile_dir data/example/tactile --output_dir datasets \
    --object_ids 1 --prompt_style long

python -m multidiffsense.data_preparation.dataset_split --base_dir datasets --prompt_style short
python -m multidiffsense.data_preparation.dataset_split --base_dir datasets --prompt_style long

# Step 4: Train with short prompts (default config)
python multidiffsense/controlnet/train.py \
    --config configs/controlnet_train.yaml

# Step 5: Train with long prompts (separate config pointing to prompt_long.json)
python multidiffsense/controlnet/train.py \
    --config configs/controlnet_train_long_prompt.yaml

Short and long prompts coexist in the same dataset directory -- prompt.json and prompt_long.json sit side by side, sharing the same source/target images.


Baseline Comparison (cGAN / Pix2Pix)

We compare against Pix2Pix as a conditional GAN baseline using the pytorch-CycleGAN-and-pix2pix framework.

Setup

git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix external/pytorch-CycleGAN-and-pix2pix

Convert Dataset Format

python multidiffsense/baseline_cgan/dataset_converter.py \
    --controlnet_dataset datasets \
    --output_path external/pytorch-CycleGAN-and-pix2pix/datasets/depth_to_sensor \
    --modality TacTip

Train Pix2Pix

cd external/pytorch-CycleGAN-and-pix2pix
python train.py \
    --dataroot datasets/depth_to_sensor \
    --name depth_to_sensor_experiment \
    --model pix2pix \
    --direction AtoB \
    --n_epochs 200 \
    --n_epochs_decay 100

Test Pix2Pix

bash scripts/test_pix2pix.sh

The test script computes the same metrics (SSIM, PSNR, MSE, LPIPS, FID) for fair comparison.


Inference / Generation

Generate tactile images from depth maps without ground truth targets:

python multidiffsense/controlnet/generate.py \
    --config configs/controlnet_train.yaml \
    --checkpoint path/to/best_checkpoint.ckpt \
    --dataset_dir datasets \
    --prompt_json datasets/test/prompt_ViTacTip.json \
    --output_dir results/generated

Example Data

The data/example/ directory contains a minimal working example with:

  • 1 object across 3 sensor modalities (TacTip, ViTac, ViTacTip)
  • Per-object pose CSV in csv/<obj_id>.csv (tab-separated, 4-DOF pose)
  • STL source file in stl/<obj_id>.stl
  • Tactile images in tactile/<obj_id>/<sensor_type>/target/

To verify your installation, run the per-object pipeline on the example data:

python -m multidiffsense.data_preparation.all_processing \
    --stl_dir data/example/stl \
    --csv_dir data/example/csv \
    --tactile_dir data/example/tactile \
    --obj_ids 1

Citation

If you find this work useful, please cite:

@inproceedings{multidiffsense2026,
    title     = {MultiDiffSense: Diffusion-Based Multi-Modal Visuo-Tactile Image Generation Conditioned on Object Shape and Contact Pose},
    author    = {Sirine Bhouri and Lan Wei and Jian-Qing Zheng and Dandan Zhang},
    booktitle = {IEEE International Conference on Robotics and Automation (ICRA)},
    year      = {2026}
    url       = {https://arxiv.org/abs/2602.19348}
}

Acknowledgements


License

This project is licensed under the MIT License -- see LICENSE for details.

About

Official implementation of MultiDiffSense, a unified ControlNet-based diffusion model that generates realistic and physically grounded tactile sensor images across three different sensor types (ViTac, TacTip, ViTacTip) from a single model, conditioned on CAD-derived depth maps and structured text prompts encoding contact pose and sensor modality.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors