Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions app/api/predictions/[id]/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import type { NextRequest } from "next/server";
import { NextResponse } from "next/server";
import { getCorsHeaders } from "@/lib/networking/getCorsHeaders";
import { getOnePredictionHandler } from "@/lib/predictions/getOnePredictionHandler";

/**
* OPTIONS handler for CORS preflight requests.
*
* @returns A NextResponse with CORS headers.
*/
export async function OPTIONS() {
return new NextResponse(null, {
status: 200,
headers: getCorsHeaders(),
});
}

/**
* GET /api/predictions/{id}
*
* Get a specific engagement prediction by UUID. Returns the full prediction
* including engagement timeline, peak moments, weak spots, and regional
* activation data.
*
* Authentication: x-api-key header or Authorization Bearer token required.
*
* @param request - The request object.
* @param options - Route options containing params.
* @param options.params - Route params containing the prediction UUID.
* @returns A NextResponse with the prediction or error.
*/
export async function GET(
request: NextRequest,
options: { params: Promise<{ id: string }> },
): Promise<NextResponse> {
const { id } = await options.params;
return getOnePredictionHandler(request, id);
Comment on lines +32 to +37
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate id as UUID before delegating to the handler.

The route currently forwards raw id. Add Zod validation at the boundary and return 400 on invalid params.

Example boundary validation
 import type { NextRequest } from "next/server";
 import { NextResponse } from "next/server";
+import { z } from "zod";
 import { getCorsHeaders } from "@/lib/networking/getCorsHeaders";
 import { getOnePredictionHandler } from "@/lib/predictions/getOnePredictionHandler";
+
+const getPredictionParamsSchema = z.object({ id: z.string().uuid() });

 export async function GET(
   request: NextRequest,
   options: { params: Promise<{ id: string }> },
 ): Promise<NextResponse> {
-  const { id } = await options.params;
-  return getOnePredictionHandler(request, id);
+  const parsed = getPredictionParamsSchema.safeParse(await options.params);
+  if (!parsed.success) {
+    return NextResponse.json(
+      { status: "error", error: "Invalid prediction id" },
+      { status: 400, headers: getCorsHeaders() },
+    );
+  }
+  return getOnePredictionHandler(request, parsed.data.id);
 }

As per coding guidelines: “All API endpoints should use a validate function for input parsing using Zod for schema validation.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@app/api/predictions/`[id]/route.ts around lines 32 - 37, Validate the
incoming route param `id` in the GET function before calling
getOnePredictionHandler: create a Zod schema (e.g., predictionIdSchema =
z.object({ id: z.string().uuid() })) and parse await options.params with it,
returning a NextResponse with status 400 if validation fails; on success extract
the validated id and pass that to getOnePredictionHandler(request, id). Ensure
this validation occurs in the GET function (which currently reads const { id } =
await options.params) so malformed UUIDs never reach getOnePredictionHandler.

}
49 changes: 49 additions & 0 deletions app/api/predictions/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import type { NextRequest } from "next/server";
import { NextResponse } from "next/server";
import { getCorsHeaders } from "@/lib/networking/getCorsHeaders";
import { postCreatePredictionHandler } from "@/lib/predictions/postCreatePredictionHandler";
import { getListPredictionsHandler } from "@/lib/predictions/getListPredictionsHandler";

/**
* OPTIONS handler for CORS preflight requests.
*
* @returns A NextResponse with CORS headers.
*/
export async function OPTIONS() {
return new NextResponse(null, {
status: 200,
headers: getCorsHeaders(),
});
}

/**
* POST /api/predictions
*
* Run a neural engagement prediction on video, audio, or text content.
* Accepts a file URL and modality, returns an engagement score (0-100),
* timeline, peak moments, weak spots, and regional brain activation data.
* The result is persisted for later retrieval.
*
* Authentication: x-api-key header or Authorization Bearer token required.
*
* @param request - The request object containing the JSON body.
* @returns A NextResponse with the prediction result or error.
*/
export async function POST(request: NextRequest): Promise<NextResponse> {
return postCreatePredictionHandler(request);
}

/**
* GET /api/predictions
*
* List past engagement predictions for the authenticated account.
* Supports limit and offset query parameters for pagination.
*
* Authentication: x-api-key header or Authorization Bearer token required.
*
* @param request - The request object with optional query params.
* @returns A NextResponse with the predictions array or error.
*/
export async function GET(request: NextRequest): Promise<NextResponse> {
return getListPredictionsHandler(request);
}
4 changes: 4 additions & 0 deletions lib/const.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ export const RECOUP_API_KEY = process.env.RECOUP_API_KEY || "";
export const FLAMINGO_GENERATE_URL =
"https://sidney-78147--music-flamingo-musicflamingo-generate.modal.run";

/** TRIBE v2 neural engagement prediction endpoint (Modal) */
export const TRIBE_PREDICT_URL =
"https://sidney-78147--tribe-predict-tribev2predict-predict.modal.run";

/** Snapshot expiration duration (7 days) */
export const SNAPSHOT_EXPIRY_MS = 7 * 24 * 60 * 60 * 1000;

Expand Down
2 changes: 2 additions & 0 deletions lib/mcp/tools/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { registerWebDeepResearchTool } from "./registerWebDeepResearchTool";
import { registerArtistDeepResearchTool } from "./registerArtistDeepResearchTool";
import { registerAllFileTools } from "./files";
import { registerAllFlamingoTools } from "./flamingo";
import { registerAllTribeTools } from "./tribe";
import { registerCreateSegmentsTool } from "./registerCreateSegmentsTool";
import { registerAllYouTubeTools } from "./youtube";
import { registerTranscribeTools } from "./transcribe";
Expand Down Expand Up @@ -53,5 +54,6 @@ export const registerAllTools = (server: McpServer): void => {
registerSendEmailTool(server);
registerUpdateAccountInfoTool(server);
registerCreateSegmentsTool(server);
registerAllTribeTools(server);
registerAllYouTubeTools(server);
};
13 changes: 13 additions & 0 deletions lib/mcp/tools/tribe/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { registerPredictEngagementTool } from "./registerPredictEngagementTool";
import { registerGetPredictionsTool } from "./registerGetPredictionsTool";

/**
* Registers all TRIBE v2 engagement prediction MCP tools.
*
* @param server - The MCP server instance.
*/
export function registerAllTribeTools(server: McpServer): void {
registerPredictEngagementTool(server);
registerGetPredictionsTool(server);
}
72 changes: 72 additions & 0 deletions lib/mcp/tools/tribe/registerGetPredictionsTool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js";
import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js";
import { z } from "zod";
import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey";
import { resolveAccountId } from "@/lib/mcp/resolveAccountId";
import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess";
import { getToolResultError } from "@/lib/mcp/getToolResultError";
import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions";

const getPredictionsSchema = z.object({
limit: z
.number()
.int()
.min(1)
.max(100)
.default(20)
.describe("Maximum number of predictions to return (1-100, default 20)."),
offset: z
.number()
.int()
.min(0)
.default(0)
.describe("Number of predictions to skip for pagination (default 0)."),
});

type GetPredictionsArgs = z.infer<typeof getPredictionsSchema>;

/**
* Registers the get_predictions MCP tool on the server.
* Lists past engagement predictions for the authenticated account.
*
* @param server - The MCP server instance to register the tool on.
*/
export function registerGetPredictionsTool(server: McpServer): void {
server.registerTool(
"get_predictions",
{
description:
"List past neural engagement predictions for your account. " +
"Returns prediction summaries (id, modality, engagement_score, created_at) " +
"sorted by newest first. Use to compare scores across content iterations.",
inputSchema: getPredictionsSchema,
},
async (
args: GetPredictionsArgs,
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
) => {
const authInfo = extra.authInfo as McpAuthInfo | undefined;
const { accountId, error } = await resolveAccountId({
authInfo,
accountIdOverride: undefined,
});

if (error) {
return getToolResultError(error);
}

if (!accountId) {
return getToolResultError("Failed to resolve account ID");
}

try {
const predictions = await selectPredictions(accountId, args.limit, args.offset);
return getToolResultSuccess({ predictions });
} catch (err) {
const message = err instanceof Error ? err.message : "Failed to fetch predictions";
return getToolResultError(message);
}
},
);
}
92 changes: 92 additions & 0 deletions lib/mcp/tools/tribe/registerPredictEngagementTool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js";
import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js";
import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey";
import { resolveAccountId } from "@/lib/mcp/resolveAccountId";
import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess";
import { getToolResultError } from "@/lib/mcp/getToolResultError";
import {
createPredictionBodySchema,
type CreatePredictionBody,
} from "@/lib/tribe/validateCreatePredictionBody";
import { processPredictRequest } from "@/lib/tribe/processPredictRequest";
import { insertPrediction } from "@/lib/supabase/predictions/insertPrediction";

/**
* Registers the predict_engagement MCP tool on the server.
* Runs TRIBE v2 neural engagement prediction and persists the result.
*
* @param server - The MCP server instance to register the tool on.
*/
export function registerPredictEngagementTool(server: McpServer): void {
server.registerTool(
"predict_engagement",
{
description:
"Predict neural engagement for video, audio, or text content. " +
"Returns an engagement score (0-100), a per-timestep timeline showing where " +
"attention peaks and drops, and brain region activation data. " +
"Use to compare content iterations — predict v1, edit weak spots, predict v2. " +
"Requires a publicly accessible file_url and modality (video, audio, or text).",
inputSchema: createPredictionBodySchema,
},
async (
args: CreatePredictionBody,
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
) => {
const authInfo = extra.authInfo as McpAuthInfo | undefined;
const { accountId, error } = await resolveAccountId({
authInfo,
accountIdOverride: undefined,
});

if (error) {
return getToolResultError(error);
}

if (!accountId) {
return getToolResultError("Failed to resolve account ID");
}

let result;
try {
result = await processPredictRequest(args);
} catch (err) {
const message = err instanceof Error ? err.message : "An unexpected error occurred";
return getToolResultError(`Engagement prediction failed: ${message}`);
}

if (result.type === "error") {
return getToolResultError(result.error);
}

const { type: _, ...metrics } = result;

try {
const row = await insertPrediction({
account_id: accountId,
file_url: args.file_url,
modality: args.modality,
...metrics,
});

return getToolResultSuccess({
id: row.id,
file_url: row.file_url,
modality: row.modality,
engagement_score: row.engagement_score,
engagement_timeline: row.engagement_timeline,
peak_moments: row.peak_moments,
weak_spots: row.weak_spots,
regional_activation: row.regional_activation,
total_duration_seconds: row.total_duration_seconds,
elapsed_seconds: row.elapsed_seconds,
created_at: row.created_at,
});
} catch (err) {
const message = err instanceof Error ? err.message : "Failed to save prediction";
return getToolResultError(message);
}
},
);
}
88 changes: 88 additions & 0 deletions lib/predictions/__tests__/getListPredictionsHandler.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { NextRequest, NextResponse } from "next/server";
import { getListPredictionsHandler } from "../getListPredictionsHandler";

vi.mock("@/lib/networking/getCorsHeaders", () => ({
getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })),
}));

vi.mock("@/lib/auth/validateAuthContext", () => ({
validateAuthContext: vi.fn(),
}));

vi.mock("@/lib/supabase/predictions/selectPredictions", () => ({
selectPredictions: vi.fn(),
}));

import { validateAuthContext } from "@/lib/auth/validateAuthContext";
import { selectPredictions } from "@/lib/supabase/predictions/selectPredictions";

describe("getListPredictionsHandler", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("returns 200 with predictions array", async () => {
vi.mocked(validateAuthContext).mockResolvedValue({
accountId: "account-uuid",
orgId: null,
authToken: "test-token",
});
vi.mocked(selectPredictions).mockResolvedValue([
{
id: "pred-1",
file_url: "https://example.com/v1.mp4",
modality: "video",
engagement_score: 73.2,
created_at: "2026-04-10T00:00:00Z",
},
]);

const request = new NextRequest("http://localhost/api/predictions");
const response = await getListPredictionsHandler(request);
expect(response.status).toBe(200);

const body = await response.json();
expect(body.status).toBe("success");
expect(body.predictions).toHaveLength(1);
expect(body.predictions[0].engagement_score).toBe(73.2);
});

it("passes limit and offset to selectPredictions", async () => {
vi.mocked(validateAuthContext).mockResolvedValue({
accountId: "account-uuid",
orgId: null,
authToken: "test-token",
});
vi.mocked(selectPredictions).mockResolvedValue([]);

const request = new NextRequest("http://localhost/api/predictions?limit=5&offset=10");
await getListPredictionsHandler(request);

expect(selectPredictions).toHaveBeenCalledWith("account-uuid", 5, 10);
});

it("clamps limit to 100", async () => {
vi.mocked(validateAuthContext).mockResolvedValue({
accountId: "account-uuid",
orgId: null,
authToken: "test-token",
});
vi.mocked(selectPredictions).mockResolvedValue([]);

const request = new NextRequest("http://localhost/api/predictions?limit=500");
await getListPredictionsHandler(request);

expect(selectPredictions).toHaveBeenCalledWith("account-uuid", 100, 0);
});

it("returns 401 when auth fails", async () => {
vi.mocked(validateAuthContext).mockResolvedValue(
NextResponse.json({ status: "error" }, { status: 401 }),
);

const request = new NextRequest("http://localhost/api/predictions");
const response = await getListPredictionsHandler(request);
expect(response.status).toBe(401);
});
});
Loading
Loading