-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag-server.py
More file actions
152 lines (138 loc) · 5.49 KB
/
rag-server.py
File metadata and controls
152 lines (138 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""RAG Service API.
This FastAPI application exposes three endpoints:
* ``/{collection}/upload`` – Accepts a JSON payload containing a list of strings and stores the embeddings in Qdrant.
* ``/{collection}/upload_pdf`` – Accepts a PDF file, extracts text, chunks it, generates embeddings, and stores them.
* ``/{collection}/query`` – Takes a prompt, queries the vector store and returns a response from an LLM.
The service uses `sentence-transformers` for embeddings and `qdrant-client` to interact with Qdrant. All endpoints return a JSON response or the raw content from the LLM.
"""
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from typing import List
from uuid import uuid4
from qdrant_client.models import PointStruct
import os
import requests
from fastapi import UploadFile, File
from io import BytesIO
import PyPDF2
import re
def chunk_text(text: str, chunk_size: int = 60, overlap: int = 30) -> list[str]:
"""Split *text* into chunks.
Parameters
----------
text: str
Raw text to split.
chunk_size: int, optional
Number of words per chunk.
overlap: int, optional
Number of words that overlap between consecutive chunks.
Returns
-------
list[str]
List of text chunks.
"""
if overlap >= chunk_size:
raise ValueError("Overlap must be smaller than chunk size.")
words = re.split(r"\s+", text)
if not words:
return []
chunks = []
step_size = chunk_size - overlap
for i in range(0, len(words), step_size):
end = i + chunk_size
chunk_words = words[i:end]
chunk = " ".join(chunk_words)
chunks.append(chunk)
if end >= len(words):
break
return chunks
app = FastAPI()
model = SentenceTransformer("BAAI/bge-m3")
client = QdrantClient(url=f"http://localhost:{os.getenv("QDRANT_PORT")}")
class UploadPayload(BaseModel):
texts: List[str]
class Query(BaseModel):
text: str
@app.post("/{collection}/upload")
def upload(collection: str, payload: UploadPayload):
"""Upload a list of text strings.
The endpoint encodes each string into an embedding and upserts it into the
specified Qdrant collection. Each point stores the original text in the
payload so it can be retrieved during a query.
"""
print("Encoding list of texts to embeddings...")
embeddings = model.encode(payload.texts, convert_to_numpy=True)
points = [
PointStruct(
id=str(uuid4()),
vector=embeddings[i].tolist(),
payload={"text": payload.texts[i]}
)
for i in range(len(payload.texts))
]
print("Upserting embeddings to collection...")
client.upsert(collection_name=collection, points=points)
return {"status": "success", "count": len(payload.texts)}
@app.post("/{collection}/upload_pdf")
async def upload_pdf(collection: str, file: UploadFile = File(...)):
"""Upload a PDF file and store its content in Qdrant.
The PDF is read into memory, text is extracted page by page, split into
chunks, and each chunk is embedded. The resulting points are upserted into
the collection.
"""
print("Reading PDF file...")
pdf_bytes = await file.read()
reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
chunks = chunk_text(text)
print(f"Split into {len(chunks)} chunks.")
embeddings = model.encode(chunks, convert_to_numpy=True)
points = [
PointStruct(
id=str(uuid4()),
vector=embeddings[i].tolist(),
payload={"text": chunks[i]}
)
for i in range(len(chunks))
]
print("Upserting embeddings to collection...")
client.upsert(collection_name=collection, points=points)
return {"status": "success", "count": len(chunks)}
@app.post("/{collection}/query")
def embed(collection: str, prompt: Query):
"""Query the vector store with a prompt.
The prompt is embedded, the nearest points are retrieved from Qdrant, and
the context is formatted into a prompt for an external LLM service. The
LLM response is returned directly.
"""
print("Converting prompt to embedding...")
embedding = model.encode(prompt.text, convert_to_numpy=True).tolist()
print("Query vector database for embedding...")
search_result = client.query_points(
collection_name=collection,
query=embedding,
with_payload=True,
limit=3
).points
print("Creating augmented prompt...")
contextArr = [item.payload['text'] for item in search_result]
context = ", ".join([f"{i+1}. {line}" for i, line in enumerate(contextArr)])
payload = {
"prompt": f"""
You are a AI assistant augmented with an Vector Store.
To help you answer the questions, a context will be provided. This context is generated by querying the vector store with the user question.
Answer the question at the end using only the information available in the context.
Answer always in a summarizing manner in full sentences and never mention the context.
-------------
Context:\n{context}\n-------------
Question: {prompt}\nHelpful answer:\n"""
}
print(f"Created the augmented prompt: {payload.get('prompt')}")
print("Waiting for llm response...")
headers = {"Content-Type": "application/json"}
response = requests.post(f"http://localhost:{os.getenv("LLAMA_PORT")}/completion", json=payload, headers=headers)
return response.content