Skip to content

Commit 1f1327d

Browse files
committed
Remove decorator override for ratelimit
The ratelimit keys in the cache, include a global or scope prefix in the key name. At the decorator level, there is no prefix, and hence it would use the global prefix and end up overriding global ratelimit records for the requests made by the user in the current window. Thinking more of it, I don't think we need decorator level overrides. Scope isolation should be sufficient.
1 parent 331f806 commit 1f1327d

2 files changed

Lines changed: 24 additions & 143 deletions

File tree

brainzutils/flask/test/test_ratelimit.py

Lines changed: 11 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -112,80 +112,18 @@ def index():
112112
set_rate_limits(self.max_token_requests, self.max_ip_requests, self.ratelimit_window)
113113
self.make_requests(client, self.max_token_requests, token="Token %s" % valid_user)
114114

115-
def test_custom_ip_limit(self):
116-
"""Test that per_ip_limit parameter overrides global limit."""
117-
custom_limit = 2
118-
119-
@self.app.route("/custom")
120-
@ratelimit(per_ip_limit=custom_limit, window=60)
121-
def custom_endpoint():
122-
return "OK"
123-
124-
client = self.app.test_client()
125-
126-
# First request should succeed
127-
response = client.get("/custom")
128-
self.assertEqual(response.status_code, 200)
129-
self.assertEqual(response.headers["X-RateLimit-Limit"], str(custom_limit))
130-
self.assertEqual(response.headers["X-RateLimit-Remaining"], "1")
131-
132-
response = client.get("/custom")
133-
self.assertEqual(response.status_code, 200)
134-
self.assertEqual(response.headers["X-RateLimit-Remaining"], "0")
135-
136-
response = client.get("/custom")
137-
self.assertEqual(response.status_code, 429)
138-
139-
def test_custom_window(self):
140-
"""Test that window parameter works correctly."""
141-
@self.app.route("/short-window")
142-
@ratelimit(per_ip_limit=1, window=2)
143-
def short_window_endpoint():
144-
return "OK"
145-
146-
client = self.app.test_client()
147-
148-
response = client.get("/short-window")
149-
self.assertEqual(response.status_code, 200)
150-
response = client.get("/short-window")
151-
self.assertEqual(response.status_code, 429)
152-
153-
sleep(2.5)
154-
response = client.get("/short-window")
155-
self.assertEqual(response.status_code, 200)
156-
157-
def test_headers_contain_correct_values(self):
158-
"""Test that rate limit headers contain expected values."""
159-
limit = 5
160-
window = 30
161-
162-
@self.app.route("/headers")
163-
@ratelimit(per_ip_limit=limit, window=window)
164-
def headers_endpoint():
165-
return "OK"
166-
167-
client = self.app.test_client()
168-
response = client.get("/headers")
169-
170-
self.assertEqual(response.status_code, 200)
171-
self.assertIn("X-RateLimit-Limit", response.headers)
172-
self.assertIn("X-RateLimit-Remaining", response.headers)
173-
self.assertIn("X-RateLimit-Reset-In", response.headers)
174-
175-
self.assertEqual(response.headers["X-RateLimit-Limit"], str(limit))
176-
self.assertEqual(response.headers["X-RateLimit-Remaining"], str(limit - 1))
177-
self.assertLessEqual(int(response.headers["X-RateLimit-Reset-In"]), window)
178-
self.assertIn("X-RateLimit-Reset", response.headers)
179-
180115
def test_scope_isolation(self):
181116
"""Test that different scopes have independent rate limit buckets."""
117+
set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_a")
118+
set_rate_limits(per_token=100, per_ip=2, window=60, scope="scope_b")
119+
182120
@self.app.route("/scope-a")
183-
@ratelimit(scope="scope_a", per_ip_limit=2, window=60)
121+
@ratelimit(scope="scope_a")
184122
def scope_a_endpoint():
185123
return "A"
186124

187125
@self.app.route("/scope-b")
188-
@ratelimit(scope="scope_b", per_ip_limit=2, window=60)
126+
@ratelimit(scope="scope_b")
189127
def scope_b_endpoint():
190128
return "B"
191129

@@ -245,13 +183,16 @@ def shared_2_endpoint():
245183

246184
def test_no_scope_vs_scoped(self):
247185
"""Test that unscoped and scoped endpoints have separate buckets."""
186+
set_rate_limits(per_token=100, per_ip=1, window=60)
187+
set_rate_limits(per_token=100, per_ip=1, window=60, scope="my_scope")
188+
248189
@self.app.route("/unscoped")
249-
@ratelimit(per_ip_limit=1, window=60)
190+
@ratelimit()
250191
def unscoped_endpoint():
251192
return "Unscoped"
252193

253194
@self.app.route("/scoped")
254-
@ratelimit(scope="my_scope", per_ip_limit=1, window=60)
195+
@ratelimit(scope="my_scope")
255196
def scoped_endpoint():
256197
return "Scoped"
257198

@@ -290,31 +231,7 @@ def test_set_and_get_scope_limits(self):
290231
self.assertIsNotNone(result)
291232
self.assertEqual(result["per_token"], None)
292233
self.assertEqual(result["per_ip"], None)
293-
self.assertEqual(result["window"], None
294-
)
295-
296-
def test_decorator_overrides_cache_scope_limits(self):
297-
"""Test that decorator parameters override scope limits from cache."""
298-
scope = "override_scope"
299-
set_rate_limits(per_token=100, per_ip=10, window=60, scope=scope)
300-
301-
@self.app.route("/override")
302-
@ratelimit(scope=scope, per_ip_limit=2)
303-
def override_endpoint():
304-
return "OK"
305-
306-
client = self.app.test_client()
307-
308-
response = client.get("/override")
309-
self.assertEqual(response.status_code, 200)
310-
self.assertEqual(response.headers["X-RateLimit-Limit"], "2")
311-
312-
response = client.get("/override")
313-
self.assertEqual(response.status_code, 200)
314-
315-
# 3rd request should fail (limit is 2, not 10)
316-
response = client.get("/override")
317-
self.assertEqual(response.status_code, 429)
234+
self.assertEqual(result["window"], None)
318235

319236
def test_scope_cache_values_stored_correctly(self):
320237
"""Test that scope limits are stored in cache with correct keys."""
@@ -359,11 +276,6 @@ def global_endpoint():
359276
def scope_priority_endpoint():
360277
return "OK"
361278

362-
@self.app.route("/decorator-priority")
363-
@ratelimit(scope=scope, per_ip_limit=2)
364-
def decorator_priority_endpoint():
365-
return "OK"
366-
367279
client = self.app.test_client()
368280

369281
response = client.get("/global")
@@ -373,7 +285,3 @@ def decorator_priority_endpoint():
373285
response = client.get("/scope-priority")
374286
self.assertEqual(response.status_code, 200)
375287
self.assertEqual(response.headers["X-RateLimit-Limit"], "3")
376-
377-
response = client.get("/decorator-priority")
378-
self.assertEqual(response.status_code, 200)
379-
self.assertEqual(response.headers["X-RateLimit-Limit"], "2")

brainzutils/ratelimit.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,15 @@ def after_request_callbacks(response):
7878
def index():
7979
return '<html><body>test</body></html>'
8080
81-
You can also pass custom rate limit parameters directly to the decorator to override
82-
the global/cached values::
83-
84-
@app.route('/expensive')
85-
@ratelimit(per_token_limit=10, per_ip_limit=5, window=60)
86-
def expensive_endpoint():
87-
return 'This endpoint has stricter rate limits'
88-
89-
Use the scope parameter to isolate rate limits for different endpoints::
81+
Use the scope parameter to isolate and apply custom rate limits for different endpoints::
9082
9183
@app.route('/api/v1/search')
9284
@ratelimit(scope='search')
9385
def search():
9486
return 'Search results'
9587
9688
@app.route('/api/v1/upload')
97-
@ratelimit(scope='upload', per_ip_limit=5, window=60)
89+
@ratelimit(scope='upload')
9890
def upload():
9991
return 'Upload complete'
10092
@@ -114,9 +106,8 @@ def upload():
114106
set_rate_limits(per_token=10, per_ip=5, window=120, scope='upload')
115107
116108
Limit resolution order (first non-None value wins):
117-
1. Decorator parameters (per_token_limit, per_ip_limit, window)
118-
2. Scope-specific limits from cache (if scope is provided)
119-
3. Global limits from cache
109+
1. Scope-specific limits from cache (if scope is provided)
110+
2. Global limits from cache
120111
121112
4. To enable token based rate limiting, callers need to pass the Authorization header (see above)
122113
and the application needs to provide a user validation function::
@@ -240,51 +231,39 @@ def on_over_limit(limit):
240231

241232
def _get_rate_limit_helper(
242233
limit_type: Literal["per_ip", "per_token"],
243-
_global: dict, _scope: dict = None, _local: dict = None
234+
_global: dict,
235+
_scope: dict = None
244236
) -> dict:
245237
values = {}
246-
if _local is not None and (local_limit := _local.get(limit_type)) is not None:
247-
values["limit"] = local_limit
248-
elif _scope is not None and (scope_limit := _scope.get(limit_type)) is not None:
238+
if _scope is not None and (scope_limit := _scope.get(limit_type)) is not None:
249239
values["limit"] = scope_limit
250240
else:
251241
values["limit"] = _global[limit_type]
252242

253-
if _local is not None and (local_window := _local.get("window")) is not None:
254-
values["window"] = local_window
255-
elif _scope is not None and (scope_window := _scope.get("window")) is not None:
243+
if _scope is not None and (scope_window := _scope.get("window")) is not None:
256244
values["window"] = scope_window
257245
else:
258246
values["window"] = _global["window"]
259247

260248
return values
261249

262250

263-
def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window=None, scope=None):
251+
def get_rate_limit_data(request, scope=None):
264252
"""Fetch key for the given request. If an Authorization header is provided,
265253
the caller will get a better and personalized rate limit. If no header is provided,
266254
the caller will be rate limited by IP, which gets an overall lower rate limit.
267255
This should encourage callers to always provide the Authorization token
268256
269257
Args:
270258
request: The Flask request object
271-
per_token_limit: Optional override for per-token limit (uses cache value if None)
272-
per_ip_limit: Optional override for per-IP limit (uses cache value if None)
273-
window: Optional override for rate limit window in seconds (uses cache value if None)
274259
scope: Optional scope name to check for scope-specific limits in cache
275260
276261
Limit resolution order (first non-None value wins):
277-
1. Decorator parameters (per_token_limit, per_ip_limit, window)
278-
2. Scope-specific limits from cache (if scope is provided)
262+
1. Scope-specific limits from cache (if scope is provided)
279263
3. Global limits from cache
280264
"""
281265
_global = get_rate_limits()
282266
_scope = get_rate_limits(scope) if scope else None
283-
_local = {
284-
"per_token": per_token_limit,
285-
"per_ip": per_ip_limit,
286-
"window": window
287-
}
288267

289268
# If a user verification function is provided, parse the Authorization header and try to look up that user
290269
if ratelimit_user_validation:
@@ -294,7 +273,7 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window
294273
is_valid = ratelimit_user_validation(auth_token)
295274
if is_valid:
296275
values = _get_rate_limit_helper(
297-
"per_token", _global=_global, _scope=_scope, _local=_local
276+
"per_token", _global=_global, _scope=_scope
298277
)
299278
values["key"] = auth_token
300279
return values
@@ -306,21 +285,18 @@ def get_rate_limit_data(request, per_token_limit=None, per_ip_limit=None, window
306285
ip = request.remote_addr
307286

308287
values = _get_rate_limit_helper(
309-
"per_ip", _global=_global, _scope=_scope, _local=_local
288+
"per_ip", _global=_global, _scope=_scope
310289
)
311290
values["key"] = ip
312291
return values
313292

314293

315-
def ratelimit(per_token_limit=None, per_ip_limit=None, window=None, scope=None):
294+
def ratelimit(scope=None):
316295
"""
317296
This is the decorator that should be applied to all view functions that should be
318297
rate limited.
319298
320299
Args:
321-
per_token_limit: Optional override for per-token limit (uses cache value if None)
322-
per_ip_limit: Optional override for per-IP limit (uses cache value if None)
323-
window: Optional override for rate limit window in seconds (uses cache value if None)
324300
scope: Optional scope to isolate rate limits for different endpoints.
325301
If provided, the rate limit key will be scoped with this value,
326302
allowing different endpoints to have separate rate limit buckets..
@@ -329,9 +305,6 @@ def decorator(f):
329305
def rate_limited(*args, **kwargs):
330306
data = get_rate_limit_data(
331307
request,
332-
per_token_limit=per_token_limit,
333-
per_ip_limit=per_ip_limit,
334-
window=window,
335308
scope=scope
336309
)
337310
key = f"{scope}:{data['key']}" if scope else data["key"]

0 commit comments

Comments
 (0)