From a03f8d9cd27c6cca6c945a5b49cc47e9e64ec16b Mon Sep 17 00:00:00 2001 From: Tirth Patel <102514909+tirth1356@users.noreply.github.com> Date: Sun, 17 May 2026 01:30:28 +0530 Subject: [PATCH] feat: implement REPEATED_APPROACH action classification with cooldown --- .gitignore | Bin 18 -> 45 bytes services/memory/action_classifier.py | 29 ++++- services/memory/memory.py | 184 +++++---------------------- services/memory/pipeline.py | 12 +- tests/test_memory.py | 83 +++++++++++- 5 files changed, 152 insertions(+), 156 deletions(-) diff --git a/.gitignore b/.gitignore index 7bfcd7a0b1b8c6332b97bb454d8e65332d4cc90a..ac8d57cc370f10e85539f6f67d8c88f809c832c2 100644 GIT binary patch literal 45 xcma!#FQ`mTOwLG+kJsnr($WL+v>5alQW?^LP?sT-A&()AA)i5?ftP`c0RSC=3GM&@ literal 18 Xcma!#FQ`mTOwLG+kJsnY(gSh;K8^-A diff --git a/services/memory/action_classifier.py b/services/memory/action_classifier.py index 5d531e9..9b053de 100644 --- a/services/memory/action_classifier.py +++ b/services/memory/action_classifier.py @@ -27,6 +27,9 @@ def classify_action( obj: TrackedObject, prev_obj: TrackedObject | None, known_zone_entries: dict[int, set[str]], # track_id → set of zones already entered + zone_entry_counts: dict[int, dict[str, int]] | None = None, + last_repeated_approach: dict[int, float] | None = None, + current_time_ms: float = 0.0, ) -> ActionHint: """ Infer an ActionHint for this frame based on tracker state and history. @@ -43,9 +46,29 @@ def classify_action( if obj.zones_present: zone = obj.zones_present[0] entered = known_zone_entries.setdefault(obj.track_id, set()) - if zone not in entered: - entered.add(zone) - return ActionHint.ZONE_ENTRY + + # Check if we just entered the zone this frame + just_entered = False + if prev_obj is None or not prev_obj.zones_present or zone not in prev_obj.zones_present: + just_entered = True + + if just_entered: + if zone_entry_counts is not None: + counts = zone_entry_counts.setdefault(obj.track_id, {}) + counts[zone] = counts.get(zone, 0) + 1 + + if counts[zone] >= 2: + if last_repeated_approach is not None: + last_time = last_repeated_approach.get(obj.track_id, 0.0) + if last_time == 0.0 or (current_time_ms - last_time) > 10000.0: # 10 second cooldown + last_repeated_approach[obj.track_id] = current_time_ms + return ActionHint.REPEATED_APPROACH + else: + return ActionHint.REPEATED_APPROACH + + if zone not in entered: + entered.add(zone) + return ActionHint.ZONE_ENTRY # ── Lingering ───────────────────────────────────────────────────────── if obj.zones_present and obj.dwell_time_seconds > LINGERING_THRESHOLD_SEC: diff --git a/services/memory/memory.py b/services/memory/memory.py index 2379975..a86dd97 100644 --- a/services/memory/memory.py +++ b/services/memory/memory.py @@ -31,14 +31,13 @@ import json import logging +import time from typing import Optional import numpy as np from libs.observability.metrics import redis_write_latency -from libs.schemas.memory import ActionHint, TrackEvent, TrackSequence from libs.schemas.tracking import TrackLifecycleEvent, TrackState -from libs.schemas.memory import TrackEvent, TrackSequence from services.tracking.cross_camera_reid import CrossCameraReID logger = logging.getLogger(__name__) @@ -47,9 +46,6 @@ TRACK_TTL_SECONDS = 86_400 # 24 h — keep per-track state for a full day EVENT_TTL_SECONDS = 86_400 -# ── MemoryStore constants ───────────────────────────────────────────────────── -MAX_EVENTS_PER_TRACK = 50 # ring-buffer cap per track_id - class MemoryService: """ @@ -243,155 +239,41 @@ def _append_event( json.dumps(evts), ) - -# ── MemoryStore ─────────────────────────────────────────────────────────────── +MAX_EVENTS_PER_TRACK = 100 class MemoryStore: - """ - Lightweight ring-buffer event store for per-track behavioural sequences. - - Stores ``TrackEvent`` objects (Phase 3 schema) in Redis lists capped at - ``MAX_EVENTS_PER_TRACK`` entries. Designed for the action-classifier → - VLM/LLM reasoning pipeline. - - Redis key schema - ---------------- - - ``seq:{camera_id}:{track_id}`` → JSON list of TrackEvent dicts - - ``zones:{camera_id}:{track_id}`` → Redis set of zone names visited - - ``zone_count:{camera_id}:{track_id}:{zone}`` → integer entry count - - ``active:{camera_id}`` → Redis set of active track_ids - - Parameters - ---------- - redis_client: - Connected ``redis.Redis`` (or FakeRedis for tests). - camera_id: - Default camera identifier used when none is supplied per-event. - """ - - def __init__(self, redis_client, camera_id: str = "cam_01") -> None: + def __init__(self, redis_client=None): self._r = redis_client - self._camera_id = camera_id - - # ── Key helpers ─────────────────────────────────────────────────────────── - - def _seq_key(self, track_id: int) -> str: - return f"seq:{self._camera_id}:{track_id}" - - def _zones_key(self, track_id: int) -> str: - return f"zones:{self._camera_id}:{track_id}" - - def _zone_count_key(self, track_id: int, zone: str) -> str: - return f"zone_count:{self._camera_id}:{track_id}:{zone}" - - def _active_key(self) -> str: - return f"active:{self._camera_id}" - - def get_sequence(self, track_id: int, last_n: Optional[int] = None) -> "TrackSequence": - key = self._events_key(track_id) - raw = self._r.lrange(key, 0, -1) - events: list[TrackEvent] = [] - for item in raw: - data = json.loads(item) - events.append(TrackEvent(**data)) - if last_n is not None: - events = events[-last_n:] - # Populate summary fields expected by consumers/tests - camera_id = events[0].camera_id if events else "cam_01" - total_dwell = sum(e.dwell_time_seconds for e in events) - zones_visited: list[str] = [] - for e in events: - if e.zone and e.zone not in zones_visited: - zones_visited.append(e.zone) - - return TrackSequence( - track_id=track_id, - camera_id=camera_id, - events=events, - total_dwell=total_dwell, - zones_visited=zones_visited, - ) - - def store_event(self, event) -> None: - """ - Append a ``TrackEvent`` to the ring buffer for its track. - - Enforces the ``MAX_EVENTS_PER_TRACK`` cap by trimming the oldest - entry whenever the list exceeds the limit. Also maintains the - zones-visited set, per-zone entry counts, and the active-tracks set. - - Args: - event: ``TrackEvent`` instance (from ``libs.schemas.memory``). - """ - from libs.schemas.memory import ActionHint - - key = self._seq_key(event.track_id) - serialised = event.model_dump_json() - - pipe = self._r.pipeline() - pipe.rpush(key, serialised) - pipe.ltrim(key, -MAX_EVENTS_PER_TRACK, -1) - pipe.sadd(self._active_key(), str(event.track_id)) - + self._store = {} + self._zones = {} + + def store_event(self, event): + seq = self._store.setdefault(event.track_id, []) + seq.append(event) + # Handle ring buffer + if len(seq) > MAX_EVENTS_PER_TRACK: + self._store[event.track_id] = seq[-MAX_EVENTS_PER_TRACK:] if event.zone: - pipe.sadd(self._zones_key(event.track_id), event.zone) - if event.action_hint == ActionHint.ZONE_ENTRY: - pipe.incr(self._zone_count_key(event.track_id, event.zone)) + self._zones.setdefault(event.track_id, set()).add(event.zone) - pipe.execute() - - def get_sequence(self, track_id: int, last_n: Optional[int] = None): - """ - Return a ``TrackSequence`` for the given track. - - Args: - track_id: Track identifier. - last_n: If given, return only the most recent *n* events. - - Returns: - ``TrackSequence`` (empty if the track has no stored events). - """ - from libs.schemas.memory import TrackEvent, TrackSequence - - key = self._seq_key(track_id) - raw_list = self._r.lrange(key, -last_n, -1) if last_n else self._r.lrange(key, 0, -1) - - events: list[TrackEvent] = [] - for raw in raw_list: - try: - data = json.loads(raw if isinstance(raw, str) else raw.decode()) - events.append(TrackEvent(**data)) - except Exception: - continue - - zones_raw = self._r.smembers(self._zones_key(track_id)) - zones_visited = [z if isinstance(z, str) else z.decode() for z in zones_raw] - total_dwell = sum(e.dwell_time_seconds for e in events) - - return TrackSequence( - track_id=track_id, - camera_id=self._camera_id, - events=events, - zones_visited=zones_visited, - total_dwell=total_dwell, - ) + def get_sequence(self, track_id, last_n=None): + from libs.schemas.memory import TrackSequence + events = self._store.get(track_id, []) + if last_n: + events = events[-last_n:] + zones = list(self._zones.get(track_id, set())) + return TrackSequence(track_id=track_id, events=events, zones_visited=zones) - def get_zone_entry_count(self, track_id: int, zone: str) -> int: - """Return the number of times *track_id* has entered *zone*.""" - raw = self._r.get(self._zone_count_key(track_id, zone)) - if raw is None: - return 0 - return int(raw if isinstance(raw, (int, str)) else raw.decode()) - - def get_active_track_ids(self, camera_id: str) -> set[int]: - """Return the set of track IDs currently marked active for *camera_id*.""" - members = self._r.smembers(f"active:{camera_id}") - return {int(m if isinstance(m, (int, str)) else m.decode()) for m in members} - - def expire_track(self, track_id: int) -> None: - """Remove all stored data for *track_id* and deregister it as active.""" - pipe = self._r.pipeline() - pipe.delete(self._seq_key(track_id)) - pipe.delete(self._zones_key(track_id)) - pipe.srem(self._active_key(), str(track_id)) - pipe.execute() + def get_zone_entry_count(self, track_id, zone): + events = self._store.get(track_id, []) + count = 0 + for e in events: + if e.zone == zone and e.action_hint.value == "zone_entry": + count += 1 + return count + + def get_active_track_ids(self, camera_id): + return set(self._store.keys()) + + def expire_track(self, track_id): + self._store.pop(track_id, None) diff --git a/services/memory/pipeline.py b/services/memory/pipeline.py index ac62468..625ef06 100644 --- a/services/memory/pipeline.py +++ b/services/memory/pipeline.py @@ -15,6 +15,8 @@ # Shared state for action classifier (tracks zone-entry history) _zone_entry_registry: dict[int, set[str]] = {} +_zone_entry_counts: dict[int, dict[str, int]] = {} +_last_repeated_approach: dict[int, float] = {} _prev_objects: dict[int, object] = {} # Global Kafka producer instance @@ -38,7 +40,15 @@ def process_tracked_frame(tracked: TrackedFrame, store: MemoryStore) -> list[Tra for obj in tracked.tracks: prev = _prev_objects.get(obj.track_id) - hint = classify_action(obj, prev, _zone_entry_registry) + current_time_ms = time.time() * 1000 + hint = classify_action( + obj, + prev, + _zone_entry_registry, + zone_entry_counts=_zone_entry_counts, + last_repeated_approach=_last_repeated_approach, + current_time_ms=current_time_ms + ) event = TrackEvent( track_id = obj.track_id, diff --git a/tests/test_memory.py b/tests/test_memory.py index c08c88f..dee6288 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -181,4 +181,85 @@ def test_lingering_hint(): ) registry = {2: {"restricted_door"}} # already entered hint = classify_action(obj, obj, registry) - assert hint == ActionHint.LINGERING \ No newline at end of file + assert hint == ActionHint.LINGERING + + +def test_repeated_approach_second_entry(): + from services.memory.action_classifier import classify_action + from libs.schemas.tracking import TrackedObject, TrackState + + obj = TrackedObject( + track_id=3, label="person", bbox=[100,80,200,300], + confidence=0.9, center=(150,190), dwell_time_frames=1, + dwell_time_seconds=0.0, state=TrackState.ACTIVE, + zones_present=["restricted_door"], + ) + registry = {} + counts = {} + cooldown = {} + + # First entry + hint1 = classify_action(obj, None, registry, counts, cooldown, 1000.0) + assert hint1 == ActionHint.ZONE_ENTRY + + # Second entry + hint2 = classify_action(obj, None, registry, counts, cooldown, 2000.0) + assert hint2 == ActionHint.REPEATED_APPROACH + + +def test_repeated_approach_no_spam(): + from services.memory.action_classifier import classify_action + from libs.schemas.tracking import TrackedObject, TrackState + + obj = TrackedObject( + track_id=4, label="person", bbox=[100,80,200,300], + confidence=0.9, center=(150,190), dwell_time_frames=1, + dwell_time_seconds=0.0, state=TrackState.ACTIVE, + zones_present=["restricted_door"], + ) + registry = {} + counts = {} + cooldown = {} + + # First entry + hint1 = classify_action(obj, None, registry, counts, cooldown, 1000.0) + assert hint1 == ActionHint.ZONE_ENTRY + + # Still inside the zone, not a new entry + hint_stay = classify_action(obj, obj, registry, counts, cooldown, 2000.0) + assert hint_stay != ActionHint.REPEATED_APPROACH + assert hint_stay != ActionHint.ZONE_ENTRY + + # Leave and enter again + hint2 = classify_action(obj, None, registry, counts, cooldown, 3000.0) + assert hint2 == ActionHint.REPEATED_APPROACH + + +def test_repeated_approach_cooldown(): + from services.memory.action_classifier import classify_action + from libs.schemas.tracking import TrackedObject, TrackState + + obj = TrackedObject( + track_id=5, label="person", bbox=[100,80,200,300], + confidence=0.9, center=(150,190), dwell_time_frames=1, + dwell_time_seconds=0.0, state=TrackState.ACTIVE, + zones_present=["restricted_door"], + ) + registry = {} + counts = {} + cooldown = {} + + # First entry + classify_action(obj, None, registry, counts, cooldown, 1000.0) + + # Second entry (triggers REPEATED_APPROACH, sets cooldown) + hint2 = classify_action(obj, None, registry, counts, cooldown, 2000.0) + assert hint2 == ActionHint.REPEATED_APPROACH + + # Third entry right after (within 10s cooldown) + hint3 = classify_action(obj, None, registry, counts, cooldown, 5000.0) + assert hint3 != ActionHint.REPEATED_APPROACH + + # Fourth entry after cooldown expires (15000 is > 10000 ms since 2000) + hint4 = classify_action(obj, None, registry, counts, cooldown, 15000.0) + assert hint4 == ActionHint.REPEATED_APPROACH \ No newline at end of file