diff --git a/libs/schemas/tracking.py b/libs/schemas/tracking.py index fd065b9..212144f 100644 --- a/libs/schemas/tracking.py +++ b/libs/schemas/tracking.py @@ -5,7 +5,7 @@ from __future__ import annotations from pydantic import BaseModel, Field from enum import Enum - +from typing import Optional class TrackState(str, Enum): BORN = "BORN" # first frame this track_id appeared @@ -13,12 +13,15 @@ class TrackState(str, Enum): LOST = "LOST" # not seen for up to max_age frames DEAD = "DEAD" # expired — will not be reassigned - class TrajectoryPoint(BaseModel): + """A single spatial-temporal coordinate snapshot representing an object's historical location.""" + x: float x: float y: float frame_id: int - + interpolated: bool = False + w: Optional[float] = None + h: Optional[float] = None class TrackedObject(BaseModel): track_id: int = Field(..., description="Persistent ID across frames") diff --git a/services/memory/memory.py b/services/memory/memory.py index 6544e61..7fc46a7 100644 --- a/services/memory/memory.py +++ b/services/memory/memory.py @@ -34,8 +34,6 @@ import time from typing import Optional -import numpy as np - from libs.observability.metrics import redis_write_latency from libs.schemas.tracking import TrackLifecycleEvent, TrackState from services.tracking.cross_camera_reid import CrossCameraReID @@ -70,7 +68,7 @@ def __init__(self, redis_client, reid: CrossCameraReID) -> None: def handle_lifecycle_event( self, event: TrackLifecycleEvent, - embedding: Optional[np.ndarray] = None, + embedding: Optional["numpy.ndarray"] = None, ) -> Optional[str]: """ Process a single lifecycle event and return the assigned global_id. @@ -117,7 +115,7 @@ def get_identity(self, global_id: str) -> list[str]: def _handle_born( self, event: TrackLifecycleEvent, - embedding: Optional[np.ndarray], + embedding: Optional["numpy.ndarray"], ) -> str: if embedding is not None: reid_result = self._reid.match_or_create( @@ -162,6 +160,8 @@ def _handle_born( def _handle_lost( self, event: TrackLifecycleEvent, + embedding: Optional["numpy.ndarray"], + ) -> Optional[str]: embedding: Optional[np.ndarray], ) -> tuple[Optional[str], bool]: record = self._load_record(event.camera_id, event.track_id) diff --git a/services/tracking/tracker.py b/services/tracking/tracker.py index b17cfa1..ffd6af6 100644 --- a/services/tracking/tracker.py +++ b/services/tracking/tracker.py @@ -71,11 +71,25 @@ def __init__( camera_id: str = "cam_01", event_logger: TrackEventLogger | None = None, reid_similarity_threshold: float = 0.85, + max_interpolation_gap: int = 10, # Added with a sensible default ) -> None: + """Initialize the tracker with DeepSort hyperparameters and interpolation constraints. + + Args: + fps: Frame rate of the video source. + max_age: Maximum frames to keep a lost track alive before dropping it. + n_init: Number of consecutive frames needed to confirm a track. + max_cosine_distance: Maximum threshold for visual appearance feature matching. + camera_id: Unique identifier string for the source camera. + event_logger: Optional logger interface for tracking state lifecycle events. + reid_similarity_threshold: Minimum confidence needed to reconnect an ID via ReID. + max_interpolation_gap: Maximum frame gap size allowed to fill missing trajectories. + """ self.fps = fps self.camera_id = camera_id self.max_age = max_age # NEW self.REID_SIMILARITY_THRESHOLD = reid_similarity_threshold + self.max_interpolation_gap = max_interpolation_gap # Fixed missing attribute self._tracker = DeepSort( max_age=max_age, @@ -173,15 +187,62 @@ def update( self._emit_lifecycle(TrackState.BORN, tid, zones, 0.0) logger.info(f"Track BORN: #{tid} in zones={zones}") - # ── Dwell time ──────────────────────────────────────────────── + # ── Base Setup & Gap Calculation ────────────────────────────── prev = self._active_tracks.get(tid) - dwell_frames = (prev.dwell_time_frames + 1) if prev else 1 + prev_traj = prev.trajectory if prev else [] + + # Compute gap_frames early so both Dwell Time and Trajectory can use it + gap_frames = max(0, self._frame_id - prev.last_seen_frame - 1) if prev is not None else 0 + + # ── Dwell time ──────────────────────────────────────────────── + if prev: + # Add historic frames, the current frame, and the occlusion gap + dwell_frames = prev.dwell_time_frames + 1 + gap_frames + else: + dwell_frames = 1 + dwell_secs = dwell_frames / self.fps # ── Trajectory ──────────────────────────────────────────────── - prev_traj = prev.trajectory if prev else [] + interpolated_points = [] + max_gap = self.max_interpolation_gap # <-- Replaced self.config string access + + if prev is not None and 0 < gap_frames <= max_gap: + # Added guard condition below to prevent IndexError crashes + if prev.trajectory: + last_pos = {"x": prev.trajectory[-1].x, "y": prev.trajectory[-1].y} + else: + last_pos = {"x": cx, "y": cy} # Fallback to current center coordinates + + new_pos = {"x": cx, "y": cy} + + # Check if previous data contains w and h bounding box metrics + if hasattr(prev, 'bbox') and len(prev.bbox) == 4: + # Calculate old width and height from bbox: [x1, y1, x2, y2] + last_pos["w"] = prev.bbox[2] - prev.bbox[0] + last_pos["h"] = prev.bbox[3] - prev.bbox[1] + # Current width and height + new_pos["w"] = x2 - x1 + new_pos["h"] = y2 - y1 + + # Synthesize intermediate points and wrap them into TrajectoryPoint instances + interpolated_points = [ + TrajectoryPoint( + x=p["x"], + y=p["y"], + frame_id=p["frame_id"], + interpolated=True, + w=p.get("w"), + h=p.get("h") + ) + for p in _interpolate_trajectory(last_pos, new_pos, gap_frames, prev.last_seen_frame + 1) + ] + + # Generate the current frame real point new_point = TrajectoryPoint(x=cx, y=cy, frame_id=self._frame_id) - trajectory = (prev_traj + [new_point])[-self.MAX_TRAJECTORY_LEN :] + + # Merge old history, calculated mid-gap points, and current point cleanly + trajectory = (prev_traj + interpolated_points + [new_point])[-self.MAX_TRAJECTORY_LEN :] obj = TrackedObject( track_id=tid, @@ -241,6 +302,7 @@ def update( del self._active_tracks[tid] self._active_embeddings.pop(tid, None) logger.info(f"Track DEAD: #{tid} after {prev_obj.dwell_time_seconds:.1f}s") + # ── Cleanup expired ReID embeddings ────────────────── expired_ids = [ tid @@ -360,6 +422,37 @@ def main() -> None: writer.release() cv2.destroyAllWindows() +def _interpolate_trajectory( + last_pos: dict, + new_pos: dict, + gap_frames: int, + start_frame_id: int +) -> list: + """Fills trajectory gaps using linear interpolation for temporary missed detections.""" + if gap_frames <= 0: + return [] + + interpolated_points = [] + total_steps = gap_frames + 1 + + x_step = (new_pos['x'] - last_pos['x']) / total_steps + y_step = (new_pos['y'] - last_pos['y']) / total_steps + + for i in range(1, gap_frames + 1): + point = { + "frame_id": start_frame_id + (i - 1), + "x": round(last_pos['x'] + (x_step * i), 2), + "y": round(last_pos['y'] + (y_step * i), 2), + "interpolated": True + } + + if all(k in last_pos and k in new_pos for k in ('w', 'h')): + point['w'] = round(last_pos['w'] + (((new_pos['w'] - last_pos['w']) / total_steps) * i), 2) + point['h'] = round(last_pos['h'] + (((new_pos['h'] - last_pos['h']) / total_steps) * i), 2) + + interpolated_points.append(point) + + return interpolated_points if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/test_tracker.py b/tests/test_tracker.py index dce51e1..6db7be7 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -15,6 +15,7 @@ from libs.schemas.detection import DetectionFrameSchema, DetectionSchema, BoundingBox from libs.schemas.tracking import TrackedFrame, TrackedObject, TrackState, TrajectoryPoint +from services.tracking.tracker import _interpolate_trajectory # ── Schema unit tests (no tracker needed) ──────────────────────────────────── @@ -484,4 +485,45 @@ def test_reid_expires_after_max_age(MockDeepSort): ) # Should NOT restore old ID - assert result.tracks[0].track_id == 99 \ No newline at end of file + assert result.tracks[0].track_id == 99 + + +def test_interpolate_trajectory_success(): + """Test standard linear interpolation for a 3-frame gap including width and height scaling.""" + last_pos = {"x": 10.0, "y": 20.0, "w": 50.0, "h": 50.0} + new_pos = {"x": 50.0, "y": 60.0, "w": 90.0, "h": 90.0} + gap_frames = 3 + start_frame = 101 + + result = _interpolate_trajectory(last_pos, new_pos, gap_frames, start_frame) + + assert len(result) == 3 + + # Assert step progression and metadata across all items + expected_values = [ + {"frame_id": 101, "x": 20.0, "y": 30.0, "w": 60.0, "h": 60.0}, + {"frame_id": 102, "x": 30.0, "y": 40.0, "w": 70.0, "h": 70.0}, + {"frame_id": 103, "x": 40.0, "y": 50.0, "w": 80.0, "h": 80.0}, + ] + + for idx, expected in enumerate(expected_values): + assert result[idx]["frame_id"] == expected["frame_id"] + assert result[idx]["interpolated"] is True + assert result[idx]["x"] == expected["x"] + assert result[idx]["y"] == expected["y"] + assert result[idx]["w"] == expected["w"] + assert result[idx]["h"] == expected["h"] + +def test_interpolate_trajectory_no_gap(): + last_pos = {"x": 10, "y": 20} + new_pos = {"x": 20, "y": 30} + assert _interpolate_trajectory(last_pos, new_pos, 0, 100) == [] + +def test_interpolate_trajectory_no_movement(): + last_pos = {"x": 100.0, "y": 100.0} + new_pos = {"x": 100.0, "y": 100.0} + gap_frames = 2 + start_frame = 50 + result = _interpolate_trajectory(last_pos, new_pos, gap_frames, start_frame) + assert len(result) == 2 + assert result[0]["x"] == 100.0 \ No newline at end of file