>>> import jax.numpy as jnp
>>> import numpyro
>>> from pyrenew.deterministic import DeterministicVariable
>>>
>>> val = jnp.array([1, 2, 3])
>>> det_var = DeterministicVariable("my var", val)
>>> delta_dist = numpyro.distributions.Delta(val)
>>>
>>> delta_dist.sample(key = None)
Array([1, 2, 3], dtype=int32)
>>> det_var.sample()
Array([1, 2, 3], dtype=int32)
I think we can use
numpyro.distributions.Deltain all cases where we currently use our ownDeterministicVariable, which would reduce the complexity of the codebase.