Regularized scVI with ambient RNA correction and overdispersion regularisation, based on cell2location/cell2fate modelling principles (Kleshchevnikov et al. 2022, Aivazidis et al. 2025, Simpson et al. 2017).
The modifications (ambient RNA correction, dispersion prior, batch-free decoder, learned library size) act as structural inductive biases that make a high-capacity model (n_hidden=512+, n_latent=128+) well-behaved by default, removing the need for careful per-dataset hyperparameter tuning. This is particularly important for complex datasets with hundreds of cell types (e.g. whole-embryo atlases, cross-atlas integration) where large latent spaces and wide hidden layers are needed to avoid competition between cell types for representational capacity.
Standard scVI (Lopez et al. 2018) models observed UMI counts
where:
-
$z_n \in \mathbb{R}^d$ — low-dimensional latent cell state -
$\ell_n \in (0, \infty)$ — library size (by default fixed to total UMI count per cell), with log-normal prior parameterised per batch$s_n$ -
$\rho_n \in \Delta^{G-1}$ — decoder output on the probability simplex (via softmax) as a fraction of total$\ell_n$ RNA per cell, representing denoised normalised gene expression -
$f_w(z_n, s_n): \mathbb{R}^d \times {0,1}^K \to \Delta^{G-1}$ — decoder neural network, conditioned on batch$s_n$ -
$\theta_{g,s_n} \in (0, \infty)$ — per-gene, per-batch inverse dispersion (code:px_r, stored as unconstrained$\phi_{g,s_n}$ where$\theta_{g,s_n} = \exp(\phi_{g,s_n})$ ) -
$s_n \in {0,1}^K$ — one-hot batch indicator for cell$n$ -
$c_n \in {0,1}^K$ — one-hot categorical covariate indicator for cell$n$
The inference model uses amortised variational inference to fit all cell specific variables (encoder NNs):
regularizedvi adapts cell2location/cell2fate modelling principles to scVI. All learnable parameters are initialised at their prior means to improve training stability.
Latent variable and library size — standard scVI structure with a constrained library prior. Library prior parameters library_size_key group _module.py:297–312):
Decoder output — batch-free decoder maps nn_conditioning_covariate_keys (many keys, _module.py:710–739, _components.py:461–462):
Additive background — per-gene ambient RNA with Gamma prior pushing ambient_covariate_keys (many keys, concatenated one-hot, _module.py:379–389 init, _module.py:700–708 one-hot selection, _module.py:877–885 prior penalty). The background parameter is always initialized at the prior mean, but the Gamma prior penalty in the loss is off by default (regularise_background=False); enable it explicitly if needed:
Feature scaling — per-gene, per-covariate multiplicative scaling capturing systematic biases (e.g. PCR amplification, RT efficiency differences between protocols). Parameterised as feature_scaling_covariate_keys (many keys); each covariate category gets its own scaling factor. When no scaling covariates are provided, a single _module.py:391–403 init, _module.py:758–771 application, _module.py:887–896 prior penalty):
Hierarchical dispersion prior with variational posterior — two-level prior on inverse dispersion px_r_mu and px_r_log_sigma. During training, _module.py:750–755 sampling). Dispersion groups dispersion_key (one key). A learned rate _module.py:834–873 full block, _module.py:844 Level 1 softplus, _module.py:867–869 Level 2 transform):
Expected mean counts — decoder output plus optional background, scaled by library size and feature scaling (_components.py:467 base rate, _module.py:758–771 feature scaling):
Observation model — GammaPoisson (= negative binomial) with mean
Notation:
-
$s_{e,g} = \exp(\beta_{e,g})$ — per-gene ambient background indexed byambient_covariate_keys(many keys);$\beta$ is the unconstrained parameter (code:additive_background). Whenbatch_keyis used alone,$e$ = batch group. -
$c_{k,n}$ — categorical covariates (site, donor, etc.), selected bynn_conditioning_covariate_keys(many keys). Injected into the decoder; optionally into the encoder viaencoder_covariate_keys. -
$y_{t,g} = \text{softplus}(\gamma_{t,g})/0.7$ — per-gene feature scaling indexed byfeature_scaling_covariate_keys(many keys); tight$\text{Gamma}(200, 200)$ prior centered at 1.0 -
$\rho_{ng} \in \mathbb{R}_{\geq 0}^G$ — decoder output via softplus (not on the simplex), since$\rho_{ng} + s_{e,g}$ need not sum to 1 -
$\theta_{g,d}$ — inverse dispersion indexed bydispersion_key(one key) with variational LogNormal posterior:px_r_mu($\mu_{g,d}$ ) andpx_r_log_sigma($\log \sigma_{g,d}$ ) are learnable parameters;$\theta = \exp(\mu)$ at inference. Initialised at$\mu = \log(\lambda^2) \approx 2.2$ so$\theta \approx 9$ (equilibrium:$1/\sqrt{\theta} = 1/\lambda$ ) (_module.py:314–347) -
$\lambda_d$ — learned Exponential rate, one per dispersion group;$\text{Gamma}(9, 3)$ hyper-prior has mean 3 (_module.py:349–360) -
$\ell_p^{\mu}$ ,$\ell_p^{\sigma^2}$ — library prior mean and variance perlibrary_size_keygroup$p$ (one key).$0.5$ scaling factor prevents library size from absorbing biological signal. -
Backward compat: When
batch_keyis used alone,$e = d = p$ (all index the same batch groups) and$t = k$ (categorical and scaling covariates share groups).
The NB variance is
RegularizedMultimodalVI extends regularizedvi to
Each modality
where n_latent={"rna": 96, "atac": 32}). Because every decoder receives the full
The following equations describe how observed counts
Library size — always learned (observed totals include ambient contamination). A low-capacity encoder infers library size per cell, regularised by a tight LogNormal prior estimated per library_size_key group _multimodule.py:415–431 prior buffers, _multimodule.py:981–990 loss):
Decoder output — maps joint latent code nn_conditioning_covariate_keys, many keys) to non-negative feature signal via softplus (_multimodule.py:791–808):
Additive background — per-feature ambient contamination with Gamma prior, indexed by ambient_covariate_keys (many keys, concatenated one-hot, _multimodule.py:480–494 init, _multimodule.py:779–789 one-hot selection, _multimodule.py:1058–1070 prior penalty):
Feature scaling — per-feature, per-covariate multiplicative scaling capturing systematic biases (GC content, mappability, peak caller sensitivity). Parameterised as feature_scaling_covariate_keys (many keys); each covariate category gets its own factor (_multimodule.py:496–505 init, _multimodule.py:810–820 activation and selection, _multimodule.py:1039–1056 prior penalty):
Expected mean counts — decoder output plus optional background, scaled by library size and feature scaling (_components.py:467 base rate, _multimodule.py:820 feature scaling):
Hierarchical dispersion prior — same two-level structure as single-modality with variational LogNormal posterior, per modality and dispersion_key group _multimodule.py:995–1037):
Observation model — GammaPoisson (= negative binomial) with mean
| Term | Symbol | Prior | What it captures | RNA default | ATAC default |
|---|---|---|---|---|---|
| Additive background |
|
Per-feature ambient contamination; ambient_covariate_keys (many keys) |
ON | off | |
| Feature scaling |
|
Per-feature multiplicative bias; feature_scaling_covariate_keys (many keys) |
off | ON | |
| Learned library size |
|
Low-capacity encoder; library_size_key (one key) |
always ON | always ON | |
| Dispersion regularisation |
|
Containment prior; dispersion_key (one key) |
ON | ON | |
| Batch-free decoder | — | — | Decoder conditioned only on nn_conditioning_covariate_keys (many keys) |
ON | ON |
Setting
Per-modality encoder — each modality's encoder takes its own observed counts as input and independently constructs a Gaussian posterior over its private latent slice. The RNA encoder sees only RNA counts; the ATAC encoder sees only ATAC counts. This forces the model to build a dedicated representation for each modality before combining them:
Posterior concatenation — samples from the per-modality posteriors are concatenated to form the joint representation fed to all decoders. Because every decoder
Alternative latent strategies (selectable via latent_mode):
-
"concatenation"(default) — per-modality encoders, posteriors concatenated; total latent dim$= \sum_m d_m$ -
"weighted_mean"— per-modality encoders, posteriors mixed into a single shared latent by learned scalar weights (MultiVI-style); requires equal$d_m$ across modalities -
"single_encoder"— one joint encoder on all concatenated inputs, producing a single shared latent; simplest but loses per-modality interpretability
With a concatenated latent space it is useful to know which latent dimensions each decoder actually uses. get_modality_attribution() computes the mean absolute Jacobian of each decoder's predicted mean
This reveals the empirical partition of the latent space: even though concatenation assigns each slice to a modality by construction, decoders can learn to cross-use other modalities' slices. The weighted representation weighted_z
-
Ambient RNA correction with Gamma prior: Per-gene, per-ambient-category additive background
$s_{e,g} = \exp(\beta_{e,g})$ captures ambient RNA contamination, mirroring cell2location's$s_g \cdot g_{e,g}$ structure. A$\text{Gamma}(1, 100)$ prior pushes$s_{e,g}$ toward 0.01, keeping background small relative to biological signal. Initialised atlog(0.01)(prior mean) with per-category selection via concatenated one-hot encoding across allambient_covariate_keys. -
Hierarchical dispersion regularisation: Prior
$1/\sqrt{\theta_{g,d}} \sim \text{Exponential}(\lambda_d)$ is a containment prior (Simpson et al. 2017) that penalises small$\theta$ (excessive overdispersion), regularising the NB toward the Poisson baseline during gradient-based training. The data likelihood provides the opposing force, pulling$\theta$ toward values that explain observed count variance. The rate$\lambda_d$ is learned per dispersion group (selected bydispersion_key) with a$\text{Gamma}(9, 3)$ hyper-prior (mean 3). Dispersion$\theta = \exp(\phi)$ is initialised at$\lambda^2 = 9$ (equilibrium). As used in cell2location/cell2fate. -
Batch-free decoder with separated correction paths: The decoder
$f_w(z_n, c_{k,n})$ receives only categorical covariates$c_{k,n}$ (site, donor, protocol viann_conditioning_covariate_keys), not the ambient or dispersion covariates. This separates batch correction into structurally different paths: (a) a constrained additive path ($s_{e,g}$ with Gamma prior, selected byambient_covariate_keys) for per-sample ambient RNA, (b) a flexible multiplicative path through categorical covariates in the decoder for systematic differences between donors, protocols, or sites (e.g. PCR bias, RT efficiency, 10x chemistry versions), and (c) per-group dispersion$\theta_{g,d}$ (selected bydispersion_key) for variance differences. In standard scVI, the decoder handles all batch effects through a single flexible path, which can absorb biological variation. The separation is most beneficial when batches have high within-batch cell type diversity (e.g. whole-embryo samples), because the additive background can be cleanly identified as the baseline signal shared across all cells in a batch. -
Softplus activation: Because
$\rho_{ng} + s_{e_n,g}$ must be non-negative but need not sum to 1 across genes, softmax is replaced with softplus. The library size$\ell_n$ acts as a true normalisation factor. -
Learned library size with constrained prior: The observed total counts include ambient RNA, so library size must be learned (not observed). Prior variance is scaled by 0.5 to prevent the library size from absorbing biological signal. Library encoder has low capacity (
n_hidden=16). -
LayerNorm and dropout-on-input: LayerNorm replaces BatchNorm (independent of batch composition). Dropout is applied before the linear layer (feature-level masking).
-
Auto-scaled early stopping: The
early_stopping_min_delta_per_featureparameter (default: 0.0002) auto-scales the early stopping threshold as$\text{min_delta} = n_\text{features} \times \text{early_stopping_min_delta_per_feature}$ . This adapts the stopping criterion to dataset size: datasets with more features produce larger expected ELBO values and need a proportionally larger delta to distinguish meaningful improvement from noise.
-
Best suited for single-nucleus RNA-seq (independent modality and multiome), which typically has substantial ambient RNA contamination. The ambient correction is less necessary for single-cell RNA-seq where ambient levels are lower.
-
Study design matters: The structured assumptions (additive ambient + multiplicative categorical covariates) depend on the experimental design. With some study designs, every batch has both additive effects (ambient RNA) and multiplicative effects (PCR bias, RT differences, 10x 3' v1 vs v2 vs v3, 3' vs 5'). These assumptions may not hold for Smart-seq type data where every cell can have PCR bias and RT differences.
-
Using as standard scVI with ambient correction: If you provide the batch covariate to both
batch_keyandnn_conditioning_covariate_keys, the model effectively operates as standard scVI with ambient RNA correction (batch effects handled through both additive and multiplicative paths). -
Not a strict ambient correction model: Unlike CellBender (Fleming et al. 2023), this model is not constrained by the ambient count distribution from empty droplets. However, because it does not require empty droplets data, it can be more easily applied to integration of published datasets where empty droplet profiles are unavailable.
-
Additivity in non-negative space: The additive background operates in non-negative space ($s_{e,g} = \exp(\beta_{e,g})$), reflecting the ambient RNA correction mechanism. Without empty droplets data, the additive component can learn the minimal expression of each gene across cells — for many genes this reflects ambient levels, but for ubiquitously expressed genes it captures genuine baseline expression. The additive mechanism therefore works best when individual batches are composed of diverse cell types.
-
Regularised overdispersion alone likely helps: The containment prior on overdispersion regularises the NB toward the Poisson baseline, preventing the model from absorbing residuals through excessive variance (small
$\theta$ ). This forces the decoder to capture genuine biological signal through its mean structure rather than relying on high overdispersion to explain noise. This likely contributes to improved sensitivity, but needs more systematic testing.
In standard scVI (and early regularizedvi), a single batch_key controls all batch-dependent model components: additive ambient background, dispersion, and library size prior. This works well for 10x Chromium experiments where one sample = one GEM well = one set of technical biases.
However, in complex experimental designs — particularly combinatorial indexing protocols like sci-RNA-seq3 — different technical effects arise at different experimental granularities:
| Technical effect | 10x Chromium source | sci-RNA-seq3 source | Model component |
|---|---|---|---|
| Ambient RNA contamination | GEM well | Embryo (lysis batch) | ambient_covariate_keys |
| Library size distribution | GEM well | PCR well (amplification batch) | library_size_key |
| Overdispersion profile | GEM well | Embryo or experiment | dispersion_key |
| Multiplicative biases (RT, PCR) | Protocol version | Experiment batch | nn_conditioning_covariate_keys |
| Per-feature scaling (feature scaling) | — | Experiment batch | feature_scaling_covariate_keys |
Using a single batch_key forces all components to share the same granularity, which either under-corrects (too few groups) or over-parameterises (too many groups for components that don't need that resolution).
| Parameter | Model component | Encoder | Decoder | Shape per modality |
|---|---|---|---|---|
ambient_covariate_keys |
Additive background |
No | No (additive on rate) | (n_feat, sum(n_cats)) |
nn_conditioning_covariate_keys |
Standard scVI-style injection |
No | Yes (one-hot) | Standard |
feature_scaling_covariate_keys |
Feature scaling |
No | Multiplicative on rate | (sum(n_cats), n_feat) |
dispersion_key |
Inverse dispersion |
No | Indexing | (n_feat, n_disp_cats) |
library_size_key |
Library prior |
No | No | (1, n_lib_cats) |
encoder_covariate_keys |
Categorical covariates for encoder only | Yes (one-hot) | No | Standard |
-
batch_keyalone (backward compatible): Automatically fans out toambient_covariate_keys=[batch_key],dispersion_key=batch_key,library_size_key=batch_key. Equivalent to the original single-batch design — in the notation above,$e = d = p$ (same batch groups) and$t = k$ (same categorical groups). -
batch_key+ purpose-specific keys: RaisesValueError. These are mutually exclusive — use one approach or the other. -
feature_scaling_covariate_keys: If not specified butnn_conditioning_covariate_keysis provided, defaults tonn_conditioning_covariate_keys(multi-modal only).
The encoder and decoder receive different subsets of covariates:
-
Encoder receives: gene expression
$x_n$ + continuous covariates (if any) +encoder_covariate_keyscategoricals (if any). By defaultencoder_covariate_keys=False, so the encoder sees only expression and continuous covariates — matching the scVI/MultiVI/PeakVI default. This keeps the latent space free of batch information. Settingencoder_covariate_keysto a list of keys (e.g.["batch", "site"]) injects those categoricals into the encoder; a warning is emitted for non-default values. -
Decoder receives:
[cat_covs...]fromnn_conditioning_covariate_keysonly (batch-free by default). Whenuse_batch_in_decoder=True, batch is additionally included.
# batch_key fans out to all components — equivalent to original API
regularizedvi.AmbientRegularizedSCVI.setup_anndata(
adata,
layer="counts",
batch_key="sample",
nn_conditioning_covariate_keys=["donor", "site"],
)regularizedvi.AmbientRegularizedSCVI.setup_anndata(
adata,
layer="counts",
# Ambient RNA comes from lysis: each embryo has its own ambient profile
ambient_covariate_keys=["embryo_id", "pcr_well"],
# Multiplicative effects from experiment-level protocol differences
nn_conditioning_covariate_keys=["experiment"],
# Dispersion varies by embryo (tissue composition → variance structure)
dispersion_key="embryo_id",
# Library size determined by PCR well (amplification batch)
library_size_key="pcr_well",
)export PYTHONNOUSERSITE="1"
conda create -y -n regularizedvi python=3.11
conda activate regularizedvi
# Install PyTorch with CUDA 12.4 support
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install JAX (optional, for some scvi-tools features)
pip install jax
# Install scvi-tools and regularizedvi
pip install scvi-tools
pip install git+https://github.com/vitkl/regularizedvi.git@main
# Install additional analysis packages
pip install scanpy igraph matplotlib ipykernel jupyter
# Register Jupyter kernel
python -m ipykernel install --user --name=regularizedvi --display-name='Environment (regularizedvi)'git clone https://github.com/vitkl/regularizedvi.git
cd regularizedvi
pip install -e ".[dev,test]"When a single batch variable can describe several technical effects — ambient RNA contamination, library size distribution, and overdispersion profile (typical for 10x Chromium where one sample = one GEM well):
import regularizedvi
regularizedvi.AmbientRegularizedSCVI.setup_anndata(
adata,
layer="counts",
batch_key="batch",
nn_conditioning_covariate_keys=["site", "donor"],
)
model = regularizedvi.AmbientRegularizedSCVI(
adata,
n_hidden=512,
n_layers=1,
n_latent=128,
)
model.train(
train_size=1.0,
max_epochs=2000,
batch_size=1024,
)
latent = model.get_latent_representation()When different technical effects arise at different experimental granularities (e.g. sci-RNA-seq3, combinatorial indexing), you can assign each model component its own covariate:
regularizedvi.AmbientRegularizedSCVI.setup_anndata(
adata,
layer="counts",
ambient_covariate_keys=["embryo_id", "pcr_well"], # additive background
nn_conditioning_covariate_keys=["experiment"], # encoder/decoder injection
dispersion_key="embryo_id", # per-group overdispersion
library_size_key="pcr_well", # library size prior groups
)See Covariate design below for full details.
The model now uses GammaPoisson likelihood (cell2location-style, mathematically equivalent to NB) by default with a containment prior on overdispersion to regularise the model. The default dispersion is "gene-batch", providing per-gene, per-batch inverse dispersion parameters.
See the changelog.
For questions and help requests, you can reach out in the scverse discourse. If you found a bug, please use the issue tracker.
t.b.a
- Lopez, R., Regier, J., Cole, M.B. et al. Deep generative modeling for single-cell transcriptomics. Nat Methods 15, 1053–1058 (2018). doi:10.1038/s41592-018-0229-2
- Kleshchevnikov, V., Shmatko, A., Dann, E. et al. Cell2location maps fine-grained cell types in spatial transcriptomics. Nat Biotechnol 40, 661–671 (2022). doi:10.1038/s41587-021-01139-4
- Aivazidis, A., Memi, F., Kleshchevnikov, V. et al. Cell2fate infers RNA velocity modules to improve cell fate prediction. Nat Methods 22, 698–707 (2025). doi:10.1038/s41592-025-02608-3
- Simpson, D., Rue, H., Riebler, A. et al. Penalising Model Component Complexity: A Principled, Practical Approach to Constructing Priors. Statist. Sci. 32(1), 1-28 (2017). doi:10.1214/16-STS576