-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcounters.py
More file actions
309 lines (251 loc) · 10.5 KB
/
counters.py
File metadata and controls
309 lines (251 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
Scalable Like Counter System
=============================
Features:
- Sharded counters with per-shard locks (no global bottleneck)
- O(1) reads via running_shard_total
- Dynamic shard scaling (16 → 32 → 128 → 1024)
- Worker pool for parallel event processing
- Duplicate-like prevention (in-memory; swap for Redis in prod)
- Periodic aggregator flush to simulated DB
- Shard monitor that auto-scales on queue depth
"""
import random
import threading
import time
from queue import Queue
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
INITIAL_SHARDS = 16
NUM_WORKERS = 4 # parallel event-processor threads
FLUSH_INTERVAL = 10 # seconds between aggregator flushes
MONITOR_INTERVAL = 5 # seconds between shard-monitor checks
# Queue-depth thresholds → target shard count
SCALE_THRESHOLDS = [
(500, 32),
(2_000, 128),
(10_000, 1_024),
]
# ---------------------------------------------------------------------------
# DynamicShardedCounter
# ---------------------------------------------------------------------------
class DynamicShardedCounter:
"""
Thread-safe sharded counter with:
- Per-shard locks → parallel increments with no global contention
- running_total → O(1) reads at all times
- rescale() → zero-loss resharding under a single resize lock
"""
def __init__(self, initial_shards: int = INITIAL_SHARDS):
self._num_shards = initial_shards
self._shards = [0] * initial_shards
self._shard_locks = [threading.Lock() for _ in range(initial_shards)]
self._running_total = 0
self._total_lock = threading.Lock()
# RLock so rescale() can call increment() internally if needed
self._resize_lock = threading.RLock()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def increment(self, user_id: int) -> None:
"""Increment the counter for a given user (write path)."""
with self._resize_lock: # blocks during resize
shard = hash(user_id) % self._num_shards
with self._shard_locks[shard]:
self._shards[shard] += 1
with self._total_lock:
self._running_total += 1 # O(1) running total
def read(self) -> int:
"""Return current like count — O(1)."""
with self._total_lock:
return self._running_total
def flush(self) -> int:
"""
Drain all shard values and reset running_total.
Returns the batch count to be committed to the DB.
Called by the aggregator every FLUSH_INTERVAL seconds.
"""
with self._resize_lock:
batch = 0
for i in range(self._num_shards):
with self._shard_locks[i]:
batch += self._shards[i]
self._shards[i] = 0
with self._total_lock:
self._running_total = 0
return batch
def rescale(self, new_shard_count: int) -> None:
"""
Resize shard array without losing unflushed counts.
Strategy:
1. Acquire resize_lock exclusively (pauses all increments).
2. Sum up all live shard values (carried).
3. Rebuild shards + locks at new size.
4. Redistribute carried counts evenly across new shards.
5. Release lock → writes resume immediately.
"""
with self._resize_lock:
old_count = self._num_shards
if new_shard_count <= old_count:
return # never downscale in this impl
# Step 1: capture live (unflushed) counts
carried = sum(self._shards[i] for i in range(old_count))
# Step 2: rebuild
self._num_shards = new_shard_count
self._shards = [0] * new_shard_count
self._shard_locks = [threading.Lock() for _ in range(new_shard_count)]
# Step 3: redistribute carried counts without rounding loss
base, remainder = divmod(carried, new_shard_count)
for i in range(new_shard_count):
self._shards[i] = base + (1 if i < remainder else 0)
print(
f"[Resharding] {old_count} → {new_shard_count} shards | "
f"Carried {carried} unflushed likes"
)
@property
def num_shards(self) -> int:
return self._num_shards
# ---------------------------------------------------------------------------
# Duplicate-like guard
# ---------------------------------------------------------------------------
class LikeDeduplicator:
"""
In-memory seen set. In production replace with Redis SETNX + TTL.
"""
def __init__(self):
self._seen: set = set()
self._lock = threading.Lock()
def is_duplicate(self, user_id: int, post_id: int) -> bool:
key = (user_id, post_id)
with self._lock:
if key in self._seen:
return True
self._seen.add(key)
return False
# ---------------------------------------------------------------------------
# LikeSystem — wires everything together
# ---------------------------------------------------------------------------
class LikeSystem:
def __init__(self):
self.event_queue = Queue()
self.counter = DynamicShardedCounter(INITIAL_SHARDS)
self.deduplicator = LikeDeduplicator()
self._db_like_count = 0 # simulated persistent DB
self._db_lock = threading.Lock()
self._running = False
# ------------------------------------------------------------------
# Public entry points
# ------------------------------------------------------------------
def like(self, user_id: int, post_id: int) -> bool:
"""
Called by the web layer when a user taps ♥.
Returns False if the like is a duplicate (already counted).
"""
if self.deduplicator.is_duplicate(user_id, post_id):
return False
self.event_queue.put((user_id, post_id))
return True
def get_like_count(self) -> int:
"""O(1) read: in-flight shards + DB snapshot."""
with self._db_lock:
return self._db_like_count + self.counter.read()
# ------------------------------------------------------------------
# Background threads
# ------------------------------------------------------------------
def _event_processor(self) -> None:
"""Worker: pulls events off the queue → increments a shard."""
while self._running:
try:
user_id, post_id = self.event_queue.get(timeout=1)
self.counter.increment(user_id)
self.event_queue.task_done()
except Exception:
pass # timeout — loop and check _running
def _aggregator(self) -> None:
"""Flush shards → DB every FLUSH_INTERVAL seconds."""
while self._running:
time.sleep(FLUSH_INTERVAL)
batch = self.counter.flush()
with self._db_lock:
self._db_like_count += batch
print(
f"[Aggregator] Flushed {batch:,} likes | "
f"DB total: {self._db_like_count:,} | "
f"Queue depth: {self.event_queue.qsize():,} | "
f"Shards: {self.counter.num_shards}"
)
def _shard_monitor(self) -> None:
"""Auto-scale shards based on queue depth."""
while self._running:
time.sleep(MONITOR_INTERVAL)
depth = self.event_queue.qsize()
for threshold, target in SCALE_THRESHOLDS:
if depth > threshold and self.counter.num_shards < target:
self.counter.rescale(target)
break
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def start(self) -> None:
self._running = True
for _ in range(NUM_WORKERS):
threading.Thread(
target=self._event_processor, daemon=True, name="EventProcessor"
).start()
threading.Thread(
target=self._aggregator, daemon=True, name="Aggregator"
).start()
threading.Thread(
target=self._shard_monitor, daemon=True, name="ShardMonitor"
).start()
print(
f"[LikeSystem] Started | "
f"shards={INITIAL_SHARDS} | workers={NUM_WORKERS}"
)
def stop(self) -> None:
self._running = False
# ---------------------------------------------------------------------------
# Traffic simulation
# ---------------------------------------------------------------------------
def simulate_normal_traffic(system: LikeSystem) -> None:
"""Steady low-volume traffic."""
while True:
user_id = random.randint(1, 100_000)
system.like(user_id, post_id=123)
time.sleep(0.01)
def simulate_viral_spike(system: LikeSystem) -> None:
"""
Fires after 20 s to simulate a post going viral.
Hammers the system with 500 likes/s to trigger auto-rescaling.
"""
time.sleep(20)
print("\n[Simulation] *** VIRAL SPIKE STARTING ***\n")
while True:
user_id = random.randint(1, 10_000_000)
system.like(user_id, post_id=123)
time.sleep(0.002)
def print_stats(system: LikeSystem) -> None:
"""Periodically print a live read of the like count."""
while True:
time.sleep(3)
print(
f"[Read] Like count (O1): {system.get_like_count():,} | "
f"Queue: {system.event_queue.qsize():,} | "
f"Shards: {system.counter.num_shards}"
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
system = LikeSystem()
system.start()
threading.Thread(target=simulate_normal_traffic, args=(system,), daemon=True).start()
threading.Thread(target=simulate_viral_spike, args=(system,), daemon=True).start()
threading.Thread(target=print_stats, args=(system,), daemon=True).start()
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\n[LikeSystem] Shutting down...")
system.stop()