Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 223 additions & 1 deletion sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import math
import warnings
from typing import Dict, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -33,6 +33,192 @@
from sbi.utils.torchutils import ensure_theta_batched


class PosteriorBuilder:
"""Fluent builder returned by ``VectorFieldPosterior.with_iid()`` and
``VectorFieldPosterior.with_guidance()``.

It holds a modified ``VectorFieldBasedPotential`` that has been configured
for a specific iid method or guidance method. Calling ``sample()`` or
``log_prob()`` on the builder ensures that ``potential_fn.init()`` is called
exactly once (on the first invocation) before handing off to the underlying
posterior.

Builders can be stacked::

samples = (
posterior
.with_iid(method="auto_gauss", num_obs=10)
.with_guidance(method="affine_classifier_free", likelihood_scale=2.0)
.sample((1000,), x=x_obs)
)

In the stacked case each ``with_*`` call wraps the configuration of the
previous one on the **same** ``VectorFieldBasedPotential``; the final
``init()`` call on that potential handles all preprocessing at once.
"""

def __init__(
self,
posterior: "VectorFieldPosterior",
iid_method: Optional[str] = None,
iid_params: Optional[Dict[str, Any]] = None,
guidance_method: Optional[str] = None,
guidance_params: Optional[Dict[str, Any]] = None,
):
self._posterior = posterior
self._iid_method = iid_method
self._iid_params = iid_params or {}
self._guidance_method = guidance_method
self._guidance_params = guidance_params or {}
self._initialized = False

# Chaining helpers
def with_iid(
self,
method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
**method_kwargs: Any,
) -> "PosteriorBuilder":
"""Return a new builder with an updated iid method.

Args:
method: The IID score accumulation method to use.
**method_kwargs: Additional parameters forwarded to the IID method.

Returns:
A new ``PosteriorBuilder`` with the given iid configuration.
"""
return PosteriorBuilder(
self._posterior,
iid_method=method,
iid_params=method_kwargs,
guidance_method=self._guidance_method,
guidance_params=self._guidance_params,
)

def with_guidance(
self,
method: str,
**method_kwargs: Any,
) -> "PosteriorBuilder":
"""Return a new builder with an updated guidance method.

Args:
method: The guidance method to use (e.g. ``"affine_classifier_free"``).
**method_kwargs: Additional parameters forwarded to the guidance method.

Returns:
A new ``PosteriorBuilder`` with the given guidance configuration.
"""
return PosteriorBuilder(
self._posterior,
iid_method=self._iid_method,
iid_params=self._iid_params,
guidance_method=method,
guidance_params=method_kwargs,
)

# Internal helpers
def _apply_config_and_init(self, x: Tensor) -> VectorFieldBasedPotential:
"""Configure the potential with iid/guidance settings and run init().

This is called once, before the first sample/log_prob call.

Args:
x: The observed data, already reshaped to batch-event form.

Returns:
The configured ``VectorFieldBasedPotential``.
"""
potential_fn = self._posterior.potential_fn
is_iid = x.shape[0] > 1

potential_fn.set_x(
x,
x_is_iid=is_iid,
iid_method=self._iid_method or potential_fn.iid_method,
iid_params=self._iid_params or None,
guidance_method=self._guidance_method,
guidance_params=self._guidance_params or None,
)
# Explicit one-time setup — precision estimation, GMM fitting, etc.
potential_fn.init()
self._initialized = True
return potential_fn

# Public API
def sample(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
**sampler_kwargs: Any,
) -> Tensor:
"""Draw samples from the posterior using the configured iid/guidance method.

``init()`` is called on the potential function exactly once, before the
sampler starts, so expensive preprocessing (precision estimation, GMM
fitting) is never hidden inside the first diffusion step.

Args:
sample_shape: Shape of the samples to draw.
x: Observed data. If ``None``, the posterior's default x is used.
**sampler_kwargs: Additional keyword arguments forwarded to
``VectorFieldPosterior.sample()``. Do *not* pass ``iid_method``
or ``guidance_method`` here — configure them via
``with_iid()`` / ``with_guidance()`` instead.

Returns:
Samples of shape ``(*sample_shape, *theta_shape)``.
"""
posterior = self._posterior
x_resolved = posterior._x_else_default_x(x)
x_reshaped = reshape_to_batch_event(
x_resolved, posterior.vector_field_estimator.condition_shape
)
self._apply_config_and_init(x_reshaped)

# Delegate to the underlying sample — pass x=None so it re-uses what
# we already set on potential_fn, and suppress the iid/guidance kwargs
# to avoid double-applying them.
return posterior.sample(
sample_shape,
x=x_resolved,
iid_method=self._iid_method,
iid_params=self._iid_params or None,
guidance_method=self._guidance_method,
guidance_params=self._guidance_params or None,
**sampler_kwargs,
)

def log_prob(
self,
theta: Tensor,
x: Optional[Tensor] = None,
track_gradients: bool = False,
ode_kwargs: Optional[Dict] = None,
) -> Tensor:
"""Return the log-probability of the posterior at ``theta``.

Args:
theta: Parameters at which to evaluate the log-probability.
x: Observed data. If ``None``, the posterior's default x is used.
track_gradients: Whether to track gradients through the computation.
ode_kwargs: Additional keyword arguments for the ODE solver.

Returns:
Log-posterior probability of shape ``(len(theta),)``.
"""
posterior = self._posterior
x_resolved = posterior._x_else_default_x(x)
x_reshaped = reshape_to_batch_event(
x_resolved, posterior.vector_field_estimator.condition_shape
)
if not self._initialized:
self._apply_config_and_init(x_reshaped)
return posterior.log_prob(
theta, x=x_resolved, track_gradients=track_gradients, ode_kwargs=ode_kwargs
)


class VectorFieldPosterior(NeuralPosterior):
r"""Posterior based on flow- or score-matching estimators.

Expand Down Expand Up @@ -109,6 +295,42 @@ def __init__(
self._purpose = """It samples from the vector field model given the \
vector_field_estimator."""

def with_iid(
self,
method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
**method_kwargs: Any,
) -> "PosteriorBuilder":
"""Return a builder configured for i.i.d. observations.

Args:
method: IID score accumulation method.
**method_kwargs: Additional keyword arguments forwarded to the
chosen ``IIDScoreFunction`` subclass.

Returns:
A :class:`PosteriorBuilder` ready to call ``.sample()`` on.
"""
return PosteriorBuilder(self, iid_method=method, iid_params=method_kwargs)

def with_guidance(
self,
method: str,
**method_kwargs: Any,
) -> "PosteriorBuilder":
"""Return a builder configured with a guidance method.

Args:
method: Guidance method name.
**method_kwargs: Additional keyword arguments forwarded to the
guidance method.

Returns:
A :class:`PosteriorBuilder` ready to call ``.sample()`` on.
"""
return PosteriorBuilder(
self, guidance_method=method, guidance_params=method_kwargs
)

def to(self, device: Union[str, torch.device]) -> None:
"""Move posterior to device.

Expand Down
24 changes: 23 additions & 1 deletion sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from abc import ABCMeta, abstractmethod
from typing import Optional, Protocol, Union
from typing import Any, Optional, Protocol, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -40,7 +40,29 @@ def gradient(
self, theta: Tensor, time: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
raise NotImplementedError

def init(self, **kwargs: Any) -> "BasePotential":
"""One-time setup before sampling begins.

Subclasses that need expensive preprocessing (hyperparameter search,
covariance estimation, GMM fitting, etc.) override this method to do
that work here instead of lazily on the first ``__call__``.

The default is a no-op so all existing potentials remain valid without
any changes.

Returns ``self`` so calls can be chained::

potential.init(x_obs=x).gradient(theta, t)

Args:
**kwargs: Subclass-specific keyword arguments (e.g. ``x_obs``).

Returns:
``self`` for method chaining.
"""
return self

@property
def x_is_iid(self) -> bool:
"""If x has batch dimension greater than 1, whether to intepret the batch as iid
Expand Down
63 changes: 63 additions & 0 deletions sbi/inference/potentials/vector_field_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,69 @@ def rebuild_flows_for_batch(self, **kwargs) -> List[NormalizingFlow]:
flows.append(flow)
return flows


def init(self, **kwargs: Any) -> "VectorFieldBasedPotential":
"""Run one-time hyperparameter estimation before sampling.

For iid methods that need to estimate precision matrices (``auto_gauss``,
``gauss``) or guidance methods that fit a GMM (``prior_guide``), calling
``init()`` explicitly ensures that expensive preprocessing is done once,
up-front, before the sampler starts.

This is typically called automatically by the ``PosteriorBuilder`` API
(``posterior.with_iid(...).sample(...)``). It can also be called manually
when fine-grained control over initialization timing is desired::

potential_fn.set_x(x_o, x_is_iid=True)
potential_fn.init() # precision estimation happens here
posterior.sample((1000,)) # no surprise overhead on first step

Args:
**kwargs: Forwarded to the underlying IID / guidance ``init()`` hooks.
Currently unused; reserved for future subclass customisation.

Returns:
``self`` for method chaining.
"""
x_o = self._x_o
if x_o is None:
return self

device = x_o.device

# --- guidance init ---
if self.guidance_method is not None:
score_wrapper, config_cls = get_guidance_method(self.guidance_method)
config_params = config_cls(**(self.guidance_params or {}))
vf_estimator = score_wrapper(
self.vector_field_estimator,
self.prior,
config=config_params,
device=device,
)
else:
vf_estimator = self.vector_field_estimator

# --- iid init (precision estimation) ---
if self.x_is_iid and x_o.shape[0] > 1:
assert self.prior is not None, "Prior is required for iid init."
iid_method = get_iid_method(self.iid_method)
# Constructing the iid method triggers precision estimation in
# GaussCorrectedScoreFn and AutoGaussCorrectedScoreFn (via
# estimate_prior_precision / estimate_posterior_precision which are
# lru_cache'd class methods). After this call the cached results are
# warm, so the first gradient() call during sampling pays no extra cost.
iid_method(
vf_estimator,
self.prior,
device=device,
**(self.iid_params or {}),
)

return self




def vector_field_estimator_based_potential(
vector_field_estimator: ConditionalVectorFieldEstimator,
Expand Down
Loading