Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4891d14
wip: update pyproject toml and copy over npe_pfn implementation
jsvetter Feb 19, 2026
4297cb1
first working pure density estimator, with some rough edges still
jsvetter Feb 19, 2026
174e49d
wip: adding builder, some decisisons necessary soon
jsvetter Feb 19, 2026
b7be749
first kinda working version, many rough edges
jsvetter Feb 19, 2026
2b844f3
working, with inheritance from neural inference
jsvetter Feb 19, 2026
916d07b
revert some unnecessary changes
jsvetter Feb 19, 2026
508a8ea
some device handling
jsvetter Feb 19, 2026
75f8e4e
wip, dont allow standardzing x for now
jsvetter Feb 19, 2026
85b1d1a
no z scoring in default
jsvetter Feb 19, 2026
6b1e7cf
very strict handling of standardization
jsvetter Feb 19, 2026
3d0207f
first working filtering logic
jsvetter Feb 20, 2026
b4a6bc6
some renaming
jsvetter Feb 20, 2026
49dbabb
cleaner via build_posterior
jsvetter Feb 20, 2026
3a11993
completely get rid of train
jsvetter Feb 20, 2026
4c838e4
add TODO
jsvetter Feb 20, 2026
fd074b4
simplify stuff, add max context
jsvetter Feb 23, 2026
9c8a7e7
add flexible filtering
jsvetter Feb 23, 2026
b3dee2c
update docstrings
jsvetter Feb 23, 2026
acbfebd
implement sample_and_log_prob
jsvetter Feb 23, 2026
335461f
more docstrings
jsvetter Feb 23, 2026
5b6407a
fix filter_size validation lower bound
jsvetter Feb 23, 2026
c1147ae
fix all reported precommit issues
jsvetter Feb 23, 2026
1f8f054
run mini bm and add imports
jsvetter Feb 23, 2026
f3560a3
deal with TabPFN license
jsvetter Feb 24, 2026
3adfc2f
use embedded dataset for filtering
jsvetter Feb 24, 2026
d55348b
small fix in posterior parameters
jsvetter Feb 24, 2026
690df4a
Merge remote-tracking branch 'upstream/main' into npe_pfn_dev
jsvetter Feb 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"torch>=1.13.0",
"tqdm",
"zuko>=1.2.0",
"tabpfn",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe @janfb if you have any opinion on adding it as fixed dependency. The only extra dependencies this adds is pydantic, eval-type-backport, tabpfn-common-utils[telemetry-interactive] and filelock, which all have light dependencies. But one can also think about making it an optional dependency only installed when needed. I am fine with adding it.

]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from sbi.inference.trainers.marginal import MarginalTrainer
from sbi.inference.trainers.nle import MNLE, NLE_A
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C # noqa: F401
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C, NPE_PFN # noqa: F401
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C # noqa: F401
from sbi.inference.trainers.vfpe import FMPE, NPSE

Expand All @@ -20,7 +20,7 @@
SNPE_A = NPE_A
SNPE_B = NPE_B
SNPE = APT = SNPE_C = NPE = NPE_C
_npe_family = ["NPE_A", "NPE_B", "NPE_C"]
_npe_family = ["NPE_A", "NPE_B", "NPE_C", "NPE_PFN"]


SRE = SNRE = SNRE_B = NRE = NRE_B
Expand Down
4 changes: 4 additions & 0 deletions sbi/inference/posteriors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.posteriors.ensemble_posterior import EnsemblePosterior
from sbi.inference.posteriors.filtered_direct_posterior import FilteredDirectPosterior
from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior
from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior
from sbi.inference.posteriors.npe_a_posterior import NPE_A_Posterior
from sbi.inference.posteriors.posterior_parameters import (
DirectPosteriorParameters,
FilteredDirectPosteriorParameters,
ImportanceSamplingPosteriorParameters,
MCMCPosteriorParameters,
RejectionPosteriorParameters,
Expand All @@ -28,9 +30,11 @@
"VectorFieldPosterior",
"VIPosterior",
"DirectPosteriorParameters",
"FilteredDirectPosteriorParameters",
"ImportanceSamplingPosteriorParameters",
"MCMCPosteriorParameters",
"RejectionPosteriorParameters",
"VectorFieldPosteriorParameters",
"VIPosteriorParameters",
"FilteredDirectPosterior",
]
256 changes: 256 additions & 0 deletions sbi/inference/posteriors/filtered_direct_posterior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import warnings
from typing import Callable, Literal, Optional, Union

import torch
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.neural_nets.estimators.tabpfn_flow import TabPFNFlow
from sbi.sbi_types import Shape

FilterMode = Literal["knn", "first"]
FilterFn = Callable[[Tensor, Tensor, Tensor, int], Tensor]
FilterType = Union[FilterMode, FilterFn]


class FilteredDirectPosterior(DirectPosterior):
r"""Direct posterior with context filtering for TabPFN estimators.

For every queried condition `x`, this posterior selects a subset of context
simulations and updates the underlying `TabPFNFlow` context before delegating to
`DirectPosterior` sampling / log-probability logic.
"""

def __init__(
self,
estimator: TabPFNFlow,
prior: Distribution,
full_context_input: Tensor,
full_context_condition: Tensor,
max_sampling_batch_size: int = 10_000,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
enable_transform: bool = True,
filter_type: FilterType = "knn",
filter_size: int = 2048,
):
r"""Initialize a direct posterior with observation-dependent context filtering.

Args:
estimator: TabPFN-based posterior estimator used for evaluation.
prior: Prior distribution over parameters.
full_context_input: Full set of context inputs (typically `theta`).
full_context_condition: Full set of context conditions (typically `x`).
max_sampling_batch_size: Maximum number of samples drawn per internal batch.
device: Device on which posterior computations are performed.
x_shape: Optional event shape for observations.
enable_transform: Whether to use unconstrained-space transforms for MAP.
filter_type: Context filtering strategy. Either `"knn"`, `"first"`,
or a callable returning selected indices.
filter_size: Maximum number of context points retained per observation.
"""
if filter_size <= 1:
raise ValueError(f"filter_size must be greater than 1, got {filter_size}.")

super().__init__(
posterior_estimator=estimator,
prior=prior,
max_sampling_batch_size=max_sampling_batch_size,
device=device,
x_shape=x_shape,
enable_transform=enable_transform,
)

self.filter_size = int(filter_size)
self.filtering = filter_type
self._full_context_input = full_context_input
self._full_context_condition = full_context_condition
self._full_context_condition_embedded = estimator.embed(full_context_condition)

def _validate_filter_indices(self, indices: Tensor, num_context: int) -> Tensor:
"""Validate and normalize context indices returned by a filter."""

if indices.numel() < 2:
raise ValueError("Filtering function must return at least two indices.")

indices = indices.to(device=self._full_context_input.device, dtype=torch.long)
unique_indices = torch.unique(indices, sorted=False)
if unique_indices.numel() < indices.numel():
warnings.warn(
"Filtering function returned duplicate indices. Duplicates were "
"removed before setting context.",
stacklevel=2,
)

return unique_indices

def _select_context_indices(self, condition_embedded: Tensor) -> Tensor:
"""Select context indices according to the configured filtering strategy."""
num_context = self._full_context_condition_embedded.shape[0]
k = min(self.filter_size, num_context)

if k >= num_context:
return torch.arange(num_context, device=self._full_context_input.device)

if isinstance(self.filtering, str):
if self.filtering == "knn":
indices = _knn_filter_indices(
condition_embedded, self._full_context_condition_embedded, k
)
elif self.filtering == "first":
indices = _first_filter_indices(k, self._full_context_input.device)
else:
raise RuntimeError(f"Unsupported filtering mode: {self.filtering}")

return self._validate_filter_indices(indices, num_context)

indices = self.filtering(
condition_embedded,
self._full_context_input,
self._full_context_condition_embedded,
k,
)
return self._validate_filter_indices(indices, num_context)

def _set_context_for_x_o(self, x_o: Tensor) -> None:
"""Filter and set estimator context for a single queried observation."""
condition_embedded = self.posterior_estimator.embed(x_o)
unique_indices = self._select_context_indices(condition_embedded)

self.posterior_estimator.set_context(
self._full_context_input[unique_indices],
self._full_context_condition[unique_indices],
)

def sample(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
reject_outside_prior: bool = True,
max_sampling_time: Optional[float] = None,
return_partial_on_timeout: bool = False,
) -> Tensor:
r"""Sample from the posterior after setting context for the queried `x`.

Args:
sample_shape: Shape of the returned sample batch.
x: Observation to condition on. Uses the default observation if `None`.
max_sampling_batch_size: Maximum internal sampling batch size.
show_progress_bars: Whether to display progress bars.
reject_outside_prior: Whether to reject samples outside prior support.
max_sampling_time: Optional timeout in seconds.
return_partial_on_timeout: Whether to return collected samples on timeout.

Returns:
Samples from the filtered direct posterior.
"""
x_for_context = self._x_else_default_x(x)
self._set_context_for_x_o(x_for_context)
return super().sample(
sample_shape=sample_shape,
x=x,
max_sampling_batch_size=max_sampling_batch_size,
show_progress_bars=show_progress_bars,
reject_outside_prior=reject_outside_prior,
max_sampling_time=max_sampling_time,
return_partial_on_timeout=return_partial_on_timeout,
)

def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
reject_outside_prior: bool = True,
max_sampling_time: Optional[float] = None,
return_partial_on_timeout: bool = False,
) -> Tensor:
"""Batched sampling is not supported for observation-dependent filtering."""
raise NotImplementedError(
"Filtering makes the context observation dependent. "
"Batched inference requires sharing context, "
"which is currently not supported."
)

def log_prob(
self,
theta: Tensor,
x: Optional[Tensor] = None,
norm_posterior: bool = True,
track_gradients: bool = False,
leakage_correction_params: Optional[dict] = None,
) -> Tensor:
r"""Evaluate posterior log-probability after setting context for `x`.

Args:
theta: Parameters at which to evaluate log-probability.
x: Observation to condition on. Uses the default observation if `None`.
norm_posterior: Whether to include leakage correction normalization.
track_gradients: Whether to evaluate with gradient tracking.
leakage_correction_params: Optional parameters for leakage correction.

Returns:
Posterior log-probabilities for ``theta`` conditioned on ``x``.
"""
x_for_context = self._x_else_default_x(x)
self._set_context_for_x_o(x_for_context)
return super().log_prob(
theta=theta,
x=x,
norm_posterior=norm_posterior,
track_gradients=track_gradients,
leakage_correction_params=leakage_correction_params,
)

def log_prob_batched(
self,
theta: Tensor,
x: Tensor,
norm_posterior: bool = True,
track_gradients: bool = False,
leakage_correction_params: Optional[dict] = None,
) -> Tensor:
"""Batched log-probability is unsupported with per-observation filtering."""
raise NotImplementedError(
"Filtering makes the context observation dependent. "
"Batched inference requires sharing context, "
"which is currently not supported."
)

def map(
self,
x=None,
num_iter=1000,
num_to_optimize=100,
learning_rate=0.01,
init_method="posterior",
num_init_samples=1000,
save_best_every=10,
show_progress_bars=False,
force_update=False,
):
"""MAP is not supported because gradient-based optimization is unavailable."""
raise NotImplementedError(
"Computing the MAP requires gradients, which are currently not supported "
"for NPE-PFN."
)


def _knn_filter_indices(
condition_embedded: Tensor,
full_context_condition: Tensor,
filter_size: int,
) -> Tensor:
"""Return flattened k-nearest-neighbor context indices."""
distances = torch.cdist(condition_embedded, full_context_condition, p=2)
nn_indices = torch.topk(distances, k=filter_size, largest=False, dim=1).indices
return nn_indices.reshape(-1)


def _first_filter_indices(filter_size: int, device: torch.device) -> Tensor:
"""Return indices of the first `filter_size` context entries."""
return torch.arange(filter_size, device=device)
38 changes: 38 additions & 0 deletions sbi/inference/posteriors/posterior_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,44 @@ def validate(self):
raise ValueError("max_sampling_batch_size must be greater than 0.")


@dataclass(frozen=True)
class FilteredDirectPosteriorParameters(PosteriorParameters):
"""Parameters for initializing `FilteredDirectPosterior`.

Fields:
max_sampling_batch_size: Batchsize of samples drawn from
the proposal at every iteration.
enable_transform: Whether to transform parameters to unconstrained space
during MAP optimization. When False, an identity transform will be
returned for `theta_transform`.
filter_size: Number of context simulations retained after filtering.
filter_type: Filtering strategy. Either `"knn"`, `"first"`, or a
callable returning context indices.
"""

max_sampling_batch_size: int = 10_000
enable_transform: bool = True
filter_size: int = 2048
filter_type: Union[Literal["knn", "first"], Callable] = "knn"

def validate(self):
"""Validate `FilteredDirectPosteriorParameters` fields."""

if not is_positive_int(self.max_sampling_batch_size):
raise ValueError("max_sampling_batch_size must be greater than 0.")

if not is_positive_int(self.filter_size - 1):
raise ValueError("filter_size must be greater than 1.")

if not (
(isinstance(self.filter_type, str) and self.filter_type in {"knn", "first"})
or callable(self.filter_type)
):
raise ValueError(
"filter_type must be one of ['knn', 'first'] or a callable."
)


@dataclass(frozen=True)
class ImportanceSamplingPosteriorParameters(PosteriorParameters):
"""
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
mcmc_transform,
within_support,
)
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.torchutils import ensure_theta_batched, infer_module_device


def posterior_estimator_based_potential(
Expand Down Expand Up @@ -49,7 +49,7 @@ def posterior_estimator_based_potential(
to unconstrained space.
"""

device = str(next(posterior_estimator.parameters()).device)
device = infer_module_device(posterior_estimator, fallback="cpu")

potential_fn = PosteriorBasedPotential(
posterior_estimator, prior, x_o, device=device
Expand Down
Loading
Loading