Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 159 additions & 35 deletions apps/ask-gateway/app/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,36 @@
from .mcp_client import McpClientError, McpHttpClient
from .models import AskStreamRequest, ToolCall
from .response_synthesizer import synthesize_final_response
from .usage_tracker import (
resolve_model_for_user_async,
record_usage_async,
get_user_usage_async,
)
from . import supabase_store

MAX_TOOL_ITERATIONS = 30
MAX_TOOL_ITERATIONS = 20

_SYSTEM_PROMPT = """\
You are an AI course assistant for Princeton University students. You help students \
find courses, understand workload, compare options, and make informed decisions.

You have access to tools that search courses, get course details, evaluations, \
instructor info, and more. Use them to answer accurately.
The upcoming term is Fall 2026 (term code 1272). Unless the user specifies otherwise, \
default to searching and discussing courses for Fall 2026. The current term is Spring 2026 (1264).

Guidelines:
- Always use tools to look up real data. Do not fabricate course information.
- After receiving tool results, synthesize a helpful, conversational response.
- When comparing courses, highlight key differences (rating, workload, schedule).
- Format course codes as "DEPT NNN" (e.g., COS 226, not COS226).
- Keep responses concise but thorough. Use bullet points and bold for readability.
- If a course is not found, say so honestly and suggest alternatives.
- When searching for courses, prefer term 1272 (Fall 2026) unless the user asks about a different term.
"""

_SCHEDULE_PROMPT_ADDENDUM = """

You also have access to the user's TigerJunction (junction.tigerapps.org) schedule.
- Get their schedules with get_user_schedules (no userId needed — you are already authenticated)
When the user asks about "my schedule", "my courses", or wants to add/remove/manage courses, use tools.
When the user wants to find courses that fit their schedule, use search_courses with the scheduleId parameter — this combines all search filters (department, text, days, time, instructor, distribution) with schedule conflict checking.
"""


Expand Down Expand Up @@ -120,15 +133,37 @@ async def _stream_agentic(
request_id = request_id or str(uuid.uuid4())
conversation_id = payload.conversationId or str(uuid.uuid4())
prompt = payload.messages[-1].content
mcp_client = McpHttpClient(self._settings)
mcp_url = self._settings.junction_mcp_url if payload.netid else None
mcp_client = McpHttpClient(self._settings, netid=payload.netid, mcp_url=mcp_url)
llm_client = OpenAiLlmClient(self._settings)
session_id: str | None = None

# Quota enforcement
quota_before: dict | None = None
effective_model: str | None = payload.model
if payload.netid:
quota_before = await resolve_model_for_user_async(payload.netid)
if quota_before["blocked"]:
yield sse_event(
"quota_exhausted",
{
"percentUsed": 100,
"resetSeconds": quota_before["resetSeconds"],
"requestId": request_id,
},
)
return
effective_model = quota_before["model"]

try:
yield sse_event("status", {"phase": "starting", "requestId": request_id})

system_prompt = _SYSTEM_PROMPT
if payload.netid:
system_prompt += _SCHEDULE_PROMPT_ADDENDUM

messages: list[dict[str, Any]] = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "system", "content": system_prompt},
*[m.model_dump() for m in payload.messages],
]

Expand All @@ -139,6 +174,12 @@ async def _stream_agentic(
# list_tools initializes the session, so capture it
session_id = mcp_client._session_id
collected_usage: dict[str, Any] | None = None
# Accumulate usage across all LLM iterations (tool-calling loop)
total_cost = 0.0
total_input_tokens = 0
total_output_tokens = 0
# Track tool calls for persistence
persisted_tool_events: list[dict[str, Any]] = []

for iteration in range(MAX_TOOL_ITERATIONS):
if is_disconnected():
Expand All @@ -150,13 +191,18 @@ async def _stream_agentic(
finish_reason: str | None = None

async for chunk in llm_client.stream_chat(
messages=messages, tools=llm_tools, model=payload.model
messages=messages, tools=llm_tools, model=effective_model
):
if chunk.get("type") == "done":
break

if chunk.get("usage"):
collected_usage = chunk["usage"]
total_cost += collected_usage.get("cost") or 0
total_input_tokens += collected_usage.get("prompt_tokens") or 0
total_output_tokens += (
collected_usage.get("completion_tokens") or 0
)

choices = chunk.get("choices", [])
if not choices:
Expand Down Expand Up @@ -200,11 +246,47 @@ async def _stream_agentic(
# since some models like Gemini use "stop" even with tool calls).
if not collected_tool_calls:
usage = {
"inputTokens": (collected_usage or {}).get("prompt_tokens", 0),
"outputTokens": (collected_usage or {}).get(
"completion_tokens", 0
),
"inputTokens": total_input_tokens,
"outputTokens": total_output_tokens,
}

# Record cost and get updated quota
quota_after: dict | None = None
if payload.netid:
if total_cost > 0:
await record_usage_async(payload.netid, total_cost)
quota_after = await get_user_usage_async(payload.netid)

# Save conversation to Supabase
conv_title = (
payload.messages[0].content[:80]
if payload.messages
else "New chat"
)
await supabase_store.save_message(
conversation_id, payload.netid, conv_title, "user", prompt
)
# Save tool calls/results
for te in persisted_tool_events:
await supabase_store.save_message(
conversation_id,
payload.netid,
conv_title,
te["type"],
json.dumps(te, default=str),
)
await supabase_store.save_message(
conversation_id,
payload.netid,
conv_title,
"assistant",
collected_content,
cost=total_cost if total_cost > 0 else None,
input_tokens=total_input_tokens or None,
output_tokens=total_output_tokens or None,
model=effective_model,
)

yield sse_event(
"status",
{
Expand All @@ -213,15 +295,22 @@ async def _stream_agentic(
**({"sessionId": session_id} if session_id else {}),
},
)
yield sse_event(
"done",
{
"conversationId": conversation_id,
"requestId": request_id,
**({"sessionId": session_id} if session_id else {}),
"usage": usage,
},
)

done_data: dict[str, Any] = {
"conversationId": conversation_id,
"requestId": request_id,
**({"sessionId": session_id} if session_id else {}),
"usage": usage,
}
if quota_after is not None:
done_data["quota"] = {
"percentUsed": quota_after["percentUsed"],
"tier": quota_after["tier"],
"tierChanged": quota_before is not None
and quota_before["tier"] != quota_after["tier"],
"resetSeconds": quota_after["resetSeconds"],
}
yield sse_event("done", done_data)
return

yield sse_event(
Expand Down Expand Up @@ -267,10 +356,25 @@ async def _stream_agentic(
"sessionId": session_id,
},
)
persisted_tool_events.append(
{
"type": "tool_call",
"name": tool_name,
"arguments": tool_args,
}
)
result = await asyncio.wait_for(
mcp_client.call_tool(tool_name, tool_args),
timeout=self._settings.tool_timeout_seconds,
)
persisted_tool_events.append(
{
"type": "tool_result",
"name": tool_name,
"ok": True,
"result": result,
}
)
yield sse_event(
"tool_result",
{
Expand Down Expand Up @@ -354,8 +458,24 @@ async def _stream_deterministic(
request_id = request_id or str(uuid.uuid4())
conversation_id = payload.conversationId or str(uuid.uuid4())
prompt = payload.messages[-1].content
mcp_client = McpHttpClient(self._settings)
mcp_url = self._settings.junction_mcp_url if payload.netid else None
mcp_client = McpHttpClient(self._settings, netid=payload.netid, mcp_url=mcp_url)
session_id: str | None = None

# Quota enforcement (deterministic doesn't call LLM, but still check)
if payload.netid:
det_quota = await resolve_model_for_user_async(payload.netid)
if det_quota["blocked"]:
yield sse_event(
"quota_exhausted",
{
"percentUsed": 100,
"resetSeconds": det_quota["resetSeconds"],
"requestId": request_id,
},
)
return

try:
yield sse_event("status", {"phase": "starting", "requestId": request_id})
tool_calls = _plan_tools(prompt, payload.term)
Expand Down Expand Up @@ -432,18 +552,24 @@ async def _stream_deterministic(
**({"sessionId": session_id} if session_id else {}),
},
)
yield sse_event(
"done",
{
"conversationId": conversation_id,
"requestId": request_id,
**({"sessionId": session_id} if session_id else {}),
"usage": {
"inputTokens": 0,
"outputTokens": len(response_text.split()),
},
det_done_data: dict[str, Any] = {
"conversationId": conversation_id,
"requestId": request_id,
**({"sessionId": session_id} if session_id else {}),
"usage": {
"inputTokens": 0,
"outputTokens": len(response_text.split()),
},
)
}
if payload.netid:
det_q = await get_user_usage_async(payload.netid)
det_done_data["quota"] = {
"percentUsed": det_q["percentUsed"],
"tier": det_q["tier"],
"tierChanged": False,
"resetSeconds": det_q["resetSeconds"],
}
yield sse_event("done", det_done_data)
except asyncio.CancelledError:
yield sse_event(
"error",
Expand Down Expand Up @@ -505,5 +631,3 @@ def _extract_reasoning(delta: dict[str, Any]) -> str:
if isinstance(reasoning, str):
return reasoning
return ""


3 changes: 3 additions & 0 deletions apps/ask-gateway/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _env_bool(name: str, default: bool) -> bool:
class Settings:
gateway_api_token: str = os.getenv("ASK_GATEWAY_API_TOKEN", "")
mcp_url: str = os.getenv("JUNCTION_MCP_URL", "http://localhost:3000/mcp")
junction_mcp_url: str = os.getenv("JUNCTION_MCP_URL_SCHEDULE", "http://localhost:3000/junction/mcp")
mcp_token: str = os.getenv("JUNCTION_MCP_TOKEN", "")
mcp_protocol_version: str = os.getenv("MCP_PROTOCOL_VERSION", "2025-03-26")
tool_timeout_seconds: float = float(os.getenv("ASK_TOOL_TIMEOUT_SECONDS", "10"))
Expand All @@ -33,3 +34,5 @@ class Settings:
ask_llm_timeout_seconds: float = float(os.getenv("ASK_LLM_TIMEOUT_SECONDS", "12"))
ask_llm_planner_enabled: bool = _env_bool("ASK_LLM_PLANNER_ENABLED", False)
ask_llm_synthesis_enabled: bool = _env_bool("ASK_LLM_SYNTHESIS_ENABLED", False)
supabase_url: str = os.getenv("SUPABASE_URL", "")
supabase_service_role_key: str = os.getenv("SUPABASE_SERVICE_ROLE_KEY", "")
1 change: 1 addition & 0 deletions apps/ask-gateway/app/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def stream_chat(
"model": model or self._settings.ask_llm_model,
"messages": messages,
"stream": True,
"stream_options": {"include_usage": True},
}
if tools:
request["tools"] = tools
Expand Down
42 changes: 42 additions & 0 deletions apps/ask-gateway/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .chat_service import ChatService
from .config import Settings
from .models import AskStreamRequest
from .usage_tracker import get_user_usage_async
from . import supabase_store

app = FastAPI(title="Ask Gateway", version="1.0.0")
logger = logging.getLogger("ask-gateway")
Expand All @@ -36,6 +38,46 @@ async def health() -> dict[str, str]:
return {"status": "ok"}


@app.get("/ask/quota")
async def get_quota(
netid: str,
authorization: str | None = Header(default=None),
settings: Settings = Depends(get_settings),
) -> dict:
_validate_gateway_auth(settings, authorization)
if not netid:
raise HTTPException(status_code=400, detail="netid is required")
return await get_user_usage_async(netid)


@app.get("/ask/conversations")
async def list_conversations(
netid: str,
authorization: str | None = Header(default=None),
settings: Settings = Depends(get_settings),
) -> list:
_validate_gateway_auth(settings, authorization)
if not netid:
raise HTTPException(status_code=400, detail="netid is required")
return await supabase_store.list_conversations(netid)


@app.get("/ask/conversations/{conv_id}/messages")
async def get_conversation_messages(
conv_id: str,
netid: str,
authorization: str | None = Header(default=None),
settings: Settings = Depends(get_settings),
) -> list:
_validate_gateway_auth(settings, authorization)
if not netid:
raise HTTPException(status_code=400, detail="netid is required")
messages = await supabase_store.get_conversation_messages(conv_id, netid)
if messages is None:
raise HTTPException(status_code=404, detail="Conversation not found")
return messages


@app.post("/ask/stream")
async def ask_stream(
payload: AskStreamRequest,
Expand Down
Loading