-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
182 lines (153 loc) · 5.73 KB
/
main.py
File metadata and controls
182 lines (153 loc) · 5.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import time
import logging
from datetime import datetime, UTC
from typing import List, Dict, Any, Optional
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from src.api.routes import router
from src.api.task_routes import router as task_router
from src.config import Config
from src.middleware.security import SecurityHeadersMiddleware
# Logging configuration
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address, default_limits=["100/minute"])
class HealthCheckCache:
def __init__(self, ttl_seconds: int = 300):
self.ttl_seconds = ttl_seconds
self.cache: Dict[str, tuple[bool, float]] = {}
def get(self, key: str) -> Optional[bool]:
if key in self.cache:
result, timestamp = self.cache[key]
if time.time() - timestamp < self.ttl_seconds:
return result
return None
def set(self, key: str, value: bool) -> None:
self.cache[key] = (value, time.time())
health_cache = HealthCheckCache(ttl_seconds=300)
async def check_api_connectivity() -> bool:
cached = health_cache.get("api_connectivity")
if cached is not None:
return cached
try:
response = Config.CLIENT.models.generate_content(
model=Config.MODEL_NAME,
contents="ping",
config=Config.get_generation_config(max_output_tokens=10)
)
result = bool(response.text)
health_cache.set("api_connectivity", result)
return result
except Exception as e:
logger.error(f"API health check failed: {e}")
health_cache.set("api_connectivity", False)
return False
def check_disk_writable() -> bool:
try:
test_file = os.path.join(Config.INDEX_DIR, ".health_check")
with open(test_file, "w") as f:
f.write("ok")
os.remove(test_file)
return True
except Exception as e:
print(f"❌ Disk write check failed: {e}")
return False
def check_indices_directory() -> bool:
try:
return os.path.isdir(Config.INDEX_DIR) and os.access(Config.INDEX_DIR, os.R_OK)
except Exception as e:
print(f"❌ Indices directory check failed: {e}")
return False
def initialize_directories() -> None:
os.makedirs(Config.RAW_DATA_DIR, exist_ok=True)
os.makedirs(Config.INDEX_DIR, exist_ok=True)
def create_app() -> FastAPI:
initialize_directories()
app = FastAPI(
title="TreeRAG API",
version="1.0.0",
description="AI-powered regulatory document consultation system",
docs_url="/docs",
redoc_url="/redoc"
)
allowed_origins: List[str] = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"],
expose_headers=["*"],
max_age=3600,
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@app.get("/health")
async def health_check():
"""Shallow health check - service is alive."""
return {
"status": "healthy",
"service": "TreeRAG API",
"timestamp": datetime.now(UTC).isoformat()
}
@app.get("/health/deep")
async def deep_health_check() -> Dict[str, Any]:
"""
Deep health check - validates all critical dependencies.
Results are cached for 5 minutes to avoid unnecessary API calls
and rate limit issues with Kubernetes readiness probes.
"""
health_status = {
"status": "healthy",
"service": "TreeRAG API",
"timestamp": datetime.now(UTC).isoformat(),
"checks": {},
"cache_info": {
"ttl_seconds": health_cache.ttl_seconds,
"description": "API check results cached to prevent rate limiting"
}
}
api_ok = await check_api_connectivity()
health_status["checks"]["gemini_api"] = {
"status": "healthy" if api_ok else "unhealthy",
"message": "API accessible" if api_ok else "API connection failed",
"cached": health_cache.get("api_connectivity") is not None
}
disk_ok = check_disk_writable()
health_status["checks"]["disk_storage"] = {
"status": "healthy" if disk_ok else "unhealthy",
"message": "Writable" if disk_ok else "No write permission"
}
indices_ok = check_indices_directory()
health_status["checks"]["indices_directory"] = {
"status": "healthy" if indices_ok else "unhealthy",
"message": "Accessible" if indices_ok else "Directory missing"
}
all_healthy = api_ok and disk_ok and indices_ok
health_status["status"] = "healthy" if all_healthy else "degraded"
return health_status
app.include_router(router, prefix="/api", tags=["API"])
app.include_router(task_router, prefix="/api/tasks", tags=["Tasks"])
return app
app = create_app()
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info"
)