Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions eks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,16 @@ def compute_stats(data_x, data_y, data_lh):
conf_per_keypoint = jnp.sum(data_lh, axis=0)
mean_conf_per_keypoint = conf_per_keypoint / n_models

var_x = jnp.nanvar(data_x, axis=0) / mean_conf_per_keypoint if var_mode in [
'conf_weighted_var', 'confidence_weighted_var'] else jnp.nanvar(data_x, axis=0)
var_y = jnp.nanvar(data_y, axis=0) / mean_conf_per_keypoint if var_mode in [
'conf_weighted_var', 'confidence_weighted_var'] else jnp.nanvar(data_y, axis=0)
if n_models == 1:
single_var = 1.0 / jnp.maximum(mean_conf_per_keypoint, 0.05)
var_x = single_var
var_y = single_var
elif var_mode in ['conf_weighted_var', 'confidence_weighted_var']:
var_x = jnp.nanvar(data_x, axis=0) / mean_conf_per_keypoint
var_y = jnp.nanvar(data_y, axis=0) / mean_conf_per_keypoint
else:
var_x = jnp.nanvar(data_x, axis=0)
var_y = jnp.nanvar(data_y, axis=0)

# Replace NaNs in variance with chosen value
var_x = jnp.nan_to_num(var_x, nan=nan_replacement)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,54 @@ def test_jax_ensemble_nan_variance():
assert np.all(var_y == nan_replacement), "NaNs in var_y were not replaced"


def test_jax_ensemble_single_network():
"""
Test that ensemble() produces valid (positive, finite) variance estimates when n_models=1.

With a single network there is no cross-model spread to measure, so nanvar returns 0.
The function must fall back to a likelihood-based variance proxy so that downstream
Kalman filtering receives non-zero observation noise.
"""
n_models = 1
n_cameras = 2
n_frames = 10
n_keypoints = 3

rng = np.random.default_rng(0)
data = rng.random((n_models, n_cameras, n_frames, n_keypoints, 3))
# Likelihoods in (0, 1) — not all ones, so a likelihood-based fallback can be non-trivial
data[..., 2] = rng.uniform(0.5, 1.0, size=(n_models, n_cameras, n_frames, n_keypoints))

marker_array = MarkerArray(data, data_fields=["x", "y", "likelihood"])

for avg_mode in ("median", "mean"):
for var_mode in ("var", "confidence_weighted_var"):
result = ensemble(marker_array, avg_mode=avg_mode, var_mode=var_mode)

# Shape must still be correct
expected_shape = (1, n_cameras, n_frames, n_keypoints, 5)
assert result.array.shape == expected_shape, (
f"[{avg_mode}, {var_mode}] Expected shape {expected_shape}, "
f"got {result.array.shape}"
)

var_x = np.array(result.array[..., 2])
var_y = np.array(result.array[..., 3])

# All variance values must be finite
assert np.all(np.isfinite(var_x)), \
f"[{avg_mode}, {var_mode}] var_x contains non-finite values"
assert np.all(np.isfinite(var_y)), \
f"[{avg_mode}, {var_mode}] var_y contains non-finite values"

# Variance must be strictly positive — zero variance from nanvar(single sample)
# is not acceptable as observation noise in the Kalman filter
assert np.all(var_x > 0), \
f"[{avg_mode}, {var_mode}] var_x is zero (nanvar of single sample fallback needed)"
assert np.all(var_y > 0), \
f"[{avg_mode}, {var_mode}] var_y is zero (nanvar of single sample fallback needed)"


def test_jax_ensemble_zero_likelihood():
"""Test that zero likelihood does not cause NaNs in variance calculations."""
n_models = 3
Expand Down