forked from bayuzen19/dtsense-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp_langserve.py
More file actions
121 lines (103 loc) · 3.9 KB
/
app_langserve.py
File metadata and controls
121 lines (103 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from langserve import add_routes
from pydantic import BaseModel, Field
import os
import logging
from typing import Optional, Dict, Any
from src.document_pipeline import DocumentPipeline
from langchain_core.runnables import RunnableLambda
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Define request/response models with Pydantic
class QueryRequest(BaseModel):
question: str = Field(..., min_length=1, description="The question to ask the medical chatbot")
class QueryResponse(BaseModel):
answer: str
class ErrorResponse(BaseModel):
error: str
detail: Optional[str] = None
# Initialize FastAPI app
app = FastAPI(
title="Medical Chatbot API with Groq",
description="A LangServe application serving a medical chatbot interface using Groq LLM",
version="1.0.0"
)
# Enable CORS for all origins (adjust in production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Templates directory
templates = Jinja2Templates(directory="template")
# Initialize DocumentPipeline
try:
PDF_DIR = os.getenv("PDF_DIR", "Data")
INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "medicalbot")
logger.info(f"Initializing DocumentPipeline with PDF_DIR={PDF_DIR}, INDEX_NAME={INDEX_NAME}")
pipeline = DocumentPipeline(pdf_dir=PDF_DIR, index_name=INDEX_NAME)
pipeline.load_vectore_store()
pipeline.create_retrieval_chain()
logger.info("DocumentPipeline initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize DocumentPipeline: {str(e)}")
raise
# Create a wrapper function for the pipeline that conforms to Runnable interface
def query_pipeline(input_data: Dict[str, Any]) -> Dict[str, Any]:
question = input_data.get("question", "")
if not question:
raise ValueError("Question cannot be empty")
# Call the existing query method
answer = pipeline.query(question)
return {"answer": answer}
# Convert the function to a Runnable
runnable_pipeline = RunnableLambda(query_pipeline)
@app.get("/", response_class=HTMLResponse)
async def get_chat(request: Request):
"""Serve the chat UI page."""
return templates.TemplateResponse("chat_langserve.html", {"request": request})
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
# Add LangServe routes with proper Runnable
add_routes(
app,
runnable_pipeline, # Using the RunnableLambda wrapper
path="/api/medical-chatbot",
input_type=QueryRequest,
output_type=QueryResponse
)
# Keep original endpoint for backward compatibility
@app.post("/query", response_model=QueryResponse, responses={
400: {"model": ErrorResponse},
500: {"model": ErrorResponse}
})
async def query(request: QueryRequest):
"""
Process a question and return the answer from the medical chatbot.
"""
try:
logger.info(f"Received query: {request.question}")
# Use the Runnable pipeline
result = runnable_pipeline.invoke({"question": request.question})
logger.info("Query processed successfully")
return QueryResponse(answer=result["answer"])
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)