|
11 | 11 | class RateLimitTracker: |
12 | 12 | def __init__(self): |
13 | 13 | self.client_limits = {} # client_id -> rate_limit_info |
14 | | - |
| 14 | + |
| 15 | + def _parse_header_value(self, value: Any) -> Optional[int]: |
| 16 | + """ |
| 17 | + Parses a header value that could be a number or "unlimited". |
| 18 | + Returns None if the value is "unlimited", None, or invalid. |
| 19 | + """ |
| 20 | + # Handle non-string types (lists, dicts, None, etc.) |
| 21 | + if value is None: |
| 22 | + return None |
| 23 | + if isinstance(value, (list, dict)): |
| 24 | + return None |
| 25 | + |
| 26 | + # Convert to string and normalize |
| 27 | + try: |
| 28 | + str_value = str(value).strip().lower() |
| 29 | + except Exception: |
| 30 | + return None |
| 31 | + |
| 32 | + if not str_value or str_value == "unlimited": |
| 33 | + return None |
| 34 | + |
| 35 | + try: |
| 36 | + # Handle floats by converting to float first, then int |
| 37 | + parsed = int(float(str_value)) |
| 38 | + # Rate limits can't be negative |
| 39 | + return parsed if parsed >= 0 else None |
| 40 | + except (ValueError, TypeError): |
| 41 | + return None |
| 42 | + |
15 | 43 | def update_rate_limit(self, client_id: str, headers: Dict[str, str]): |
16 | 44 | if client_id not in self.client_limits: |
17 | 45 | self.client_limits[client_id] = {} |
18 | | - |
| 46 | + |
19 | 47 | limit_info = self.client_limits[client_id] |
20 | | - |
21 | | - # Update rate limit headers |
22 | | - if 'X-Ratelimit-Limit-Requests' in headers: |
23 | | - limit_info['limit'] = int(headers['X-Ratelimit-Limit-Requests']) |
24 | | - if 'X-Ratelimit-Remaining-Requests' in headers: |
25 | | - limit_info['remaining'] = int(headers['X-Ratelimit-Remaining-Requests']) |
26 | | - if 'X-Ratelimit-Reset-Requests' in headers: |
27 | | - limit_info['reset_time'] = headers['X-Ratelimit-Reset-Requests'] |
28 | | - if 'retry-after' in headers: |
29 | | - limit_info['retry_after'] = int(headers['retry-after']) |
30 | | - |
| 48 | + |
| 49 | + # Get header values (case-insensitive lookup) |
| 50 | + headers_lower = {k.lower(): v for k, v in headers.items()} |
| 51 | + |
| 52 | + limit_requests = headers_lower.get('x-ratelimit-limit-requests') |
| 53 | + remaining_requests = headers_lower.get('x-ratelimit-remaining-requests') |
| 54 | + reset_requests = headers_lower.get('x-ratelimit-reset-requests') |
| 55 | + retry_after = headers_lower.get('retry-after') |
| 56 | + |
| 57 | + # Only update numeric values if they are valid numbers (not "unlimited") |
| 58 | + parsed_limit = self._parse_header_value(limit_requests) |
| 59 | + parsed_remaining = self._parse_header_value(remaining_requests) |
| 60 | + parsed_retry_after = self._parse_header_value(retry_after) |
| 61 | + |
| 62 | + if parsed_limit is not None: |
| 63 | + limit_info['limit'] = parsed_limit |
| 64 | + if parsed_remaining is not None: |
| 65 | + limit_info['remaining'] = parsed_remaining |
| 66 | + # Mark as unlimited if header explicitly says "unlimited" |
| 67 | + if remaining_requests is not None: |
| 68 | + try: |
| 69 | + if str(remaining_requests).strip().lower() == "unlimited": |
| 70 | + limit_info['is_unlimited'] = True |
| 71 | + except Exception: |
| 72 | + pass |
| 73 | + if reset_requests: |
| 74 | + limit_info['reset_time'] = reset_requests |
| 75 | + if parsed_retry_after is not None: |
| 76 | + limit_info['retry_after'] = parsed_retry_after |
| 77 | + |
31 | 78 | limit_info['last_updated'] = time.time() |
32 | | - |
| 79 | + |
33 | 80 | def is_rate_limited(self, client_id: str) -> bool: |
34 | 81 | if client_id not in self.client_limits: |
35 | 82 | return False |
36 | | - |
| 83 | + |
37 | 84 | limit_info = self.client_limits[client_id] |
38 | | - return limit_info.get('remaining', 1) <= 0 |
| 85 | + # If marked as unlimited, never rate limited |
| 86 | + if limit_info.get('is_unlimited', False): |
| 87 | + return False |
| 88 | + # Only consider rate limited if remaining is explicitly set and is 0 or less |
| 89 | + remaining = limit_info.get('remaining') |
| 90 | + return remaining is not None and remaining <= 0 |
39 | 91 |
|
40 | 92 | def get_retry_after(self, client_id: str) -> Optional[int]: |
41 | 93 | if client_id not in self.client_limits: |
|
0 commit comments