Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/lib/cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ interface CacheBackend {
set(key: string, value: unknown, ttlSeconds: number): Promise<void>;
del(key: string): Promise<void>;
scanDel(prefix: string): Promise<void>;
rateLimitHit(key: string, windowSec: number, now: number): Promise<RateLimitBucket>;
}

export type RateLimitBucket = {
count: number;
resetAt: number;
};

function rateLimitResetAt(ttlSeconds: number, windowSec: number, now: number): number {
return now + Math.max(1, ttlSeconds > 0 ? ttlSeconds : windowSec) * 1000;
}

function blockedRateLimitBucket(windowSec: number, now: number): RateLimitBucket {
return { count: Number.MAX_SAFE_INTEGER, resetAt: now + windowSec * 1000 };
}

class MemoryBackend implements CacheBackend {
Expand Down Expand Up @@ -41,6 +55,22 @@ class MemoryBackend implements CacheBackend {
if (k.startsWith(prefix)) this.store.delete(k);
}
}

async rateLimitHit(key: string, windowSec: number, now: number): Promise<RateLimitBucket> {
const hit = this.store.get(key);
const count = typeof hit?.value === 'number' ? hit.value : 0;
const expired = !hit || hit.expiresAt <= now || count <= 0;

if (expired) {
const resetAt = now + windowSec * 1000;
this.store.set(key, { value: 1, expiresAt: resetAt });
return { count: 1, resetAt };
}

const next = count + 1;
this.store.set(key, { value: next, expiresAt: hit.expiresAt });
return { count: next, resetAt: hit.expiresAt };
}
}

export class UpstashBackend implements CacheBackend {
Expand Down Expand Up @@ -90,6 +120,18 @@ export class UpstashBackend implements CacheBackend {
// ignore
}
}

async rateLimitHit(key: string, windowSec: number, now: number): Promise<RateLimitBucket> {
try {
const count = await this.redis.incr(key);
if (count === 1) await this.redis.expire(key, windowSec);
const ttl = await this.redis.ttl(key);
if (ttl <= 0) await this.redis.expire(key, windowSec);
return { count, resetAt: rateLimitResetAt(ttl, windowSec, now) };
} catch {
return blockedRateLimitBucket(windowSec, now);
}
}
}

export class IoRedisBackend implements CacheBackend {
Expand Down Expand Up @@ -136,6 +178,18 @@ export class IoRedisBackend implements CacheBackend {
// ignore
}
}

async rateLimitHit(key: string, windowSec: number, now: number): Promise<RateLimitBucket> {
try {
const count = await this.redis.incr(key);
if (count === 1) await this.redis.expire(key, windowSec);
const ttl = await this.redis.ttl(key);
if (ttl <= 0) await this.redis.expire(key, windowSec);
return { count, resetAt: rateLimitResetAt(ttl, windowSec, now) };
} catch {
return blockedRateLimitBucket(windowSec, now);
}
}
}

let backend: CacheBackend = pickDefaultBackend();
Expand Down Expand Up @@ -184,3 +238,11 @@ export function cacheDel(key: string): Promise<void> {
export function cacheDelByPrefix(prefix: string): Promise<void> {
return backend.scanDel(prefix);
}

export function cacheRateLimitHit(
key: string,
windowSec: number,
now: number,
): Promise<RateLimitBucket> {
return backend.rateLimitHit(key, windowSec, now);
}
11 changes: 11 additions & 0 deletions src/lib/rate-limit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ describe('rateLimit', () => {
expect(blocked.remaining).toBe(0);
});

it('does not allow concurrent bursts past the limit', async () => {
const opts = { namespace: 'test', key: 'burst', limit: 5, windowSec: 60 };
const results = await Promise.all(Array.from({ length: 20 }, () => rateLimit(opts)));
const firstResetAt = results.at(0)?.resetAt;

expect(results.filter((r) => r.ok)).toHaveLength(5);
expect(results.filter((r) => !r.ok)).toHaveLength(15);
expect(firstResetAt).toBeDefined();
expect(results.every((r) => r.resetAt === firstResetAt)).toBe(true);
});

it('separate keys do not share budget', async () => {
const a = await rateLimit({ namespace: 'test', key: 'a', limit: 1, windowSec: 60 });
const b = await rateLimit({ namespace: 'test', key: 'b', limit: 1, windowSec: 60 });
Expand Down
25 changes: 5 additions & 20 deletions src/lib/rate-limit.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { cacheGet, cacheSet } from './cache';
import { cacheRateLimitHit } from './cache';

export type RateLimitOptions = {
namespace: string;
Expand All @@ -18,28 +18,13 @@ export type RateLimitResult = {
* If we ever need true sliding window precision, swap the backend without touching callers.
*/
export async function rateLimit(opts: RateLimitOptions): Promise<RateLimitResult> {
const bucketKey = `rl:${opts.namespace}:${opts.key}`;
const bucketKey = `rl:v2:${opts.namespace}:${opts.key}`;
const now = Date.now();
const ttlMs = opts.windowSec * 1000;

const existing = await cacheGet<{ count: number; resetAt: number }>(bucketKey);
if (!existing || existing.resetAt <= now) {
const fresh = { count: 1, resetAt: now + ttlMs };
await cacheSet(bucketKey, fresh, opts.windowSec);
return { ok: true, remaining: opts.limit - 1, resetAt: fresh.resetAt };
}

if (existing.count >= opts.limit) {
return { ok: false, remaining: 0, resetAt: existing.resetAt };
}

const next = { count: existing.count + 1, resetAt: existing.resetAt };
const remainingTtl = Math.max(1, Math.ceil((existing.resetAt - now) / 1000));
await cacheSet(bucketKey, next, remainingTtl);
const next = await cacheRateLimitHit(bucketKey, opts.windowSec, now);

return {
ok: true,
ok: next.count <= opts.limit,
remaining: Math.max(0, opts.limit - next.count),
resetAt: existing.resetAt,
resetAt: next.resetAt,
};
}