Skip to content
9 changes: 6 additions & 3 deletions libs/schemas/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@
from __future__ import annotations
from pydantic import BaseModel, Field
from enum import Enum

from typing import Optional

class TrackState(str, Enum):
BORN = "BORN" # first frame this track_id appeared
ACTIVE = "ACTIVE" # confirmed, currently visible
LOST = "LOST" # not seen for up to max_age frames
DEAD = "DEAD" # expired β€” will not be reassigned


class TrajectoryPoint(BaseModel):
"""A single spatial-temporal coordinate snapshot representing an object's historical location."""
x: float
x: float
y: float
frame_id: int

interpolated: bool = False
w: Optional[float] = None
h: Optional[float] = None

class TrackedObject(BaseModel):
track_id: int = Field(..., description="Persistent ID across frames")
Expand Down
8 changes: 4 additions & 4 deletions services/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import time
from typing import Optional

import numpy as np

from libs.observability.metrics import redis_write_latency
from libs.schemas.tracking import TrackLifecycleEvent, TrackState
from services.tracking.cross_camera_reid import CrossCameraReID
Expand Down Expand Up @@ -70,7 +68,7 @@ def __init__(self, redis_client, reid: CrossCameraReID) -> None:
def handle_lifecycle_event(
self,
event: TrackLifecycleEvent,
embedding: Optional[np.ndarray] = None,
embedding: Optional["numpy.ndarray"] = None,
) -> Optional[str]:
"""
Process a single lifecycle event and return the assigned global_id.
Expand Down Expand Up @@ -117,7 +115,7 @@ def get_identity(self, global_id: str) -> list[str]:
def _handle_born(
self,
event: TrackLifecycleEvent,
embedding: Optional[np.ndarray],
embedding: Optional["numpy.ndarray"],
) -> str:
if embedding is not None:
reid_result = self._reid.match_or_create(
Expand Down Expand Up @@ -162,6 +160,8 @@ def _handle_born(
def _handle_lost(
self,
event: TrackLifecycleEvent,
embedding: Optional["numpy.ndarray"],
) -> Optional[str]:
embedding: Optional[np.ndarray],
) -> tuple[Optional[str], bool]:
record = self._load_record(event.camera_id, event.track_id)
Expand Down
103 changes: 98 additions & 5 deletions services/tracking/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,25 @@ def __init__(
camera_id: str = "cam_01",
event_logger: TrackEventLogger | None = None,
reid_similarity_threshold: float = 0.85,
max_interpolation_gap: int = 10, # Added with a sensible default
) -> None:
"""Initialize the tracker with DeepSort hyperparameters and interpolation constraints.

Args:
fps: Frame rate of the video source.
max_age: Maximum frames to keep a lost track alive before dropping it.
n_init: Number of consecutive frames needed to confirm a track.
max_cosine_distance: Maximum threshold for visual appearance feature matching.
camera_id: Unique identifier string for the source camera.
event_logger: Optional logger interface for tracking state lifecycle events.
reid_similarity_threshold: Minimum confidence needed to reconnect an ID via ReID.
max_interpolation_gap: Maximum frame gap size allowed to fill missing trajectories.
"""
self.fps = fps
self.camera_id = camera_id
self.max_age = max_age # NEW
self.REID_SIMILARITY_THRESHOLD = reid_similarity_threshold
self.max_interpolation_gap = max_interpolation_gap # Fixed missing attribute

self._tracker = DeepSort(
max_age=max_age,
Expand Down Expand Up @@ -173,15 +187,62 @@ def update(
self._emit_lifecycle(TrackState.BORN, tid, zones, 0.0)
logger.info(f"Track BORN: #{tid} in zones={zones}")

# ── Dwell time ────────────────────────────────────────────────
# ── Base Setup & Gap Calculation ──────────────────────────────
prev = self._active_tracks.get(tid)
dwell_frames = (prev.dwell_time_frames + 1) if prev else 1
prev_traj = prev.trajectory if prev else []

# Compute gap_frames early so both Dwell Time and Trajectory can use it
gap_frames = max(0, self._frame_id - prev.last_seen_frame - 1) if prev is not None else 0

# ── Dwell time ────────────────────────────────────────────────
if prev:
# Add historic frames, the current frame, and the occlusion gap
dwell_frames = prev.dwell_time_frames + 1 + gap_frames
else:
dwell_frames = 1

dwell_secs = dwell_frames / self.fps

# ── Trajectory ────────────────────────────────────────────────
prev_traj = prev.trajectory if prev else []
interpolated_points = []
max_gap = self.max_interpolation_gap # <-- Replaced self.config string access

if prev is not None and 0 < gap_frames <= max_gap:
# Added guard condition below to prevent IndexError crashes
if prev.trajectory:
last_pos = {"x": prev.trajectory[-1].x, "y": prev.trajectory[-1].y}
else:
last_pos = {"x": cx, "y": cy} # Fallback to current center coordinates

new_pos = {"x": cx, "y": cy}

# Check if previous data contains w and h bounding box metrics
if hasattr(prev, 'bbox') and len(prev.bbox) == 4:
# Calculate old width and height from bbox: [x1, y1, x2, y2]
last_pos["w"] = prev.bbox[2] - prev.bbox[0]
last_pos["h"] = prev.bbox[3] - prev.bbox[1]
# Current width and height
new_pos["w"] = x2 - x1
new_pos["h"] = y2 - y1

# Synthesize intermediate points and wrap them into TrajectoryPoint instances
interpolated_points = [
TrajectoryPoint(
x=p["x"],
y=p["y"],
frame_id=p["frame_id"],
interpolated=True,
w=p.get("w"),
h=p.get("h")
)
for p in _interpolate_trajectory(last_pos, new_pos, gap_frames, prev.last_seen_frame + 1)
]

# Generate the current frame real point
new_point = TrajectoryPoint(x=cx, y=cy, frame_id=self._frame_id)
trajectory = (prev_traj + [new_point])[-self.MAX_TRAJECTORY_LEN :]

# Merge old history, calculated mid-gap points, and current point cleanly
trajectory = (prev_traj + interpolated_points + [new_point])[-self.MAX_TRAJECTORY_LEN :]

obj = TrackedObject(
track_id=tid,
Expand Down Expand Up @@ -241,6 +302,7 @@ def update(
del self._active_tracks[tid]
self._active_embeddings.pop(tid, None)
logger.info(f"Track DEAD: #{tid} after {prev_obj.dwell_time_seconds:.1f}s")

# ── Cleanup expired ReID embeddings ──────────────────
expired_ids = [
tid
Expand Down Expand Up @@ -360,6 +422,37 @@ def main() -> None:
writer.release()
cv2.destroyAllWindows()

def _interpolate_trajectory(
last_pos: dict,
new_pos: dict,
gap_frames: int,
start_frame_id: int
) -> list:
"""Fills trajectory gaps using linear interpolation for temporary missed detections."""
if gap_frames <= 0:
return []

interpolated_points = []
total_steps = gap_frames + 1

x_step = (new_pos['x'] - last_pos['x']) / total_steps
y_step = (new_pos['y'] - last_pos['y']) / total_steps

for i in range(1, gap_frames + 1):
point = {
"frame_id": start_frame_id + (i - 1),
"x": round(last_pos['x'] + (x_step * i), 2),
"y": round(last_pos['y'] + (y_step * i), 2),
"interpolated": True
}

if all(k in last_pos and k in new_pos for k in ('w', 'h')):
point['w'] = round(last_pos['w'] + (((new_pos['w'] - last_pos['w']) / total_steps) * i), 2)
point['h'] = round(last_pos['h'] + (((new_pos['h'] - last_pos['h']) / total_steps) * i), 2)

interpolated_points.append(point)

return interpolated_points

if __name__ == "__main__":
main()
main()
44 changes: 43 additions & 1 deletion tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from libs.schemas.detection import DetectionFrameSchema, DetectionSchema, BoundingBox
from libs.schemas.tracking import TrackedFrame, TrackedObject, TrackState, TrajectoryPoint
from services.tracking.tracker import _interpolate_trajectory


# ── Schema unit tests (no tracker needed) ────────────────────────────────────
Expand Down Expand Up @@ -484,4 +485,45 @@ def test_reid_expires_after_max_age(MockDeepSort):
)

# Should NOT restore old ID
assert result.tracks[0].track_id == 99
assert result.tracks[0].track_id == 99


def test_interpolate_trajectory_success():
"""Test standard linear interpolation for a 3-frame gap including width and height scaling."""
last_pos = {"x": 10.0, "y": 20.0, "w": 50.0, "h": 50.0}
new_pos = {"x": 50.0, "y": 60.0, "w": 90.0, "h": 90.0}
gap_frames = 3
start_frame = 101

result = _interpolate_trajectory(last_pos, new_pos, gap_frames, start_frame)

assert len(result) == 3

# Assert step progression and metadata across all items
expected_values = [
{"frame_id": 101, "x": 20.0, "y": 30.0, "w": 60.0, "h": 60.0},
{"frame_id": 102, "x": 30.0, "y": 40.0, "w": 70.0, "h": 70.0},
{"frame_id": 103, "x": 40.0, "y": 50.0, "w": 80.0, "h": 80.0},
]

for idx, expected in enumerate(expected_values):
assert result[idx]["frame_id"] == expected["frame_id"]
assert result[idx]["interpolated"] is True
assert result[idx]["x"] == expected["x"]
assert result[idx]["y"] == expected["y"]
assert result[idx]["w"] == expected["w"]
assert result[idx]["h"] == expected["h"]

def test_interpolate_trajectory_no_gap():
last_pos = {"x": 10, "y": 20}
new_pos = {"x": 20, "y": 30}
assert _interpolate_trajectory(last_pos, new_pos, 0, 100) == []

def test_interpolate_trajectory_no_movement():
last_pos = {"x": 100.0, "y": 100.0}
new_pos = {"x": 100.0, "y": 100.0}
gap_frames = 2
start_frame = 50
result = _interpolate_trajectory(last_pos, new_pos, gap_frames, start_frame)
assert len(result) == 2
assert result[0]["x"] == 100.0
Loading