From 8a4f02056a72c21360de4843097c95820d68a18b Mon Sep 17 00:00:00 2001 From: GhanshyamJha Date: Thu, 21 May 2026 21:26:43 +0530 Subject: [PATCH] fix: make rate limiter atomic --- src/lib/cache.ts | 62 ++++++++++++++++++++++++++++++++++++++ src/lib/rate-limit.test.ts | 11 +++++++ src/lib/rate-limit.ts | 25 +++------------ 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/src/lib/cache.ts b/src/lib/cache.ts index 467233f..79b1b32 100644 --- a/src/lib/cache.ts +++ b/src/lib/cache.ts @@ -13,6 +13,20 @@ interface CacheBackend { set(key: string, value: unknown, ttlSeconds: number): Promise; del(key: string): Promise; scanDel(prefix: string): Promise; + rateLimitHit(key: string, windowSec: number, now: number): Promise; +} + +export type RateLimitBucket = { + count: number; + resetAt: number; +}; + +function rateLimitResetAt(ttlSeconds: number, windowSec: number, now: number): number { + return now + Math.max(1, ttlSeconds > 0 ? ttlSeconds : windowSec) * 1000; +} + +function blockedRateLimitBucket(windowSec: number, now: number): RateLimitBucket { + return { count: Number.MAX_SAFE_INTEGER, resetAt: now + windowSec * 1000 }; } class MemoryBackend implements CacheBackend { @@ -41,6 +55,22 @@ class MemoryBackend implements CacheBackend { if (k.startsWith(prefix)) this.store.delete(k); } } + + async rateLimitHit(key: string, windowSec: number, now: number): Promise { + const hit = this.store.get(key); + const count = typeof hit?.value === 'number' ? hit.value : 0; + const expired = !hit || hit.expiresAt <= now || count <= 0; + + if (expired) { + const resetAt = now + windowSec * 1000; + this.store.set(key, { value: 1, expiresAt: resetAt }); + return { count: 1, resetAt }; + } + + const next = count + 1; + this.store.set(key, { value: next, expiresAt: hit.expiresAt }); + return { count: next, resetAt: hit.expiresAt }; + } } export class UpstashBackend implements CacheBackend { @@ -90,6 +120,18 @@ export class UpstashBackend implements CacheBackend { // ignore } } + + async rateLimitHit(key: string, windowSec: number, now: number): Promise { + try { + const count = await this.redis.incr(key); + if (count === 1) await this.redis.expire(key, windowSec); + const ttl = await this.redis.ttl(key); + if (ttl <= 0) await this.redis.expire(key, windowSec); + return { count, resetAt: rateLimitResetAt(ttl, windowSec, now) }; + } catch { + return blockedRateLimitBucket(windowSec, now); + } + } } export class IoRedisBackend implements CacheBackend { @@ -136,6 +178,18 @@ export class IoRedisBackend implements CacheBackend { // ignore } } + + async rateLimitHit(key: string, windowSec: number, now: number): Promise { + try { + const count = await this.redis.incr(key); + if (count === 1) await this.redis.expire(key, windowSec); + const ttl = await this.redis.ttl(key); + if (ttl <= 0) await this.redis.expire(key, windowSec); + return { count, resetAt: rateLimitResetAt(ttl, windowSec, now) }; + } catch { + return blockedRateLimitBucket(windowSec, now); + } + } } let backend: CacheBackend = pickDefaultBackend(); @@ -184,3 +238,11 @@ export function cacheDel(key: string): Promise { export function cacheDelByPrefix(prefix: string): Promise { return backend.scanDel(prefix); } + +export function cacheRateLimitHit( + key: string, + windowSec: number, + now: number, +): Promise { + return backend.rateLimitHit(key, windowSec, now); +} diff --git a/src/lib/rate-limit.test.ts b/src/lib/rate-limit.test.ts index ae5b91b..b28609b 100644 --- a/src/lib/rate-limit.test.ts +++ b/src/lib/rate-limit.test.ts @@ -34,6 +34,17 @@ describe('rateLimit', () => { expect(blocked.remaining).toBe(0); }); + it('does not allow concurrent bursts past the limit', async () => { + const opts = { namespace: 'test', key: 'burst', limit: 5, windowSec: 60 }; + const results = await Promise.all(Array.from({ length: 20 }, () => rateLimit(opts))); + const firstResetAt = results.at(0)?.resetAt; + + expect(results.filter((r) => r.ok)).toHaveLength(5); + expect(results.filter((r) => !r.ok)).toHaveLength(15); + expect(firstResetAt).toBeDefined(); + expect(results.every((r) => r.resetAt === firstResetAt)).toBe(true); + }); + it('separate keys do not share budget', async () => { const a = await rateLimit({ namespace: 'test', key: 'a', limit: 1, windowSec: 60 }); const b = await rateLimit({ namespace: 'test', key: 'b', limit: 1, windowSec: 60 }); diff --git a/src/lib/rate-limit.ts b/src/lib/rate-limit.ts index 36e9c12..9a7794f 100644 --- a/src/lib/rate-limit.ts +++ b/src/lib/rate-limit.ts @@ -1,4 +1,4 @@ -import { cacheGet, cacheSet } from './cache'; +import { cacheRateLimitHit } from './cache'; export type RateLimitOptions = { namespace: string; @@ -18,28 +18,13 @@ export type RateLimitResult = { * If we ever need true sliding window precision, swap the backend without touching callers. */ export async function rateLimit(opts: RateLimitOptions): Promise { - const bucketKey = `rl:${opts.namespace}:${opts.key}`; + const bucketKey = `rl:v2:${opts.namespace}:${opts.key}`; const now = Date.now(); - const ttlMs = opts.windowSec * 1000; - - const existing = await cacheGet<{ count: number; resetAt: number }>(bucketKey); - if (!existing || existing.resetAt <= now) { - const fresh = { count: 1, resetAt: now + ttlMs }; - await cacheSet(bucketKey, fresh, opts.windowSec); - return { ok: true, remaining: opts.limit - 1, resetAt: fresh.resetAt }; - } - - if (existing.count >= opts.limit) { - return { ok: false, remaining: 0, resetAt: existing.resetAt }; - } - - const next = { count: existing.count + 1, resetAt: existing.resetAt }; - const remainingTtl = Math.max(1, Math.ceil((existing.resetAt - now) / 1000)); - await cacheSet(bucketKey, next, remainingTtl); + const next = await cacheRateLimitHit(bucketKey, opts.windowSec, now); return { - ok: true, + ok: next.count <= opts.limit, remaining: Math.max(0, opts.limit - next.count), - resetAt: existing.resetAt, + resetAt: next.resetAt, }; }