From 00d8759b5a892fa76cadda7928cd7669856a2953 Mon Sep 17 00:00:00 2001 From: Sidney Swift <158200036+sidneyswift@users.noreply.github.com> Date: Fri, 10 Apr 2026 11:43:46 -0400 Subject: [PATCH] feat: add neural engagement prediction endpoint (TRIBE v2) New REST resource at /api/predictions with POST (create), GET (list), and GET /:id (detail). Includes Modal integration, Supabase persistence, Zod validation, MCP tools (predict_engagement, get_predictions), and 26 unit tests. Made-with: Cursor --- app/api/predictions/[id]/route.ts | 38 +++++ app/api/predictions/route.ts | 49 ++++++ lib/const.ts | 4 + lib/mcp/tools/index.ts | 2 + lib/mcp/tools/tribe/index.ts | 13 ++ .../tools/tribe/registerGetPredictionsTool.ts | 72 +++++++++ .../tribe/registerPredictEngagementTool.ts | 92 ++++++++++++ .../getListPredictionsHandler.test.ts | 88 +++++++++++ .../__tests__/getOnePredictionHandler.test.ts | 93 ++++++++++++ .../postCreatePredictionHandler.test.ts | 142 ++++++++++++++++++ lib/predictions/getListPredictionsHandler.ts | 40 +++++ lib/predictions/getOnePredictionHandler.ts | 66 ++++++++ .../postCreatePredictionHandler.ts | 89 +++++++++++ lib/supabase/predictions/insertPrediction.ts | 44 ++++++ .../predictions/selectPredictionById.ts | 37 +++++ lib/supabase/predictions/selectPredictions.ts | 36 +++++ lib/tribe/__tests__/callTribePredict.test.ts | 91 +++++++++++ .../__tests__/processPredictRequest.test.ts | 52 +++++++ .../validateCreatePredictionBody.test.ts | 66 ++++++++ lib/tribe/callTribePredict.ts | 36 +++++ lib/tribe/isTribePredictResult.ts | 33 ++++ lib/tribe/processPredictRequest.ts | 33 ++++ lib/tribe/validateCreatePredictionBody.ts | 40 +++++ 23 files changed, 1256 insertions(+) create mode 100644 app/api/predictions/[id]/route.ts create mode 100644 app/api/predictions/route.ts create mode 100644 lib/mcp/tools/tribe/index.ts create mode 100644 lib/mcp/tools/tribe/registerGetPredictionsTool.ts create mode 100644 lib/mcp/tools/tribe/registerPredictEngagementTool.ts create mode 100644 lib/predictions/__tests__/getListPredictionsHandler.test.ts create mode 100644 lib/predictions/__tests__/getOnePredictionHandler.test.ts create mode 100644 lib/predictions/__tests__/postCreatePredictionHandler.test.ts create mode 100644 lib/predictions/getListPredictionsHandler.ts create mode 100644 lib/predictions/getOnePredictionHandler.ts create mode 100644 lib/predictions/postCreatePredictionHandler.ts create mode 100644 lib/supabase/predictions/insertPrediction.ts create mode 100644 lib/supabase/predictions/selectPredictionById.ts create mode 100644 lib/supabase/predictions/selectPredictions.ts create mode 100644 lib/tribe/__tests__/callTribePredict.test.ts create mode 100644 lib/tribe/__tests__/processPredictRequest.test.ts create mode 100644 lib/tribe/__tests__/validateCreatePredictionBody.test.ts create mode 100644 lib/tribe/callTribePredict.ts create mode 100644 lib/tribe/isTribePredictResult.ts create mode 100644 lib/tribe/processPredictRequest.ts create mode 100644 lib/tribe/validateCreatePredictionBody.ts diff --git a/app/api/predictions/[id]/route.ts b/app/api/predictions/[id]/route.ts new file mode 100644 index 000000000..b211b726a --- /dev/null +++ b/app/api/predictions/[id]/route.ts @@ -0,0 +1,38 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { getOnePredictionHandler } from "@/lib/predictions/getOnePredictionHandler"; + +/** + * OPTIONS handler for CORS preflight requests. + * + * @returns A NextResponse with CORS headers. + */ +export async function OPTIONS() { + return new NextResponse(null, { + status: 200, + headers: getCorsHeaders(), + }); +} + +/** + * GET /api/predictions/{id} + * + * Get a specific engagement prediction by UUID. Returns the full prediction + * including engagement timeline, peak moments, weak spots, and regional + * activation data. + * + * Authentication: x-api-key header or Authorization Bearer token required. + * + * @param request - The request object. + * @param options - Route options containing params. + * @param options.params - Route params containing the prediction UUID. + * @returns A NextResponse with the prediction or error. + */ +export async function GET( + request: NextRequest, + options: { params: Promise<{ id: string }> }, +): Promise { + const { id } = await options.params; + return getOnePredictionHandler(request, id); +} diff --git a/app/api/predictions/route.ts b/app/api/predictions/route.ts new file mode 100644 index 000000000..10d5e446a --- /dev/null +++ b/app/api/predictions/route.ts @@ -0,0 +1,49 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { postCreatePredictionHandler } from "@/lib/predictions/postCreatePredictionHandler"; +import { getListPredictionsHandler } from "@/lib/predictions/getListPredictionsHandler"; + +/** + * OPTIONS handler for CORS preflight requests. + * + * @returns A NextResponse with CORS headers. + */ +export async function OPTIONS() { + return new NextResponse(null, { + status: 200, + headers: getCorsHeaders(), + }); +} + +/** + * POST /api/predictions + * + * Run a neural engagement prediction on video, audio, or text content. + * Accepts a file URL and modality, returns an engagement score (0-100), + * timeline, peak moments, weak spots, and regional brain activation data. + * The result is persisted for later retrieval. + * + * Authentication: x-api-key header or Authorization Bearer token required. + * + * @param request - The request object containing the JSON body. + * @returns A NextResponse with the prediction result or error. + */ +export async function POST(request: NextRequest): Promise { + return postCreatePredictionHandler(request); +} + +/** + * GET /api/predictions + * + * List past engagement predictions for the authenticated account. + * Supports limit and offset query parameters for pagination. + * + * Authentication: x-api-key header or Authorization Bearer token required. + * + * @param request - The request object with optional query params. + * @returns A NextResponse with the predictions array or error. + */ +export async function GET(request: NextRequest): Promise { + return getListPredictionsHandler(request); +} diff --git a/lib/const.ts b/lib/const.ts index a5cccfac8..63bfe21af 100644 --- a/lib/const.ts +++ b/lib/const.ts @@ -36,6 +36,10 @@ export const RECOUP_API_KEY = process.env.RECOUP_API_KEY || ""; export const FLAMINGO_GENERATE_URL = "https://sidney-78147--music-flamingo-musicflamingo-generate.modal.run"; +/** TRIBE v2 neural engagement prediction endpoint (Modal) */ +export const TRIBE_PREDICT_URL = + "https://sidney-78147--tribe-predict-tribev2predict-predict.modal.run"; + /** Snapshot expiration duration (7 days) */ export const SNAPSHOT_EXPIRY_MS = 7 * 24 * 60 * 60 * 1000; diff --git a/lib/mcp/tools/index.ts b/lib/mcp/tools/index.ts index e95da17fb..bd457fbaf 100644 --- a/lib/mcp/tools/index.ts +++ b/lib/mcp/tools/index.ts @@ -13,6 +13,7 @@ import { registerWebDeepResearchTool } from "./registerWebDeepResearchTool"; import { registerArtistDeepResearchTool } from "./registerArtistDeepResearchTool"; import { registerAllFileTools } from "./files"; import { registerAllFlamingoTools } from "./flamingo"; +import { registerAllTribeTools } from "./tribe"; import { registerCreateSegmentsTool } from "./registerCreateSegmentsTool"; import { registerAllYouTubeTools } from "./youtube"; import { registerTranscribeTools } from "./transcribe"; @@ -53,5 +54,6 @@ export const registerAllTools = (server: McpServer): void => { registerSendEmailTool(server); registerUpdateAccountInfoTool(server); registerCreateSegmentsTool(server); + registerAllTribeTools(server); registerAllYouTubeTools(server); }; diff --git a/lib/mcp/tools/tribe/index.ts b/lib/mcp/tools/tribe/index.ts new file mode 100644 index 000000000..05303d96f --- /dev/null +++ b/lib/mcp/tools/tribe/index.ts @@ -0,0 +1,13 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { registerPredictEngagementTool } from "./registerPredictEngagementTool"; +import { registerGetPredictionsTool } from "./registerGetPredictionsTool"; + +/** + * Registers all TRIBE v2 engagement prediction MCP tools. + * + * @param server - The MCP server instance. + */ +export function registerAllTribeTools(server: McpServer): void { + registerPredictEngagementTool(server); + registerGetPredictionsTool(server); +} diff --git a/lib/mcp/tools/tribe/registerGetPredictionsTool.ts b/lib/mcp/tools/tribe/registerGetPredictionsTool.ts new file mode 100644 index 000000000..bf6ffe4ff --- /dev/null +++ b/lib/mcp/tools/tribe/registerGetPredictionsTool.ts @@ -0,0 +1,72 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; +import { z } from "zod"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; +import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; +import { getToolResultError } from "@/lib/mcp/getToolResultError"; +import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions"; + +const getPredictionsSchema = z.object({ + limit: z + .number() + .int() + .min(1) + .max(100) + .default(20) + .describe("Maximum number of predictions to return (1-100, default 20)."), + offset: z + .number() + .int() + .min(0) + .default(0) + .describe("Number of predictions to skip for pagination (default 0)."), +}); + +type GetPredictionsArgs = z.infer; + +/** + * Registers the get_predictions MCP tool on the server. + * Lists past engagement predictions for the authenticated account. + * + * @param server - The MCP server instance to register the tool on. + */ +export function registerGetPredictionsTool(server: McpServer): void { + server.registerTool( + "get_predictions", + { + description: + "List past neural engagement predictions for your account. " + + "Returns prediction summaries (id, modality, engagement_score, created_at) " + + "sorted by newest first. Use to compare scores across content iterations.", + inputSchema: getPredictionsSchema, + }, + async ( + args: GetPredictionsArgs, + extra: RequestHandlerExtra, + ) => { + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: undefined, + }); + + if (error) { + return getToolResultError(error); + } + + if (!accountId) { + return getToolResultError("Failed to resolve account ID"); + } + + try { + const predictions = await selectPredictions(accountId, args.limit, args.offset); + return getToolResultSuccess({ predictions }); + } catch (err) { + const message = err instanceof Error ? err.message : "Failed to fetch predictions"; + return getToolResultError(message); + } + }, + ); +} diff --git a/lib/mcp/tools/tribe/registerPredictEngagementTool.ts b/lib/mcp/tools/tribe/registerPredictEngagementTool.ts new file mode 100644 index 000000000..90b308b87 --- /dev/null +++ b/lib/mcp/tools/tribe/registerPredictEngagementTool.ts @@ -0,0 +1,92 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; +import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; +import { getToolResultError } from "@/lib/mcp/getToolResultError"; +import { + createPredictionBodySchema, + type CreatePredictionBody, +} from "@/lib/tribe/validateCreatePredictionBody"; +import { processPredictRequest } from "@/lib/tribe/processPredictRequest"; +import { insertPrediction } from "@/lib/supabase/predictions/insertPrediction"; + +/** + * Registers the predict_engagement MCP tool on the server. + * Runs TRIBE v2 neural engagement prediction and persists the result. + * + * @param server - The MCP server instance to register the tool on. + */ +export function registerPredictEngagementTool(server: McpServer): void { + server.registerTool( + "predict_engagement", + { + description: + "Predict neural engagement for video, audio, or text content. " + + "Returns an engagement score (0-100), a per-timestep timeline showing where " + + "attention peaks and drops, and brain region activation data. " + + "Use to compare content iterations — predict v1, edit weak spots, predict v2. " + + "Requires a publicly accessible file_url and modality (video, audio, or text).", + inputSchema: createPredictionBodySchema, + }, + async ( + args: CreatePredictionBody, + extra: RequestHandlerExtra, + ) => { + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: undefined, + }); + + if (error) { + return getToolResultError(error); + } + + if (!accountId) { + return getToolResultError("Failed to resolve account ID"); + } + + let result; + try { + result = await processPredictRequest(args); + } catch (err) { + const message = err instanceof Error ? err.message : "An unexpected error occurred"; + return getToolResultError(`Engagement prediction failed: ${message}`); + } + + if (result.type === "error") { + return getToolResultError(result.error); + } + + const { type: _, ...metrics } = result; + + try { + const row = await insertPrediction({ + account_id: accountId, + file_url: args.file_url, + modality: args.modality, + ...metrics, + }); + + return getToolResultSuccess({ + id: row.id, + file_url: row.file_url, + modality: row.modality, + engagement_score: row.engagement_score, + engagement_timeline: row.engagement_timeline, + peak_moments: row.peak_moments, + weak_spots: row.weak_spots, + regional_activation: row.regional_activation, + total_duration_seconds: row.total_duration_seconds, + elapsed_seconds: row.elapsed_seconds, + created_at: row.created_at, + }); + } catch (err) { + const message = err instanceof Error ? err.message : "Failed to save prediction"; + return getToolResultError(message); + } + }, + ); +} diff --git a/lib/predictions/__tests__/getListPredictionsHandler.test.ts b/lib/predictions/__tests__/getListPredictionsHandler.test.ts new file mode 100644 index 000000000..3780eb3cc --- /dev/null +++ b/lib/predictions/__tests__/getListPredictionsHandler.test.ts @@ -0,0 +1,88 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { getListPredictionsHandler } from "../getListPredictionsHandler"; + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/auth/validateAuthContext", () => ({ + validateAuthContext: vi.fn(), +})); + +vi.mock("@/lib/supabase/predictions/selectPredictions", () => ({ + selectPredictions: vi.fn(), +})); + +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions"; + +describe("getListPredictionsHandler", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns 200 with predictions array", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(selectPredictions).mockResolvedValue([ + { + id: "pred-1", + file_url: "https://example.com/v1.mp4", + modality: "video", + engagement_score: 73.2, + created_at: "2026-04-10T00:00:00Z", + }, + ]); + + const request = new NextRequest("http://localhost/api/predictions"); + const response = await getListPredictionsHandler(request); + expect(response.status).toBe(200); + + const body = await response.json(); + expect(body.status).toBe("success"); + expect(body.predictions).toHaveLength(1); + expect(body.predictions[0].engagement_score).toBe(73.2); + }); + + it("passes limit and offset to selectPredictions", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(selectPredictions).mockResolvedValue([]); + + const request = new NextRequest("http://localhost/api/predictions?limit=5&offset=10"); + await getListPredictionsHandler(request); + + expect(selectPredictions).toHaveBeenCalledWith("account-uuid", 5, 10); + }); + + it("clamps limit to 100", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(selectPredictions).mockResolvedValue([]); + + const request = new NextRequest("http://localhost/api/predictions?limit=500"); + await getListPredictionsHandler(request); + + expect(selectPredictions).toHaveBeenCalledWith("account-uuid", 100, 0); + }); + + it("returns 401 when auth fails", async () => { + vi.mocked(validateAuthContext).mockResolvedValue( + NextResponse.json({ status: "error" }, { status: 401 }), + ); + + const request = new NextRequest("http://localhost/api/predictions"); + const response = await getListPredictionsHandler(request); + expect(response.status).toBe(401); + }); +}); diff --git a/lib/predictions/__tests__/getOnePredictionHandler.test.ts b/lib/predictions/__tests__/getOnePredictionHandler.test.ts new file mode 100644 index 000000000..9099ca49c --- /dev/null +++ b/lib/predictions/__tests__/getOnePredictionHandler.test.ts @@ -0,0 +1,93 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { getOnePredictionHandler } from "../getOnePredictionHandler"; + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/auth/validateAuthContext", () => ({ + validateAuthContext: vi.fn(), +})); + +vi.mock("@/lib/supabase/predictions/selectPredictionById", () => ({ + selectPredictionById: vi.fn(), +})); + +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { selectPredictionById } from "@/lib/supabase/predictions/selectPredictionById"; + +const MOCK_PREDICTION = { + id: "pred-uuid", + account_id: "account-uuid", + file_url: "https://example.com/video.mp4", + modality: "video", + engagement_score: 73.2, + engagement_timeline: [{ time_seconds: 0, score: 45.1 }], + peak_moments: [{ time_seconds: 12.0, score: 95.4 }], + weak_spots: [{ time_seconds: 6.0, score: 22.1 }], + regional_activation: { visual_cortex: 0.72 }, + total_duration_seconds: 60.0, + elapsed_seconds: 14.2, + created_at: "2026-04-10T00:00:00Z", +}; + +describe("getOnePredictionHandler", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns 200 with prediction when found and owned", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(selectPredictionById).mockResolvedValue(MOCK_PREDICTION); + + const request = new NextRequest("http://localhost/api/predictions/pred-uuid"); + const response = await getOnePredictionHandler(request, "pred-uuid"); + expect(response.status).toBe(200); + + const body = await response.json(); + expect(body.status).toBe("success"); + expect(body.id).toBe("pred-uuid"); + expect(body.engagement_score).toBe(73.2); + }); + + it("returns 404 when prediction not found", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(selectPredictionById).mockResolvedValue(null); + + const request = new NextRequest("http://localhost/api/predictions/nonexistent"); + const response = await getOnePredictionHandler(request, "nonexistent"); + expect(response.status).toBe(404); + }); + + it("returns 404 when prediction belongs to different account", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "other-account", + orgId: null, + authToken: "test-token", + }); + vi.mocked(selectPredictionById).mockResolvedValue(MOCK_PREDICTION); + + const request = new NextRequest("http://localhost/api/predictions/pred-uuid"); + const response = await getOnePredictionHandler(request, "pred-uuid"); + expect(response.status).toBe(404); + }); + + it("returns 401 when auth fails", async () => { + vi.mocked(validateAuthContext).mockResolvedValue( + NextResponse.json({ status: "error" }, { status: 401 }), + ); + + const request = new NextRequest("http://localhost/api/predictions/pred-uuid"); + const response = await getOnePredictionHandler(request, "pred-uuid"); + expect(response.status).toBe(401); + }); +}); diff --git a/lib/predictions/__tests__/postCreatePredictionHandler.test.ts b/lib/predictions/__tests__/postCreatePredictionHandler.test.ts new file mode 100644 index 000000000..016a1d118 --- /dev/null +++ b/lib/predictions/__tests__/postCreatePredictionHandler.test.ts @@ -0,0 +1,142 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { postCreatePredictionHandler } from "../postCreatePredictionHandler"; + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/auth/validateAuthContext", () => ({ + validateAuthContext: vi.fn(), +})); + +vi.mock("@/lib/tribe/validateCreatePredictionBody", () => ({ + validateCreatePredictionBody: vi.fn(), +})); + +vi.mock("@/lib/tribe/processPredictRequest", () => ({ + processPredictRequest: vi.fn(), +})); + +vi.mock("@/lib/supabase/predictions/insertPrediction", () => ({ + insertPrediction: vi.fn(), +})); + +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { validateCreatePredictionBody } from "@/lib/tribe/validateCreatePredictionBody"; +import { processPredictRequest } from "@/lib/tribe/processPredictRequest"; +import { insertPrediction } from "@/lib/supabase/predictions/insertPrediction"; + +const MOCK_METRICS = { + type: "success" as const, + engagement_score: 73.2, + engagement_timeline: [{ time_seconds: 0, score: 45.1 }], + peak_moments: [{ time_seconds: 12.0, score: 95.4 }], + weak_spots: [{ time_seconds: 6.0, score: 22.1 }], + regional_activation: { visual_cortex: 0.72 }, + total_duration_seconds: 60.0, + elapsed_seconds: 14.2, +}; + +const MOCK_ROW = { + id: "test-uuid", + account_id: "account-uuid", + file_url: "https://example.com/video.mp4", + modality: "video", + engagement_score: 73.2, + engagement_timeline: [{ time_seconds: 0, score: 45.1 }], + peak_moments: [{ time_seconds: 12.0, score: 95.4 }], + weak_spots: [{ time_seconds: 6.0, score: 22.1 }], + regional_activation: { visual_cortex: 0.72 }, + total_duration_seconds: 60.0, + elapsed_seconds: 14.2, + created_at: "2026-04-10T00:00:00Z", +}; + +describe("postCreatePredictionHandler", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns 200 with prediction on success", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(validateCreatePredictionBody).mockReturnValue({ + file_url: "https://example.com/video.mp4", + modality: "video", + }); + vi.mocked(processPredictRequest).mockResolvedValue(MOCK_METRICS); + vi.mocked(insertPrediction).mockResolvedValue(MOCK_ROW); + + const request = new NextRequest("http://localhost/api/predictions", { + method: "POST", + body: JSON.stringify({ + file_url: "https://example.com/video.mp4", + modality: "video", + }), + }); + + const response = await postCreatePredictionHandler(request); + expect(response.status).toBe(200); + + const body = await response.json(); + expect(body.status).toBe("success"); + expect(body.id).toBe("test-uuid"); + expect(body.engagement_score).toBe(73.2); + }); + + it("returns 400 on invalid JSON body", async () => { + const request = new NextRequest("http://localhost/api/predictions", { + method: "POST", + body: "not json", + headers: { "Content-Type": "application/json" }, + }); + + const response = await postCreatePredictionHandler(request); + expect(response.status).toBe(400); + }); + + it("returns auth error when validateAuthContext fails", async () => { + vi.mocked(validateAuthContext).mockResolvedValue( + NextResponse.json({ status: "error", error: "Unauthorized" }, { status: 401 }), + ); + + const request = new NextRequest("http://localhost/api/predictions", { + method: "POST", + body: JSON.stringify({ file_url: "https://example.com/v.mp4", modality: "video" }), + }); + + const response = await postCreatePredictionHandler(request); + expect(response.status).toBe(401); + }); + + it("returns 500 when prediction fails", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: "account-uuid", + orgId: null, + authToken: "test-token", + }); + vi.mocked(validateCreatePredictionBody).mockReturnValue({ + file_url: "https://example.com/video.mp4", + modality: "video", + }); + vi.mocked(processPredictRequest).mockResolvedValue({ + type: "error", + error: "Model unavailable", + }); + + const request = new NextRequest("http://localhost/api/predictions", { + method: "POST", + body: JSON.stringify({ file_url: "https://example.com/v.mp4", modality: "video" }), + }); + + const response = await postCreatePredictionHandler(request); + expect(response.status).toBe(500); + + const body = await response.json(); + expect(body.error).toBe("Model unavailable"); + }); +}); diff --git a/lib/predictions/getListPredictionsHandler.ts b/lib/predictions/getListPredictionsHandler.ts new file mode 100644 index 000000000..92964706c --- /dev/null +++ b/lib/predictions/getListPredictionsHandler.ts @@ -0,0 +1,40 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions"; + +/** + * Handler for GET /api/predictions. + * Returns past predictions for the authenticated account. + * + * @param request - The incoming request with optional limit/offset query params. + * @returns A NextResponse with the predictions array or an error. + */ +export async function getListPredictionsHandler( + request: NextRequest, +): Promise { + const authResult = await validateAuthContext(request); + if (authResult instanceof NextResponse) { + return authResult; + } + const { accountId } = authResult; + + const url = new URL(request.url); + const limit = Math.min(Math.max(parseInt(url.searchParams.get("limit") || "20", 10) || 20, 1), 100); + const offset = Math.max(parseInt(url.searchParams.get("offset") || "0", 10) || 0, 0); + + try { + const predictions = await selectPredictions(accountId, limit, offset); + return NextResponse.json( + { status: "success", predictions }, + { status: 200, headers: getCorsHeaders() }, + ); + } catch (err) { + const message = err instanceof Error ? err.message : "Failed to fetch predictions"; + return NextResponse.json( + { status: "error", error: message }, + { status: 500, headers: getCorsHeaders() }, + ); + } +} diff --git a/lib/predictions/getOnePredictionHandler.ts b/lib/predictions/getOnePredictionHandler.ts new file mode 100644 index 000000000..2bbf46a63 --- /dev/null +++ b/lib/predictions/getOnePredictionHandler.ts @@ -0,0 +1,66 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { selectPredictionById } from "@/lib/supabase/predictions/selectPredictionById"; + +/** + * Handler for GET /api/predictions/:id. + * Returns a single prediction by UUID. + * + * @param request - The incoming request. + * @param id - The prediction UUID from the URL path. + * @returns A NextResponse with the prediction or an error. + */ +export async function getOnePredictionHandler( + request: NextRequest, + id: string, +): Promise { + const authResult = await validateAuthContext(request); + if (authResult instanceof NextResponse) { + return authResult; + } + const { accountId } = authResult; + + try { + const prediction = await selectPredictionById(id); + + if (!prediction) { + return NextResponse.json( + { status: "error", error: "Prediction not found" }, + { status: 404, headers: getCorsHeaders() }, + ); + } + + if (prediction.account_id !== accountId) { + return NextResponse.json( + { status: "error", error: "Prediction not found" }, + { status: 404, headers: getCorsHeaders() }, + ); + } + + return NextResponse.json( + { + status: "success", + id: prediction.id, + file_url: prediction.file_url, + modality: prediction.modality, + engagement_score: prediction.engagement_score, + engagement_timeline: prediction.engagement_timeline, + peak_moments: prediction.peak_moments, + weak_spots: prediction.weak_spots, + regional_activation: prediction.regional_activation, + total_duration_seconds: prediction.total_duration_seconds, + elapsed_seconds: prediction.elapsed_seconds, + created_at: prediction.created_at, + }, + { status: 200, headers: getCorsHeaders() }, + ); + } catch (err) { + const message = err instanceof Error ? err.message : "Failed to fetch prediction"; + return NextResponse.json( + { status: "error", error: message }, + { status: 500, headers: getCorsHeaders() }, + ); + } +} diff --git a/lib/predictions/postCreatePredictionHandler.ts b/lib/predictions/postCreatePredictionHandler.ts new file mode 100644 index 000000000..26d6846ed --- /dev/null +++ b/lib/predictions/postCreatePredictionHandler.ts @@ -0,0 +1,89 @@ +import type { NextRequest } from "next/server"; +import { NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { validateCreatePredictionBody } from "@/lib/tribe/validateCreatePredictionBody"; +import { processPredictRequest } from "@/lib/tribe/processPredictRequest"; +import { insertPrediction } from "@/lib/supabase/predictions/insertPrediction"; + +/** + * Handler for POST /api/predictions. + * Authenticates, validates body, runs TRIBE v2 via Modal, persists result. + * + * @param request - The incoming request with a JSON body. + * @returns A NextResponse with the prediction result or an error. + */ +export async function postCreatePredictionHandler( + request: NextRequest, +): Promise { + let body: unknown; + try { + body = await request.json(); + } catch { + return NextResponse.json( + { status: "error", error: "Request body must be valid JSON" }, + { status: 400, headers: getCorsHeaders() }, + ); + } + + const authResult = await validateAuthContext(request); + if (authResult instanceof NextResponse) { + return authResult; + } + const { accountId } = authResult; + + const validated = validateCreatePredictionBody(body); + if (validated instanceof NextResponse) { + return validated; + } + + const result = await processPredictRequest(validated); + + if (result.type === "error") { + return NextResponse.json( + { status: "error", error: result.error }, + { status: 500, headers: getCorsHeaders() }, + ); + } + + const { type: _, ...metrics } = result; + + try { + const row = await insertPrediction({ + account_id: accountId, + file_url: validated.file_url, + modality: validated.modality, + engagement_score: metrics.engagement_score, + engagement_timeline: metrics.engagement_timeline, + peak_moments: metrics.peak_moments, + weak_spots: metrics.weak_spots, + regional_activation: metrics.regional_activation, + total_duration_seconds: metrics.total_duration_seconds, + elapsed_seconds: metrics.elapsed_seconds, + }); + + return NextResponse.json( + { + status: "success", + id: row.id, + file_url: row.file_url, + modality: row.modality, + engagement_score: row.engagement_score, + engagement_timeline: row.engagement_timeline, + peak_moments: row.peak_moments, + weak_spots: row.weak_spots, + regional_activation: row.regional_activation, + total_duration_seconds: row.total_duration_seconds, + elapsed_seconds: row.elapsed_seconds, + created_at: row.created_at, + }, + { status: 200, headers: getCorsHeaders() }, + ); + } catch (err) { + const message = err instanceof Error ? err.message : "Failed to save prediction"; + return NextResponse.json( + { status: "error", error: message }, + { status: 500, headers: getCorsHeaders() }, + ); + } +} diff --git a/lib/supabase/predictions/insertPrediction.ts b/lib/supabase/predictions/insertPrediction.ts new file mode 100644 index 000000000..9e28b1126 --- /dev/null +++ b/lib/supabase/predictions/insertPrediction.ts @@ -0,0 +1,44 @@ +import supabase from "../serverClient"; + +interface PredictionInsert { + account_id: string; + file_url: string; + modality: string; + engagement_score: number; + engagement_timeline: unknown; + peak_moments: unknown; + weak_spots: unknown; + regional_activation: unknown; + total_duration_seconds: number; + elapsed_seconds: number; +} + +interface PredictionRow extends PredictionInsert { + id: string; + created_at: string; +} + +/** + * Inserts a new prediction row after TRIBE v2 returns results. + * + * @param prediction - The prediction data to insert. + * @returns The inserted prediction row with id and created_at. + * @throws Error if the insert fails. + */ +export async function insertPrediction(prediction: PredictionInsert): Promise { + const { data, error } = await supabase + .from("predictions") + .insert(prediction) + .select() + .single(); + + if (error) { + throw new Error(`Failed to insert prediction: ${error.message}`); + } + + if (!data) { + throw new Error("Failed to insert prediction: No data returned"); + } + + return data as PredictionRow; +} diff --git a/lib/supabase/predictions/selectPredictionById.ts b/lib/supabase/predictions/selectPredictionById.ts new file mode 100644 index 000000000..e699bbd28 --- /dev/null +++ b/lib/supabase/predictions/selectPredictionById.ts @@ -0,0 +1,37 @@ +import supabase from "../serverClient"; + +interface PredictionRow { + id: string; + account_id: string; + file_url: string; + modality: string; + engagement_score: number; + engagement_timeline: unknown; + peak_moments: unknown; + weak_spots: unknown; + regional_activation: unknown; + total_duration_seconds: number; + elapsed_seconds: number; + created_at: string; +} + +/** + * Selects a single prediction by its UUID. + * + * @param id - The prediction UUID. + * @returns The full prediction row, or null if not found. + */ +export async function selectPredictionById(id: string): Promise { + const { data, error } = await supabase + .from("predictions") + .select("*") + .eq("id", id) + .single(); + + if (error) { + if (error.code === "PGRST116") return null; + throw new Error(`Failed to select prediction: ${error.message}`); + } + + return data as PredictionRow | null; +} diff --git a/lib/supabase/predictions/selectPredictions.ts b/lib/supabase/predictions/selectPredictions.ts new file mode 100644 index 000000000..9da04ab8f --- /dev/null +++ b/lib/supabase/predictions/selectPredictions.ts @@ -0,0 +1,36 @@ +import supabase from "../serverClient"; + +interface PredictionSummary { + id: string; + file_url: string; + modality: string; + engagement_score: number; + created_at: string; +} + +/** + * Selects predictions for an account, sorted by creation date descending. + * + * @param accountId - The account UUID to filter by. + * @param limit - Maximum number of rows (default 20, max 100). + * @param offset - Number of rows to skip (default 0). + * @returns Array of prediction summaries. + */ +export async function selectPredictions( + accountId: string, + limit = 20, + offset = 0, +): Promise { + const { data, error } = await supabase + .from("predictions") + .select("id, file_url, modality, engagement_score, created_at") + .eq("account_id", accountId) + .order("created_at", { ascending: false }) + .range(offset, offset + limit - 1); + + if (error) { + throw new Error(`Failed to select predictions: ${error.message}`); + } + + return (data ?? []) as PredictionSummary[]; +} diff --git a/lib/tribe/__tests__/callTribePredict.test.ts b/lib/tribe/__tests__/callTribePredict.test.ts new file mode 100644 index 000000000..a5e3e0674 --- /dev/null +++ b/lib/tribe/__tests__/callTribePredict.test.ts @@ -0,0 +1,91 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { callTribePredict } from "../callTribePredict"; + +const mockFetch = vi.fn(); +vi.stubGlobal("fetch", mockFetch); + +const VALID_RESPONSE = { + engagement_score: 73.2, + engagement_timeline: [{ time_seconds: 0, score: 45.1 }], + peak_moments: [{ time_seconds: 12.0, score: 95.4 }], + weak_spots: [{ time_seconds: 6.0, score: 22.1 }], + regional_activation: { visual_cortex: 0.72, auditory_cortex: 0.85 }, + total_duration_seconds: 60.0, + elapsed_seconds: 14.2, +}; + +describe("callTribePredict", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("sends correct payload to Modal endpoint", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve(VALID_RESPONSE), + }); + + await callTribePredict({ + file_url: "https://example.com/video.mp4", + modality: "video", + }); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + file_url: "https://example.com/video.mp4", + modality: "video", + }), + }), + ); + }); + + it("returns parsed response on success", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve(VALID_RESPONSE), + }); + + const result = await callTribePredict({ + file_url: "https://example.com/video.mp4", + modality: "video", + }); + + expect(result.engagement_score).toBe(73.2); + expect(result.engagement_timeline).toHaveLength(1); + expect(result.peak_moments).toHaveLength(1); + expect(result.weak_spots).toHaveLength(1); + }); + + it("throws on non-ok response", async () => { + mockFetch.mockResolvedValue({ + ok: false, + status: 500, + text: () => Promise.resolve("Internal Server Error"), + }); + + await expect( + callTribePredict({ + file_url: "https://example.com/video.mp4", + modality: "video", + }), + ).rejects.toThrow("Engagement prediction failed (status 500)"); + }); + + it("throws on unexpected response shape", async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ unexpected: "data" }), + }); + + await expect( + callTribePredict({ + file_url: "https://example.com/video.mp4", + modality: "video", + }), + ).rejects.toThrow("unexpected response shape"); + }); +}); diff --git a/lib/tribe/__tests__/processPredictRequest.test.ts b/lib/tribe/__tests__/processPredictRequest.test.ts new file mode 100644 index 000000000..d35505a7e --- /dev/null +++ b/lib/tribe/__tests__/processPredictRequest.test.ts @@ -0,0 +1,52 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { processPredictRequest } from "../processPredictRequest"; + +vi.mock("../callTribePredict", () => ({ + callTribePredict: vi.fn(), +})); + +import { callTribePredict } from "../callTribePredict"; + +const MOCK_RESULT = { + engagement_score: 73.2, + engagement_timeline: [{ time_seconds: 0, score: 45.1 }], + peak_moments: [{ time_seconds: 12.0, score: 95.4 }], + weak_spots: [{ time_seconds: 6.0, score: 22.1 }], + regional_activation: { visual_cortex: 0.72 }, + total_duration_seconds: 60.0, + elapsed_seconds: 14.2, +}; + +describe("processPredictRequest", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns success with metrics on successful prediction", async () => { + vi.mocked(callTribePredict).mockResolvedValue(MOCK_RESULT); + + const result = await processPredictRequest({ + file_url: "https://example.com/video.mp4", + modality: "video", + }); + + expect(result.type).toBe("success"); + if (result.type === "success") { + expect(result.engagement_score).toBe(73.2); + } + }); + + it("returns error when callTribePredict throws", async () => { + vi.mocked(callTribePredict).mockRejectedValue(new Error("Connection refused")); + + const result = await processPredictRequest({ + file_url: "https://example.com/video.mp4", + modality: "video", + }); + + expect(result.type).toBe("error"); + if (result.type === "error") { + expect(result.error).toBe("Connection refused"); + } + }); +}); diff --git a/lib/tribe/__tests__/validateCreatePredictionBody.test.ts b/lib/tribe/__tests__/validateCreatePredictionBody.test.ts new file mode 100644 index 000000000..06742f5de --- /dev/null +++ b/lib/tribe/__tests__/validateCreatePredictionBody.test.ts @@ -0,0 +1,66 @@ +import { describe, it, expect } from "vitest"; +import { NextResponse } from "next/server"; +import { validateCreatePredictionBody } from "../validateCreatePredictionBody"; + +describe("validateCreatePredictionBody", () => { + it("accepts valid video prediction body", () => { + const result = validateCreatePredictionBody({ + file_url: "https://storage.example.com/video.mp4", + modality: "video", + }); + expect(result).not.toBeInstanceOf(NextResponse); + expect(result).toEqual({ + file_url: "https://storage.example.com/video.mp4", + modality: "video", + }); + }); + + it("accepts valid audio prediction body", () => { + const result = validateCreatePredictionBody({ + file_url: "https://storage.example.com/track.mp3", + modality: "audio", + }); + expect(result).not.toBeInstanceOf(NextResponse); + }); + + it("accepts valid text prediction body", () => { + const result = validateCreatePredictionBody({ + file_url: "https://storage.example.com/lyrics.txt", + modality: "text", + }); + expect(result).not.toBeInstanceOf(NextResponse); + }); + + it("rejects missing file_url", () => { + const result = validateCreatePredictionBody({ modality: "video" }); + expect(result).toBeInstanceOf(NextResponse); + }); + + it("rejects invalid file_url", () => { + const result = validateCreatePredictionBody({ + file_url: "not-a-url", + modality: "video", + }); + expect(result).toBeInstanceOf(NextResponse); + }); + + it("rejects missing modality", () => { + const result = validateCreatePredictionBody({ + file_url: "https://example.com/file.mp4", + }); + expect(result).toBeInstanceOf(NextResponse); + }); + + it("rejects invalid modality", () => { + const result = validateCreatePredictionBody({ + file_url: "https://example.com/file.mp4", + modality: "image", + }); + expect(result).toBeInstanceOf(NextResponse); + }); + + it("rejects empty body", () => { + const result = validateCreatePredictionBody({}); + expect(result).toBeInstanceOf(NextResponse); + }); +}); diff --git a/lib/tribe/callTribePredict.ts b/lib/tribe/callTribePredict.ts new file mode 100644 index 000000000..babc52052 --- /dev/null +++ b/lib/tribe/callTribePredict.ts @@ -0,0 +1,36 @@ +import { TRIBE_PREDICT_URL } from "@/lib/const"; +import { isTribePredictResult, type TribePredictResult } from "./isTribePredictResult"; +import type { CreatePredictionBody } from "./validateCreatePredictionBody"; + +/** + * Calls the TRIBE v2 predict endpoint on Modal. + * Sends file_url and modality, receives engagement metrics. + * + * @param params - Validated prediction request parameters. + * @returns The engagement prediction result from Modal. + * @throws Error on network failure or unexpected response shape. + */ +export async function callTribePredict( + params: CreatePredictionBody, +): Promise { + const response = await fetch(TRIBE_PREDICT_URL, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + file_url: params.file_url, + modality: params.modality, + }), + }); + + if (!response.ok) { + const errorText = await response.text().catch(() => "Unknown error"); + console.error(`TRIBE v2 returned ${response.status}: ${errorText}`); + throw new Error(`Engagement prediction failed (status ${response.status})`); + } + + const data = await response.json(); + if (!isTribePredictResult(data)) { + throw new Error("TRIBE v2 returned an unexpected response shape"); + } + return data; +} diff --git a/lib/tribe/isTribePredictResult.ts b/lib/tribe/isTribePredictResult.ts new file mode 100644 index 000000000..01b87300f --- /dev/null +++ b/lib/tribe/isTribePredictResult.ts @@ -0,0 +1,33 @@ +/** + * Response shape from the TRIBE v2 /predict endpoint on Modal. + */ +export interface TribePredictResult { + engagement_score: number; + engagement_timeline: Array<{ time_seconds: number; score: number }>; + peak_moments: Array<{ time_seconds: number; score: number }>; + weak_spots: Array<{ time_seconds: number; score: number }>; + regional_activation: Record; + total_duration_seconds: number; + elapsed_seconds: number; +} + +/** + * Type guard for validating the TRIBE v2 predict API response. + * + * @param value - Unknown parsed JSON payload. + * @returns True when payload has the expected engagement metrics shape. + */ +export function isTribePredictResult(value: unknown): value is TribePredictResult { + if (!value || typeof value !== "object") return false; + const c = value as Record; + return ( + typeof c.engagement_score === "number" && + Array.isArray(c.engagement_timeline) && + Array.isArray(c.peak_moments) && + Array.isArray(c.weak_spots) && + typeof c.regional_activation === "object" && + c.regional_activation !== null && + typeof c.total_duration_seconds === "number" && + typeof c.elapsed_seconds === "number" + ); +} diff --git a/lib/tribe/processPredictRequest.ts b/lib/tribe/processPredictRequest.ts new file mode 100644 index 000000000..8842b7e24 --- /dev/null +++ b/lib/tribe/processPredictRequest.ts @@ -0,0 +1,33 @@ +import { callTribePredict } from "./callTribePredict"; +import type { CreatePredictionBody } from "./validateCreatePredictionBody"; +import type { TribePredictResult } from "./isTribePredictResult"; + +interface PredictSuccess extends TribePredictResult { + type: "success"; +} + +interface PredictError { + type: "error"; + error: string; +} + +export type PredictResult = PredictSuccess | PredictError; + +/** + * Shared business logic for engagement prediction. + * Used by both POST /api/predictions and the predict_engagement MCP tool. + * + * @param params - Validated prediction request parameters. + * @returns Discriminated union with type "success" or "error". + */ +export async function processPredictRequest( + params: CreatePredictionBody, +): Promise { + try { + const result = await callTribePredict(params); + return { type: "success", ...result }; + } catch (err) { + const message = err instanceof Error ? err.message : "Engagement prediction failed"; + return { type: "error", error: message }; + } +} diff --git a/lib/tribe/validateCreatePredictionBody.ts b/lib/tribe/validateCreatePredictionBody.ts new file mode 100644 index 000000000..949eb40b2 --- /dev/null +++ b/lib/tribe/validateCreatePredictionBody.ts @@ -0,0 +1,40 @@ +import { NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { z } from "zod"; + +export const createPredictionBodySchema = z.object({ + file_url: z + .string({ message: "file_url is required" }) + .url("file_url must be a valid URL"), + modality: z.enum(["video", "audio", "text"], { + message: "modality must be video, audio, or text", + }), +}); + +export type CreatePredictionBody = z.infer; + +/** + * Validates request body for POST /api/predictions. + * + * @param body - The request body. + * @returns A NextResponse with an error if validation fails, or the validated body. + */ +export function validateCreatePredictionBody( + body: unknown, +): NextResponse | CreatePredictionBody { + const result = createPredictionBodySchema.safeParse(body); + + if (!result.success) { + const firstError = result.error.issues[0]; + return NextResponse.json( + { + status: "error", + missing_fields: firstError.path, + error: firstError.message, + }, + { status: 400, headers: getCorsHeaders() }, + ); + } + + return result.data; +}