Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[flake8]
max-line-length = 99
# E203: black conflict
# E701: black conflict
# F821: lot of issues regarding type annotations
# F722: syntax error in forward annotations (jaxtyping, etc.)
extend-ignore = E203,E701,F821,F722
46 changes: 46 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Lint

on:
pull_request:
types: [opened, synchronize, reopened]

# To cancel a currently running workflow from the same PR, branch or tag when a new workflow is triggered
# https://stackoverflow.com/a/72408109
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install linters
run: pip install autopep8 flake8

- name: Check formatting with autopep8
run: autopep8 --diff --recursive --exit-code eks tests
# Reads config from [tool.autopep8] in pyproject.toml

- name: Lint with flake8 (critical errors only)
run: flake8 eks tests --select=E9,F63,F7,F82
# Reads config from .flake8 file

- name: Show fix instructions if formatting needed
if: failure()
run: |
echo ""
echo "Linting failed!"
echo ""
echo "To fix formatting issues locally, run:"
echo " autopep8 --in-place --recursive eks tests"
echo ""
echo "To check for flake8 errors locally, run:"
echo " flake8 eks tests --select=E9,F63,F7,F82"
echo ""
File renamed without changes.
36 changes: 24 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,30 @@ implementations, including fast smoothing parameter auto-tuning using GPU-driven
[Here](docs/singlecam_overview.md) is a detailed overview of the workflow.

### Multi-camera datasets
The `multicam_example.py` script demonstrates how to run the EKS code for multi-camera
setups where the pose predictions for a given model are all stored a separate csv file per camera.
We provide example data in the `data/mirror-mouse-separate` directory inside this repo,
for a two-view video of a mouse with cameras named `top` and `bot`.
To run the EKS on the example data provided, execute the following command from inside this repo:
The `multicam_example.py` script supports two modes for multi-camera setups,
depending on whether camera calibration information is available.
In both cases, pose predictions should be stored a separate csv file per camera.

#### Without calibration (linear EKS)
We provide example data in `data/mirror-mouse-separate`,
containing two-view mouse video with cameras named `top` and `bot`.
To run linear EKS on this data , execute the following command from inside this repo:

```console
python scripts/multicam_example.py --input-dir ./data/mirror-mouse-separate --bodypart-list paw1LH paw2LF paw3RF paw4RH --camera-names top bot
```

#### With calibration (nonlinear EKS)

If camera calibration information is available, you can run a nonlinear version of EKS.
Calibration data must be stored in `.toml` files using the [Anipose](https://anipose.readthedocs.io/) format.
We provide example data in `data/fly`, containing multi-view fly video with cameras named
`Cam-A`, `Cam-B`, and `Cam-C`, along with a corresponding `calibration.toml` file.
To run nonlinear EKS on this data, execute the following command from inside this repo:

```console
python scripts/multicam_example.py --input-dir ./data/fly --bodypart-list L1A L1B --camera-names Cam-A Cam-B Cam-C --calibration ./data/fly/calibration.toml
```

### Mirrored multi-camera datasets
The `mirrored_multicam_example.py` script demonstrates how to run the EKS code for multi-camera
Expand Down Expand Up @@ -140,10 +155,7 @@ python scripts/ibl_paw_multiview_example.py --input-dir ./data/ibl-paw

### Authors

Cole Hurwitz

Keemin Lee

Amol Pasarkar

Matt Whiteway
* [Cole Hurwitz](https://github.com/colehurwitz)
* [Keemin Lee](https://github.com/keeminlee)
* [Amol Pasarkar](https://github.com/apasarkar)
* [Matt Whiteway](https://github.com/themattinthehatt)
6 changes: 3 additions & 3 deletions eks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib.metadata
from typing import Any

from eks import *
# from eks import *


# Hacky way to get version from pypackage.toml.
Expand Down Expand Up @@ -28,13 +29,12 @@ def __get_package_version() -> str:
# This works in a development environment where the
# package has not been installed from a distribution.
import warnings
from pathlib import Path

import toml

warnings.warn(
"ensemble-kalman-smoother not pip-installed, getting version from pyproject.toml."
)

pyproject_toml_file = Path(__file__).parent.parent / "pyproject.toml"
__package_version = toml.load(pyproject_toml_file)["project"]["version"]

Expand Down
14 changes: 9 additions & 5 deletions eks/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, Tuple, Union

import jax
import numpy as np
import optax
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, extended_kalman_filter, \
extended_kalman_smoother
from jax import numpy as jnp, jit, value_and_grad, lax
from dynamax.nonlinear_gaussian_ssm import (
ParamsNLGSSM,
extended_kalman_filter,
extended_kalman_smoother,
)
from jax import jit, lax
from jax import numpy as jnp
from jax import value_and_grad
from typeguard import typechecked
from typing import Literal, Union, List, Tuple

from eks.marker_array import MarkerArray
from eks.utils import build_R_from_vars, crop_frames, crop_R
Expand Down
10 changes: 5 additions & 5 deletions eks/ibl_pupil_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def run_pupil_kalman_smoother(
A = jnp.diag(jnp.array([s_d_j, s_c_j, s_c_j]))
Q = jnp.diag(jnp.array([
jnp.asarray(diameters_var) * (1.0 - s_d_j**2),
jnp.asarray(x_var) * (1.0 - s_c_j**2),
jnp.asarray(y_var) * (1.0 - s_c_j**2),
jnp.asarray(x_var) * (1.0 - s_c_j**2),
jnp.asarray(y_var) * (1.0 - s_c_j**2),
]))

f_fn = (lambda x: A @ x)
Expand Down Expand Up @@ -467,7 +467,7 @@ def _to_stable_s(u, eps=1e-3):

# Cropping for loss (host-side), then back to JAX
ys_np = np.asarray(ys)
R_np = np.asarray(R)
R_np = np.asarray(R)
if s_frames and len(s_frames) > 0:
y_loss = jnp.asarray(crop_frames(ys_np, s_frames)) # (T', 8)
R_loss = jnp.asarray(crop_R(R_np, s_frames)) # (T', 8, 8)
Expand All @@ -487,8 +487,8 @@ def _nll_from_u(u: jnp.ndarray) -> jnp.ndarray:
A = jnp.diag(jnp.array([s_d, s_c, s_c]))
Q = jnp.diag(jnp.array([
jnp.asarray(diameters_var) * (1.0 - s_d**2),
jnp.asarray(x_var) * (1.0 - s_c**2),
jnp.asarray(y_var) * (1.0 - s_c**2),
jnp.asarray(x_var) * (1.0 - s_c**2),
jnp.asarray(y_var) * (1.0 - s_c**2),
]))
params = _params_linear(m0, S0, A, Q, R_loss, C)
post = extended_kalman_filter(params, y_loss)
Expand Down
8 changes: 4 additions & 4 deletions eks/marker_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def stack(others: List["MarkerArray"], axis: str) -> "MarkerArray":
assert isinstance(other, MarkerArray), \
"All elements in 'others' must be MarkerArray instances."
assert reference.array.shape[:reference.axis_map[axis]] + \
reference.array.shape[reference.axis_map[axis] + 1:] \
== other.array.shape[:reference.axis_map[axis]] + \
other.array.shape[reference.axis_map[axis] + 1:], \
f"Shape mismatch: Cannot stack along '{axis}' due to differing dimensions."
reference.array.shape[reference.axis_map[axis] + 1:] \
== other.array.shape[:reference.axis_map[axis]] + \
other.array.shape[reference.axis_map[axis] + 1:], \
f"Shape mismatch: Cannot stack along '{axis}' due to differing dimensions."

# Stack all arrays along the specified axis
stacked_array = np.concatenate([other.array for other in others],
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ requires-python = ">=3.10"
authors = [
{ name = "Cole Hurwitz"},
{ name = "Keemin Lee"},
{ name = "Matt Whiteaway" },
{ name = "Matt Whiteway" },
]
maintainers = [
{ name = "Matt Whiteway"},
Expand Down Expand Up @@ -44,6 +44,7 @@ dependencies = [
"scikit-learn",
"scipy (>=1.2.0)",
"sleap_io",
"toml",
"tqdm",
"typeguard",
"typing",
Expand All @@ -65,8 +66,9 @@ python = ">=3.10,<3.13"

[project.optional-dependencies]
dev = [
"black",
"autopep8",
"flake8",
"ipython", # dumb dependency issue in fastprogress, installing here so CI doesn't fail
"isort",
"pytest",
]
Expand Down
3 changes: 2 additions & 1 deletion tests/scripts/test_ibl_paw_multicam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ def test_ibl_paw_multicam_example_defaults(run_script, tmpdir, pytestconfig):
output_dir=tmpdir,
)


def test_ibl_paw_multicam_example_fixed_smooth_param(run_script, tmpdir, pytestconfig):
run_script(
script_file=str(pytestconfig.rootpath / 'scripts' / 'ibl_paw_multiview_example.py'),
input_dir=str(pytestconfig.rootpath / 'data' / 'ibl-paw'),
output_dir=tmpdir,
s=10
)
)
2 changes: 1 addition & 1 deletion tests/scripts/test_multicam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def test_multicam_example_fixed_smooth_param_nonlinear(run_script, tmpdir, pytes
camera_names=['Cam-A', 'Cam-B', 'Cam-C'],
calibration=str(pytestconfig.rootpath / 'data' / 'fly' / 'calibration.toml'),
s=10,
)
)
3 changes: 3 additions & 0 deletions tests/test_multicam_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def test_projection_jit_smoke():
# ---------- make_projection_from_camgroup ----------
class _MockCam:
"""Minimal camera mock exposing rotation/translation/K/dist getters."""

def __init__(self, rotation, translation, K, dist):
self._rotation = rotation
self._translation = translation
Expand All @@ -499,6 +500,7 @@ def get_distortions(self): return self._dist

class _MockCamGroup:
"""Mock camgroup that also provides a dummy triangulate(xy_views) API."""

def __init__(self, cameras):
self.cameras = cameras

Expand Down Expand Up @@ -532,6 +534,7 @@ def test_make_projection_from_camgroup_single_point_concat_order():
# ---------- triangulate_3d_models ----------
class _MockMarkerArray:
"""Minimal stand-in for MarkerArray exposing .shape and .get_array()."""

def __init__(self, arr):
self._arr = arr

Expand Down