Skip to content

Refactor NRE loss logic to Strategy Pattern (Phase 2 of #1241)#1826

Open
Sumit6307 wants to merge 6 commits intosbi-dev:mainfrom
Sumit6307:refactor/nre-strategy-phase2
Open

Refactor NRE loss logic to Strategy Pattern (Phase 2 of #1241)#1826
Sumit6307 wants to merge 6 commits intosbi-dev:mainfrom
Sumit6307:refactor/nre-strategy-phase2

Conversation

@Sumit6307
Copy link
Copy Markdown

Summary

Building exactly on the maintainers' feedback and unified vision established in Phase 1 (NPE-C refactoring PR #1755), this PR executes Phase 2: fully extracting the NRE loss calculations (NRE_A, NRE_B, NRE_C, and BNRE) into composable, isolated strategies that conform to an NRELossStrategy Protocol.

By outsourcing AALRLoss, SRELoss, CNRELoss, and BNRELoss mathematically intensive routines to nre_loss.py, we entirely eliminate _loss() and _classifier_logits() from RatioEstimatorTrainer and its subclasses, creating a fully modular Ratio Estimation architecture.

Motivation

Presently, each NRE variant embeds complex classification and contrastive atom-generation logic tightly inside its respective _loss overridden method. As outlined in Option (a) of #1241, trainers should only orchestrate the training loop, while composable protocol-compliant Objects handle mathematical formulations.

Key Changes

  1. NRELossStrategy Protocol: Added in sbi/inference/trainers/nre/nre_loss.py requiring a stateless call(neural_net, device, theta, x, **kwargs).
  2. Strategy Extractions:
    • AALRLoss (replaces NRE_A._loss)
    • SRELoss (replaces NRE_B._loss)
    • CNRELoss (replaces NRE_C._loss)
    • BNRELoss (replaces BNRE._loss)
  3. Decoupled Atom Generation: _classifier_logits() was extracted from nre_base.py into nre_loss.py as a tightly integrated helper utility.
  4. Trainer Refactoring:
    • RatioEstimatorTrainer delegates to _loss_strategy: Optional[NRELossStrategy] inside _get_losses.
    • NRE_A, NRE_B, NRE_C, and BNRE dynamically instantiate and inject their default strategies into train() natively.

Checklist

  • Extracted all NRE mathematical derivations safely into stateless strategies.
  • Code strictly mirrors the architectural decisions established in Phase 1 (NPE).
  • ruff check formatting verified.
  • Ready for maintainer review!

@Sumit6307
Copy link
Copy Markdown
Author

Sumit6307 commented Mar 25, 2026

@janfb Please Check this PR @janfb

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 26, 2026

❌ 15 Tests Failed:

Tests completed Failed Passed Skipped
1509 15 1494 40
View the top 3 failed test(s) by shortest run time
tests/posterior_parameters_test.py::test_if_warning_raised_for_deprecated_build_posterior_parameters[params1]
Stack Traces | 0s run time
@pytest.fixture(scope="session")
    def get_inference():
        def simulator(theta):
            return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
    
        num_dim = 3
        prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
        theta = prior.sample((300,))
        x = simulator(theta)
    
>       inference = NRE(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/posterior_parameters_test.py:39: TypeError
tests/posterior_parameters_test.py::test_build_posterior_works_on_default_args[build_posterior_arguments1]
Stack Traces | 0.001s run time
@pytest.fixture(scope="session")
    def get_inference():
        def simulator(theta):
            return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
    
        num_dim = 3
        prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
        theta = prior.sample((300,))
        x = simulator(theta)
    
>       inference = NRE(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/posterior_parameters_test.py:39: TypeError
tests/posterior_nn_test.py::test_batched_mcmc_sample_log_prob_with_different_x[sample_shape0-resample-1-NRE_C]
Stack Traces | 0.004s run time
snlre_method = <class 'sbi.inference.trainers.nre.nre_c.NRE_C'>
x_o_batch_dim = 1
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')
init_strategy = 'resample', sample_shape = (5,)

    @pytest.mark.mcmc
    @pytest.mark.parametrize("snlre_method", [NRE_C])  # it's independent of the method
    @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
    @pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
    @pytest.mark.parametrize(
        "sample_shape",
        (
            (5,),  # less than num_chains
            (4, 2),  # 2D batch
            (15,),  # not divisible by num_chains
        ),
    )
    def test_batched_mcmc_sample_log_prob_with_different_x(
        snlre_method: type,
        x_o_batch_dim: bool,
        mcmc_params_fast: MCMCPosteriorParameters,
        init_strategy: str,
        sample_shape: torch.Size,
    ):
        num_dim = 2
        num_simulations = 100
        num_chains = 10
    
        prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
        simulator = diagonal_linear_gaussian
    
>       inference = snlre_method(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_C with abstract method _loss

tests/posterior_nn_test.py:206: TypeError
tests/posterior_nn_test.py::test_batched_mcmc_sample_log_prob_with_different_x[sample_shape1-proposal-1-NRE_C]
Stack Traces | 0.004s run time
snlre_method = <class 'sbi.inference.trainers.nre.nre_c.NRE_C'>
x_o_batch_dim = 1
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')
init_strategy = 'proposal', sample_shape = (4, 2)

    @pytest.mark.mcmc
    @pytest.mark.parametrize("snlre_method", [NRE_C])  # it's independent of the method
    @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
    @pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
    @pytest.mark.parametrize(
        "sample_shape",
        (
            (5,),  # less than num_chains
            (4, 2),  # 2D batch
            (15,),  # not divisible by num_chains
        ),
    )
    def test_batched_mcmc_sample_log_prob_with_different_x(
        snlre_method: type,
        x_o_batch_dim: bool,
        mcmc_params_fast: MCMCPosteriorParameters,
        init_strategy: str,
        sample_shape: torch.Size,
    ):
        num_dim = 2
        num_simulations = 100
        num_chains = 10
    
        prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
        simulator = diagonal_linear_gaussian
    
>       inference = snlre_method(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_C with abstract method _loss

tests/posterior_nn_test.py:206: TypeError
tests/posterior_nn_test.py::test_importance_posterior_sample_log_prob[NRE_C]
Stack Traces | 0.004s run time
snplre_method = <class 'sbi.inference.trainers.nre.nre_c.NRE_C'>

    @pytest.mark.parametrize("snplre_method", [NPE_C, NLE_A, NRE_A, NRE_B, NRE_C])
    def test_importance_posterior_sample_log_prob(snplre_method: type):
        num_dim = 2
        num_simulations = 1000
    
        prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
        simulator = diagonal_linear_gaussian
    
>       inference = snplre_method(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_C with abstract method _loss

tests/posterior_nn_test.py:94: TypeError
tests/density_estimator_test.py::test_trainers_with_valid_and_invalid_estimator_builders[params0-NRE_B]
Stack Traces | 0.005s run time
params = {'classifier': <function build_classifier at 0x7f06dcf8e170>}
trainer_class = <class 'sbi.inference.trainers.nre.nre_b.NRE_B'>

    @pytest.mark.parametrize(
        ("params", "trainer_class"),
        [
            # Valid builders
            pytest.param(dict(classifier=build_classifier), NRE),
            pytest.param(dict(density_estimator=build_estimator), NPE),
            pytest.param(dict(density_estimator=build_estimator), NLE),
            pytest.param(dict(vf_estimator=build_vf_estimator_fmpe), FMPE),
            pytest.param(dict(vf_estimator=build_vf_estimator_npse), NPSE),
            # Invalid builders
            pytest.param(
                dict(classifier=build_estimator_missing_args),
                NRE,
                marks=pytest.mark.xfail(
                    raises=TypeError,
                    reason="Missing required parameters in classifier builder.",
                ),
            ),
            pytest.param(
                dict(density_estimator=build_estimator_missing_args),
                NPE,
                marks=pytest.mark.xfail(
                    raises=TypeError,
                    reason="Missing required parameters in density estimator builder.",
                ),
            ),
            pytest.param(
                dict(density_estimator=build_estimator_missing_args),
                NLE,
                marks=pytest.mark.xfail(
                    raises=TypeError,
                    reason="Missing required parameters in density estimator builder.",
                ),
            ),
            pytest.param(
                dict(vf_estimator=build_estimator_missing_args),
                FMPE,
                marks=pytest.mark.xfail(
                    raises=TypeError,
                    reason="Missing required parameters in vf_estimator builder.",
                ),
            ),
            pytest.param(
                dict(vf_estimator=build_estimator_missing_args),
                NPSE,
                marks=pytest.mark.xfail(
                    raises=TypeError,
                    reason="Missing required parameters in vf_estimator builder.",
                ),
            ),
            pytest.param(
                dict(classifier=build_estimator_missing_return),
                NRE,
                marks=pytest.mark.xfail(
                    raises=AssertionError,
                    reason="Missing return of RatioEstimator in classifier builder.",
                ),
            ),
            pytest.param(
                dict(density_estimator=build_estimator_missing_return),
                NPE,
                marks=pytest.mark.xfail(
                    raises=AttributeError,
                    reason="Missing return of type ConditionalEstimator"
                    " in density estimator builder.",
                ),
            ),
            pytest.param(
                dict(density_estimator=build_estimator_missing_return),
                NLE,
                marks=pytest.mark.xfail(
                    raises=AssertionError,
                    reason="Missing return of type ConditionalEstimator"
                    " in density estimator builder.",
                ),
            ),
            pytest.param(
                dict(vf_estimator=build_estimator_missing_return),
                FMPE,
                marks=pytest.mark.xfail(
                    raises=AttributeError,
                    reason="Missing return of type ConditionalVectorFieldEstimator"
                    " in density estimator builder.",
                ),
            ),
            pytest.param(
                dict(vf_estimator=build_estimator_missing_return),
                NPSE,
                marks=pytest.mark.xfail(
                    raises=AttributeError,
                    reason="Missing return of type ConditionalVectorFieldEstimator"
                    " in density estimator builder.",
                ),
            ),
        ],
    )
    def test_trainers_with_valid_and_invalid_estimator_builders(
        params: Dict, trainer_class: type[NeuralInference]
    ):
        """
        Test trainers classes work with valid classifier builders and fail
        with invalid ones.
    
        Args:
            params: Parameters passed to the trainer class.
            trainer_class: Trainer classes.
        """
    
        def simulator(theta):
            return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
    
        num_dim = 3
        prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
        theta = prior.sample((300,))
        x = simulator(theta)
    
>       inference = trainer_class(**params)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/density_estimator_test.py:748: TypeError
tests/inference/npe_loss_test.py::test_importance_weighted_loss_initialization_and_call
Stack Traces | 0.005s run time
theta = tensor([[ 0.6614,  0.2669],
        [ 0.0617,  0.6213],
        [-0.4519, -0.1661],
        [-1.5228,  0.3817],
        [-1.0276, -0.5631]])
x = tensor([[-0.8923, -0.0583],
        [-0.1955, -0.9656],
        [ 0.4224,  0.2673],
        [-0.4212, -0.5107],
        [-1.5727, -0.1232]])
masks = tensor([1., 1., 1., 1., 1.])
prior = MultivariateNormal(loc: torch.Size([2]), covariance_matrix: torch.Size([2, 2]))

    def test_importance_weighted_loss_initialization_and_call(theta, x, masks, prior):
>       neural_net = DummyEstimator()

tests/inference/npe_loss_test.py:114: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <[AttributeError("'DummyEstimator' object has no attribute '_modules'") raised in repr()] DummyEstimator object at 0x7f06d279b1f0>
is_mixture = False

    def __init__(self, is_mixture=False):
>       super().__init__()
E       TypeError: ConditionalDensityEstimator.__init__() missing 3 required positional arguments: 'net', 'input_shape', and 'condition_shape'

tests/inference/npe_loss_test.py:23: TypeError
tests/posterior_nn_test.py::test_batched_mcmc_sample_log_prob_with_different_x[sample_shape1-proposal-2-NRE_C]
Stack Traces | 0.005s run time
snlre_method = <class 'sbi.inference.trainers.nre.nre_c.NRE_C'>
x_o_batch_dim = 2
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')
init_strategy = 'proposal', sample_shape = (4, 2)

    @pytest.mark.mcmc
    @pytest.mark.parametrize("snlre_method", [NRE_C])  # it's independent of the method
    @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
    @pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
    @pytest.mark.parametrize(
        "sample_shape",
        (
            (5,),  # less than num_chains
            (4, 2),  # 2D batch
            (15,),  # not divisible by num_chains
        ),
    )
    def test_batched_mcmc_sample_log_prob_with_different_x(
        snlre_method: type,
        x_o_batch_dim: bool,
        mcmc_params_fast: MCMCPosteriorParameters,
        init_strategy: str,
        sample_shape: torch.Size,
    ):
        num_dim = 2
        num_simulations = 100
        num_chains = 10
    
        prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
        simulator = diagonal_linear_gaussian
    
>       inference = snlre_method(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_C with abstract method _loss

tests/posterior_nn_test.py:206: TypeError
tests/posterior_nn_test.py::test_batched_mcmc_sample_log_prob_with_different_x[sample_shape2-proposal-1-NRE_C]
Stack Traces | 0.005s run time
snlre_method = <class 'sbi.inference.trainers.nre.nre_c.NRE_C'>
x_o_batch_dim = 1
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')
init_strategy = 'proposal', sample_shape = (15,)

    @pytest.mark.mcmc
    @pytest.mark.parametrize("snlre_method", [NRE_C])  # it's independent of the method
    @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2))
    @pytest.mark.parametrize("init_strategy", ["proposal", "resample"])
    @pytest.mark.parametrize(
        "sample_shape",
        (
            (5,),  # less than num_chains
            (4, 2),  # 2D batch
            (15,),  # not divisible by num_chains
        ),
    )
    def test_batched_mcmc_sample_log_prob_with_different_x(
        snlre_method: type,
        x_o_batch_dim: bool,
        mcmc_params_fast: MCMCPosteriorParameters,
        init_strategy: str,
        sample_shape: torch.Size,
    ):
        num_dim = 2
        num_simulations = 100
        num_chains = 10
    
        prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
        simulator = diagonal_linear_gaussian
    
>       inference = snlre_method(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_C with abstract method _loss

tests/posterior_nn_test.py:206: TypeError
tests/linearGaussian_snre_test.py::test_api_nre_multiple_trials_and_rounds_map[NRE_B-1]
Stack Traces | 0.007s run time
num_dim = 1, nre_method = <class 'sbi.inference.trainers.nre.nre_b.NRE_B'>
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')
num_rounds = 2, num_samples = 12, num_simulations = 100

    @pytest.mark.mcmc
    @pytest.mark.parametrize("num_dim", (1,))  # dim 3 is tested below.
    @pytest.mark.parametrize("nre_method", (NRE_B, NRE_C))
    def test_api_nre_multiple_trials_and_rounds_map(
        num_dim: int,
        nre_method: RatioEstimator,
        mcmc_params_fast: MCMCPosteriorParameters,
        num_rounds: int = 2,
        num_samples: int = 12,
        num_simulations: int = 100,
    ):
        """Test NRE API with 2 rounds, different priors num trials and MAP."""
        prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
    
        simulator = diagonal_linear_gaussian
>       inference = nre_method(prior=prior, classifier="mlp", show_progress_bars=False)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/linearGaussian_snre_test.py:56: TypeError
tests/linearGaussian_snre_test.py::test_c2st_sre_on_linearGaussian[NRE_B]
Stack Traces | 0.007s run time
nre_method = <class 'sbi.inference.trainers.nre.nre_b.NRE_B'>
mcmc_params_accurate = MCMCPosteriorParameters(method='slice_np_vectorized', thin=2, warmup_steps=50, num_chains=20, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')

    @pytest.mark.mcmc
    @pytest.mark.parametrize("nre_method", (NRE_B, NRE_C))
    def test_c2st_sre_on_linearGaussian(
        nre_method: RatioEstimator, mcmc_params_accurate: MCMCPosteriorParameters
    ):
        """Test whether SRE infers well a simple example with available ground truth.
    
        This example has different number of parameters theta than number of x. This test
        also acts as the only functional test for SRE not marked as slow.
    
        """
    
        theta_dim = 3
        x_dim = 2
        discard_dims = theta_dim - x_dim
        num_samples = 500
        num_simulations = 2100
    
        likelihood_shift = -1.0 * ones(
            x_dim
        )  # likelihood_mean will be likelihood_shift+theta
        likelihood_cov = 0.3 * eye(x_dim)
    
        prior_mean = zeros(theta_dim)
        prior_cov = eye(theta_dim)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    
        def simulator(theta):
            return linear_gaussian(
                theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims
            )
    
>       inference = nre_method(classifier="resnet", show_progress_bars=False)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/linearGaussian_snre_test.py:107: TypeError
tests/linearGaussian_snre_test.py::test_c2st_sre_on_linearGaussian[NRE_C]
Stack Traces | 0.007s run time
nre_method = <class 'sbi.inference.trainers.nre.nre_c.NRE_C'>
mcmc_params_accurate = MCMCPosteriorParameters(method='slice_np_vectorized', thin=2, warmup_steps=50, num_chains=20, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')

    @pytest.mark.mcmc
    @pytest.mark.parametrize("nre_method", (NRE_B, NRE_C))
    def test_c2st_sre_on_linearGaussian(
        nre_method: RatioEstimator, mcmc_params_accurate: MCMCPosteriorParameters
    ):
        """Test whether SRE infers well a simple example with available ground truth.
    
        This example has different number of parameters theta than number of x. This test
        also acts as the only functional test for SRE not marked as slow.
    
        """
    
        theta_dim = 3
        x_dim = 2
        discard_dims = theta_dim - x_dim
        num_samples = 500
        num_simulations = 2100
    
        likelihood_shift = -1.0 * ones(
            x_dim
        )  # likelihood_mean will be likelihood_shift+theta
        likelihood_cov = 0.3 * eye(x_dim)
    
        prior_mean = zeros(theta_dim)
        prior_cov = eye(theta_dim)
        prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
    
        def simulator(theta):
            return linear_gaussian(
                theta, likelihood_shift, likelihood_cov, num_discarded_dims=discard_dims
            )
    
>       inference = nre_method(classifier="resnet", show_progress_bars=False)
E       TypeError: Can't instantiate abstract class NRE_C with abstract method _loss

tests/linearGaussian_snre_test.py:107: TypeError
tests/pyroutils_test.py::test_estimator_distribution_basic_properties[NRE_B-RatioEstimatorDistribution-5]
Stack Traces | 0.007s run time
trainer_cls = <class 'sbi.inference.trainers.nre.nre_b.NRE_B'>
distribution_cls = <class 'sbi.utils.pyroutils.RatioEstimatorDistribution'>
num_dim = 5, num_simulations = 100

    @pytest.mark.parametrize("num_dim", [2, 5])
    @pytest.mark.parametrize(
        "trainer_cls, distribution_cls",
        [
            (NLE, ConditionalDensityEstimatorDistribution),
            (NPE, ConditionalDensityEstimatorDistribution),
            (NRE, RatioEstimatorDistribution),
        ],
    )
    def test_estimator_distribution_basic_properties(
        trainer_cls,
        distribution_cls,
        num_dim,
        num_simulations: int = 100,
    ):
        """Test basic properties of the estimator distribution."""
        if num_dim == 0:
            prior = torch.distributions.Normal(0.0, 1.0)
        else:
            prior = torch.distributions.MultivariateNormal(
                loc=torch.zeros(num_dim), covariance_matrix=torch.diag(torch.ones(num_dim))
            )
        theta = prior.sample(torch.Size([num_simulations]))
        x = torch.distributions.Normal(theta, 1.0).sample()
>       trainer = trainer_cls(prior=prior).append_simulations(theta=theta, x=x)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/pyroutils_test.py:246: TypeError
tests/save_and_load_test.py::test_picklability[NRE_B-VIPosteriorParameters]
Stack Traces | 0.008s run time
inference_method = <class 'sbi.inference.trainers.nre.nre_b.NRE_B'>
posterior_parameters = <class 'sbi.inference.posteriors.posterior_parameters.VIPosteriorParameters'>
tmp_path = PosixPath('.../pytest-0/popen-gw0/test_picklability_NRE_B_VIPost0')
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')

    @pytest.mark.parametrize(
        "inference_method, posterior_parameters",
        (
            (NPE, DirectPosteriorParameters),
            (NPSE, VectorFieldPosteriorParameters),
            (FMPE, VectorFieldPosteriorParameters),
            pytest.param(NLE, MCMCPosteriorParameters, marks=pytest.mark.mcmc),
            pytest.param(NRE, MCMCPosteriorParameters, marks=pytest.mark.mcmc),
            pytest.param(NRE, VIPosteriorParameters, marks=pytest.mark.mcmc),
            (NRE, RejectionPosteriorParameters),
        ),
    )
    def test_picklability(
        inference_method,
        posterior_parameters,
        tmp_path,
        mcmc_params_fast: MCMCPosteriorParameters,
    ):
        num_dim = 2
        prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
        x_o = torch.zeros(1, num_dim)
    
        theta = prior.sample((500,))
        x = theta + 1.0 + torch.randn_like(theta) * 0.1
    
>       inference = inference_method(prior=prior)
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/save_and_load_test.py:46: TypeError
tests/embedding_net_test.py::test_embedding_net_api[mlp-2-NRE]
Stack Traces | 0.009s run time
method = 'NRE', num_dim = 2, embedding_net = 'mlp'
mcmc_params_fast = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')

    @pytest.mark.parametrize(
        "method",
        [
            pytest.param("NPE"),
            pytest.param("NLE", marks=pytest.mark.mcmc),
            pytest.param("NRE", marks=pytest.mark.mcmc),
        ],
    )
    @pytest.mark.parametrize("num_dim", [1, 2])
    @pytest.mark.parametrize("embedding_net", ["mlp"])
    def test_embedding_net_api(
        method, num_dim: int, embedding_net: str, mcmc_params_fast: MCMCPosteriorParameters
    ):
        """Tests the API when using a preconfigured embedding net."""
        model_map = {"NPE": "maf", "NLE": "maf", "NRE": "resnet"}
        model = model_map[method]
        x_o = zeros(1, num_dim)
    
        # likelihood_mean will be likelihood_shift+theta
        likelihood_shift = -1.0 * ones(num_dim)
        likelihood_cov = 0.3 * eye(num_dim)
    
        prior = utils.BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))
    
        def simulator(theta):
            return linear_gaussian(theta, likelihood_shift, likelihood_cov)
    
        if embedding_net == "mlp":
            embedding = FCEmbedding(input_dim=num_dim)
        else:
            raise NameError(f"{embedding_net} not supported.")
    
        _test_embedding_forward_pass(embedding, (num_dim,), 20)
    
        posterior_parameters = mcmc_params_fast if method in ("NLE", "NRE") else None
>       _train_and_infer_with_embedding(
            prior,
            x_o,
            simulator,
            embedding,
            model,
            method,
            posterior_parameters=posterior_parameters,
        )

tests/embedding_net_test.py:77: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

prior = BoxUniform(Uniform(low: torch.Size([2]), high: torch.Size([2])), 1)
xo = tensor([[0., 0.]])
simulator = <function test_embedding_net_api.<locals>.simulator at 0x7f06d8b70a60>
net = FCEmbedding(
  (net): Sequential(
    (0): Linear(in_features=2, out_features=50, bias=True)
    (1): Identity()
    (2): ReLU()
    (3): Linear(in_features=50, out_features=20, bias=True)
    (4): ReLU()
  )
)
model = 'resnet', method = 'NRE'
posterior_parameters = MCMCPosteriorParameters(method='slice_np_vectorized', thin=1, warmup_steps=1, num_chains=1, init_strategy='resample', init_strategy_parameters=None, num_workers=1, mp_context='spawn')

    def _train_and_infer_with_embedding(
        prior: utils.BoxUniform | MultivariateNormal,
        xo: Tensor,
        simulator: Callable,
        net: nn.Module,
        model: str,
        method: str,
        posterior_parameters: MCMCPosteriorParameters | None = None,
    ):
        """Train a small inference pipeline and smoke test posterior sampling."""
    
        builders = {"NPE": posterior_nn, "NLE": likelihood_nn, "NRE": classifier_nn}
        trainers = {"NPE": NPE, "NLE": NLE, "NRE": NRE}
    
        num_simulations = 100
        theta = prior.sample(torch.Size((num_simulations,)))
        x = simulator(theta)
    
        net_key = "embedding_net_x" if method == "NRE" else "embedding_net"
        estimator = builders[method](model=model, **{net_key: net})
    
        trainer_key = "classifier" if method == "NRE" else "density_estimator"
>       trainer = trainers[method](
            prior,
            **{trainer_key: estimator},
            show_progress_bars=False,
        )
E       TypeError: Can't instantiate abstract class NRE_B with abstract method _loss

tests/embedding_net_test.py:374: TypeError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant