-
Notifications
You must be signed in to change notification settings - Fork 9
feat: add neural engagement prediction endpoint (TRIBE v2) #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sidneyswift
wants to merge
1
commit into
test
Choose a base branch
from
feature/predictions-endpoint
base: test
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import type { NextRequest } from "next/server"; | ||
| import { NextResponse } from "next/server"; | ||
| import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; | ||
| import { getOnePredictionHandler } from "@/lib/predictions/getOnePredictionHandler"; | ||
|
|
||
| /** | ||
| * OPTIONS handler for CORS preflight requests. | ||
| * | ||
| * @returns A NextResponse with CORS headers. | ||
| */ | ||
| export async function OPTIONS() { | ||
| return new NextResponse(null, { | ||
| status: 200, | ||
| headers: getCorsHeaders(), | ||
| }); | ||
| } | ||
|
|
||
| /** | ||
| * GET /api/predictions/{id} | ||
| * | ||
| * Get a specific engagement prediction by UUID. Returns the full prediction | ||
| * including engagement timeline, peak moments, weak spots, and regional | ||
| * activation data. | ||
| * | ||
| * Authentication: x-api-key header or Authorization Bearer token required. | ||
| * | ||
| * @param request - The request object. | ||
| * @param options - Route options containing params. | ||
| * @param options.params - Route params containing the prediction UUID. | ||
| * @returns A NextResponse with the prediction or error. | ||
| */ | ||
| export async function GET( | ||
| request: NextRequest, | ||
| options: { params: Promise<{ id: string }> }, | ||
| ): Promise<NextResponse> { | ||
| const { id } = await options.params; | ||
| return getOnePredictionHandler(request, id); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import type { NextRequest } from "next/server"; | ||
| import { NextResponse } from "next/server"; | ||
| import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; | ||
| import { postCreatePredictionHandler } from "@/lib/predictions/postCreatePredictionHandler"; | ||
| import { getListPredictionsHandler } from "@/lib/predictions/getListPredictionsHandler"; | ||
|
|
||
| /** | ||
| * OPTIONS handler for CORS preflight requests. | ||
| * | ||
| * @returns A NextResponse with CORS headers. | ||
| */ | ||
| export async function OPTIONS() { | ||
| return new NextResponse(null, { | ||
| status: 200, | ||
| headers: getCorsHeaders(), | ||
| }); | ||
| } | ||
|
|
||
| /** | ||
| * POST /api/predictions | ||
| * | ||
| * Run a neural engagement prediction on video, audio, or text content. | ||
| * Accepts a file URL and modality, returns an engagement score (0-100), | ||
| * timeline, peak moments, weak spots, and regional brain activation data. | ||
| * The result is persisted for later retrieval. | ||
| * | ||
| * Authentication: x-api-key header or Authorization Bearer token required. | ||
| * | ||
| * @param request - The request object containing the JSON body. | ||
| * @returns A NextResponse with the prediction result or error. | ||
| */ | ||
| export async function POST(request: NextRequest): Promise<NextResponse> { | ||
| return postCreatePredictionHandler(request); | ||
| } | ||
|
|
||
| /** | ||
| * GET /api/predictions | ||
| * | ||
| * List past engagement predictions for the authenticated account. | ||
| * Supports limit and offset query parameters for pagination. | ||
| * | ||
| * Authentication: x-api-key header or Authorization Bearer token required. | ||
| * | ||
| * @param request - The request object with optional query params. | ||
| * @returns A NextResponse with the predictions array or error. | ||
| */ | ||
| export async function GET(request: NextRequest): Promise<NextResponse> { | ||
| return getListPredictionsHandler(request); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; | ||
| import { registerPredictEngagementTool } from "./registerPredictEngagementTool"; | ||
| import { registerGetPredictionsTool } from "./registerGetPredictionsTool"; | ||
|
|
||
| /** | ||
| * Registers all TRIBE v2 engagement prediction MCP tools. | ||
| * | ||
| * @param server - The MCP server instance. | ||
| */ | ||
| export function registerAllTribeTools(server: McpServer): void { | ||
| registerPredictEngagementTool(server); | ||
| registerGetPredictionsTool(server); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; | ||
| import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; | ||
| import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; | ||
| import { z } from "zod"; | ||
| import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; | ||
| import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; | ||
| import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; | ||
| import { getToolResultError } from "@/lib/mcp/getToolResultError"; | ||
| import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions"; | ||
|
|
||
| const getPredictionsSchema = z.object({ | ||
| limit: z | ||
| .number() | ||
| .int() | ||
| .min(1) | ||
| .max(100) | ||
| .default(20) | ||
| .describe("Maximum number of predictions to return (1-100, default 20)."), | ||
| offset: z | ||
| .number() | ||
| .int() | ||
| .min(0) | ||
| .default(0) | ||
| .describe("Number of predictions to skip for pagination (default 0)."), | ||
| }); | ||
|
|
||
| type GetPredictionsArgs = z.infer<typeof getPredictionsSchema>; | ||
|
|
||
| /** | ||
| * Registers the get_predictions MCP tool on the server. | ||
| * Lists past engagement predictions for the authenticated account. | ||
| * | ||
| * @param server - The MCP server instance to register the tool on. | ||
| */ | ||
| export function registerGetPredictionsTool(server: McpServer): void { | ||
| server.registerTool( | ||
| "get_predictions", | ||
| { | ||
| description: | ||
| "List past neural engagement predictions for your account. " + | ||
| "Returns prediction summaries (id, modality, engagement_score, created_at) " + | ||
| "sorted by newest first. Use to compare scores across content iterations.", | ||
| inputSchema: getPredictionsSchema, | ||
| }, | ||
| async ( | ||
| args: GetPredictionsArgs, | ||
| extra: RequestHandlerExtra<ServerRequest, ServerNotification>, | ||
| ) => { | ||
| const authInfo = extra.authInfo as McpAuthInfo | undefined; | ||
| const { accountId, error } = await resolveAccountId({ | ||
| authInfo, | ||
| accountIdOverride: undefined, | ||
| }); | ||
|
|
||
| if (error) { | ||
| return getToolResultError(error); | ||
| } | ||
|
|
||
| if (!accountId) { | ||
| return getToolResultError("Failed to resolve account ID"); | ||
| } | ||
|
|
||
| try { | ||
| const predictions = await selectPredictions(accountId, args.limit, args.offset); | ||
| return getToolResultSuccess({ predictions }); | ||
| } catch (err) { | ||
| const message = err instanceof Error ? err.message : "Failed to fetch predictions"; | ||
| return getToolResultError(message); | ||
| } | ||
| }, | ||
| ); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; | ||
| import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; | ||
| import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; | ||
| import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; | ||
| import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; | ||
| import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; | ||
| import { getToolResultError } from "@/lib/mcp/getToolResultError"; | ||
| import { | ||
| createPredictionBodySchema, | ||
| type CreatePredictionBody, | ||
| } from "@/lib/tribe/validateCreatePredictionBody"; | ||
| import { processPredictRequest } from "@/lib/tribe/processPredictRequest"; | ||
| import { insertPrediction } from "@/lib/supabase/predictions/insertPrediction"; | ||
|
|
||
| /** | ||
| * Registers the predict_engagement MCP tool on the server. | ||
| * Runs TRIBE v2 neural engagement prediction and persists the result. | ||
| * | ||
| * @param server - The MCP server instance to register the tool on. | ||
| */ | ||
| export function registerPredictEngagementTool(server: McpServer): void { | ||
| server.registerTool( | ||
| "predict_engagement", | ||
| { | ||
| description: | ||
| "Predict neural engagement for video, audio, or text content. " + | ||
| "Returns an engagement score (0-100), a per-timestep timeline showing where " + | ||
| "attention peaks and drops, and brain region activation data. " + | ||
| "Use to compare content iterations — predict v1, edit weak spots, predict v2. " + | ||
| "Requires a publicly accessible file_url and modality (video, audio, or text).", | ||
| inputSchema: createPredictionBodySchema, | ||
| }, | ||
| async ( | ||
| args: CreatePredictionBody, | ||
| extra: RequestHandlerExtra<ServerRequest, ServerNotification>, | ||
| ) => { | ||
| const authInfo = extra.authInfo as McpAuthInfo | undefined; | ||
| const { accountId, error } = await resolveAccountId({ | ||
| authInfo, | ||
| accountIdOverride: undefined, | ||
| }); | ||
|
|
||
| if (error) { | ||
| return getToolResultError(error); | ||
| } | ||
|
|
||
| if (!accountId) { | ||
| return getToolResultError("Failed to resolve account ID"); | ||
| } | ||
|
|
||
| let result; | ||
| try { | ||
| result = await processPredictRequest(args); | ||
| } catch (err) { | ||
| const message = err instanceof Error ? err.message : "An unexpected error occurred"; | ||
| return getToolResultError(`Engagement prediction failed: ${message}`); | ||
| } | ||
|
|
||
| if (result.type === "error") { | ||
| return getToolResultError(result.error); | ||
| } | ||
|
|
||
| const { type: _, ...metrics } = result; | ||
|
|
||
| try { | ||
| const row = await insertPrediction({ | ||
| account_id: accountId, | ||
| file_url: args.file_url, | ||
| modality: args.modality, | ||
| ...metrics, | ||
| }); | ||
|
|
||
| return getToolResultSuccess({ | ||
| id: row.id, | ||
| file_url: row.file_url, | ||
| modality: row.modality, | ||
| engagement_score: row.engagement_score, | ||
| engagement_timeline: row.engagement_timeline, | ||
| peak_moments: row.peak_moments, | ||
| weak_spots: row.weak_spots, | ||
| regional_activation: row.regional_activation, | ||
| total_duration_seconds: row.total_duration_seconds, | ||
| elapsed_seconds: row.elapsed_seconds, | ||
| created_at: row.created_at, | ||
| }); | ||
| } catch (err) { | ||
| const message = err instanceof Error ? err.message : "Failed to save prediction"; | ||
| return getToolResultError(message); | ||
| } | ||
| }, | ||
| ); | ||
| } |
88 changes: 88 additions & 0 deletions
88
lib/predictions/__tests__/getListPredictionsHandler.test.ts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import { describe, it, expect, vi, beforeEach } from "vitest"; | ||
| import { NextRequest, NextResponse } from "next/server"; | ||
| import { getListPredictionsHandler } from "../getListPredictionsHandler"; | ||
|
|
||
| vi.mock("@/lib/networking/getCorsHeaders", () => ({ | ||
| getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), | ||
| })); | ||
|
|
||
| vi.mock("@/lib/auth/validateAuthContext", () => ({ | ||
| validateAuthContext: vi.fn(), | ||
| })); | ||
|
|
||
| vi.mock("@/lib/supabase/predictions/selectPredictions", () => ({ | ||
| selectPredictions: vi.fn(), | ||
| })); | ||
|
|
||
| import { validateAuthContext } from "@/lib/auth/validateAuthContext"; | ||
| import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions"; | ||
|
|
||
| describe("getListPredictionsHandler", () => { | ||
| beforeEach(() => { | ||
| vi.clearAllMocks(); | ||
| }); | ||
|
|
||
| it("returns 200 with predictions array", async () => { | ||
| vi.mocked(validateAuthContext).mockResolvedValue({ | ||
| accountId: "account-uuid", | ||
| orgId: null, | ||
| authToken: "test-token", | ||
| }); | ||
| vi.mocked(selectPredictions).mockResolvedValue([ | ||
| { | ||
| id: "pred-1", | ||
| file_url: "https://example.com/v1.mp4", | ||
| modality: "video", | ||
| engagement_score: 73.2, | ||
| created_at: "2026-04-10T00:00:00Z", | ||
| }, | ||
| ]); | ||
|
|
||
| const request = new NextRequest("http://localhost/api/predictions"); | ||
| const response = await getListPredictionsHandler(request); | ||
| expect(response.status).toBe(200); | ||
|
|
||
| const body = await response.json(); | ||
| expect(body.status).toBe("success"); | ||
| expect(body.predictions).toHaveLength(1); | ||
| expect(body.predictions[0].engagement_score).toBe(73.2); | ||
| }); | ||
|
|
||
| it("passes limit and offset to selectPredictions", async () => { | ||
| vi.mocked(validateAuthContext).mockResolvedValue({ | ||
| accountId: "account-uuid", | ||
| orgId: null, | ||
| authToken: "test-token", | ||
| }); | ||
| vi.mocked(selectPredictions).mockResolvedValue([]); | ||
|
|
||
| const request = new NextRequest("http://localhost/api/predictions?limit=5&offset=10"); | ||
| await getListPredictionsHandler(request); | ||
|
|
||
| expect(selectPredictions).toHaveBeenCalledWith("account-uuid", 5, 10); | ||
| }); | ||
|
|
||
| it("clamps limit to 100", async () => { | ||
| vi.mocked(validateAuthContext).mockResolvedValue({ | ||
| accountId: "account-uuid", | ||
| orgId: null, | ||
| authToken: "test-token", | ||
| }); | ||
| vi.mocked(selectPredictions).mockResolvedValue([]); | ||
|
|
||
| const request = new NextRequest("http://localhost/api/predictions?limit=500"); | ||
| await getListPredictionsHandler(request); | ||
|
|
||
| expect(selectPredictions).toHaveBeenCalledWith("account-uuid", 100, 0); | ||
| }); | ||
|
|
||
| it("returns 401 when auth fails", async () => { | ||
| vi.mocked(validateAuthContext).mockResolvedValue( | ||
| NextResponse.json({ status: "error" }, { status: 401 }), | ||
| ); | ||
|
|
||
| const request = new NextRequest("http://localhost/api/predictions"); | ||
| const response = await getListPredictionsHandler(request); | ||
| expect(response.status).toBe(401); | ||
| }); | ||
| }); |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate
idas UUID before delegating to the handler.The route currently forwards raw
id. Add Zod validation at the boundary and return 400 on invalid params.Example boundary validation
import type { NextRequest } from "next/server"; import { NextResponse } from "next/server"; +import { z } from "zod"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import { getOnePredictionHandler } from "@/lib/predictions/getOnePredictionHandler"; + +const getPredictionParamsSchema = z.object({ id: z.string().uuid() }); export async function GET( request: NextRequest, options: { params: Promise<{ id: string }> }, ): Promise<NextResponse> { - const { id } = await options.params; - return getOnePredictionHandler(request, id); + const parsed = getPredictionParamsSchema.safeParse(await options.params); + if (!parsed.success) { + return NextResponse.json( + { status: "error", error: "Invalid prediction id" }, + { status: 400, headers: getCorsHeaders() }, + ); + } + return getOnePredictionHandler(request, parsed.data.id); }As per coding guidelines: “All API endpoints should use a validate function for input parsing using Zod for schema validation.”
🤖 Prompt for AI Agents