From 0a9b6ab5c4ddebf82f0685e83781a1f12da3f8d1 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 4 May 2026 20:56:30 +0700 Subject: [PATCH 1/8] feat(api): implement subscription management portal endpoints - Added a new route for managing subscription portals, including OPTIONS and POST handlers. - The OPTIONS handler responds with CORS headers for preflight requests. - The POST handler creates a subscription management session using Stripe, validating the request and returning the session ID and URL. - Introduced validation for incoming requests to ensure proper structure and authentication. - Added tests for both OPTIONS and POST handlers to verify functionality and error handling. - Implemented utility functions for creating billing portal sessions and validating requests. Files added: - app/api/subscriptions/portal/route.ts - lib/stripe/createSubscriptionPortalHandler.ts - lib/stripe/validateCreateSubscriptionPortalRequest.ts - lib/stripe/createBillingPortalSession.ts - lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts - Tests for the new functionality in __tests__ directory. This commit enhances the subscription management capabilities of the API, allowing users to manage their subscriptions effectively. --- .../portal/__tests__/route.options.test.ts | 14 ++ .../__tests__/route.post.outcomes.test.ts | 120 ++++++++++++++++++ .../__tests__/route.post.validation.test.ts | 84 ++++++++++++ .../portal/__tests__/route.test.ts | 11 ++ .../portal/__tests__/routeTestMocks.ts | 21 +++ app/api/subscriptions/portal/route.ts | 29 +++++ .../createSubscriptionPortalHandler.test.ts | 91 +++++++++++++ ...ateCreateSubscriptionPortalRequest.test.ts | 117 +++++++++++++++++ lib/stripe/createBillingPortalSession.ts | 12 ++ lib/stripe/createSubscriptionPortalHandler.ts | 44 +++++++ lib/stripe/createSubscriptionPortalSchemas.ts | 8 ++ ...validateCreateSubscriptionPortalRequest.ts | 42 ++++++ .../selectStripeBillingCustomerByAccountId.ts | 23 ++++ 13 files changed, 616 insertions(+) create mode 100644 app/api/subscriptions/portal/__tests__/route.options.test.ts create mode 100644 app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts create mode 100644 app/api/subscriptions/portal/__tests__/route.post.validation.test.ts create mode 100644 app/api/subscriptions/portal/__tests__/route.test.ts create mode 100644 app/api/subscriptions/portal/__tests__/routeTestMocks.ts create mode 100644 app/api/subscriptions/portal/route.ts create mode 100644 lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts create mode 100644 lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts create mode 100644 lib/stripe/createBillingPortalSession.ts create mode 100644 lib/stripe/createSubscriptionPortalHandler.ts create mode 100644 lib/stripe/createSubscriptionPortalSchemas.ts create mode 100644 lib/stripe/validateCreateSubscriptionPortalRequest.ts create mode 100644 lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts diff --git a/app/api/subscriptions/portal/__tests__/route.options.test.ts b/app/api/subscriptions/portal/__tests__/route.options.test.ts new file mode 100644 index 000000000..38b2bfdb7 --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/route.options.test.ts @@ -0,0 +1,14 @@ +import "./routeTestMocks"; +import { describe, it, expect } from "vitest"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; + +const { OPTIONS } = await import("../route"); + +describe("OPTIONS /api/subscriptions/portal", () => { + it("returns 200 with CORS headers", async () => { + const res = await OPTIONS(); + expect(res.status).toBe(200); + expect(getCorsHeaders).toHaveBeenCalled(); + expect(res.headers.get("Access-Control-Allow-Origin")).toBe("*"); + }); +}); diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts new file mode 100644 index 000000000..d2217a4b1 --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts @@ -0,0 +1,120 @@ +import "./routeTestMocks"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; + +const { POST } = await import("../route"); + +const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; + +describe("POST /api/subscriptions/portal (handler outcomes)", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.spyOn(console, "error").mockImplementation(() => undefined); + }); + + afterEach(() => { + vi.mocked(console.error).mockRestore(); + }); + + it("returns validation response unchanged", async () => { + const err = NextResponse.json({ error: "bad" }, { status: 400 }); + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue(err); + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + body: "{}", + }); + expect(await POST(req)).toBe(err); + expect(selectStripeBillingCustomerByAccountId).not.toHaveBeenCalled(); + }); + + it("returns 400 when no billing customer", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(null); + + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(400); + await expect(res.json()).resolves.toEqual({ error: "Billing customer not found" }); + expect(createBillingPortalSession).not.toHaveBeenCalled(); + }); + + it("returns 200 with id and url", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }); + vi.mocked(createBillingPortalSession).mockResolvedValue({ + id: "bps_test_abc", + url: "https://billing.example.com/session/abc", + } as Awaited>); + + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(200); + await expect(res.json()).resolves.toEqual({ + id: "bps_test_abc", + url: "https://billing.example.com/session/abc", + }); + }); + + it("returns 400 when session.url is null", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }); + vi.mocked(createBillingPortalSession).mockResolvedValue({ + id: "bps_test_abc", + url: null, + } as Awaited>); + + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(400); + await expect(res.json()).resolves.toEqual({ error: "Billing portal URL missing" }); + }); + + it("returns 500 when createBillingPortalSession throws", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }); + vi.mocked(createBillingPortalSession).mockRejectedValue(new Error("Stripe down")); + + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(500); + await expect(res.json()).resolves.toEqual({ error: "Internal server error" }); + }); +}); diff --git a/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts b/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts new file mode 100644 index 000000000..c1f730c8d --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts @@ -0,0 +1,84 @@ +import "./routeTestMocks"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; + +const { POST } = await import("../route"); + +async function loadRealValidate() { + const mod = await vi.importActual< + typeof import("@/lib/stripe/validateCreateSubscriptionPortalRequest") + >("@/lib/stripe/validateCreateSubscriptionPortalRequest"); + return mod.validateCreateSubscriptionPortalRequest; +} + +describe("POST /api/subscriptions/portal (validation)", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.spyOn(console, "error").mockImplementation(() => undefined); + }); + + afterEach(() => { + vi.mocked(console.error).mockRestore(); + }); + + it("returns 400 when body is invalid JSON", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockImplementationOnce( + await loadRealValidate(), + ); + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "content-type": "application/json" }, + body: "not-json", + }), + ); + expect(res.status).toBe(400); + await expect(res.json()).resolves.toEqual({ error: "Invalid JSON body" }); + expect(createBillingPortalSession).not.toHaveBeenCalled(); + }); + + it("returns 400 when returnUrl is missing", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockImplementationOnce( + await loadRealValidate(), + ); + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "content-type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({}), + }), + ); + expect(res.status).toBe(400); + const body = await res.json(); + expect(body).toEqual({ error: expect.stringMatching(/returnUrl|Invalid input/i) }); + expect(createBillingPortalSession).not.toHaveBeenCalled(); + }); + + it("returns 401 when not authenticated", async () => { + vi.mocked(validateAuthContext).mockResolvedValueOnce( + NextResponse.json( + { status: "error", error: "Exactly one of x-api-key or Authorization must be provided" }, + { status: 401 }, + ), + ); + vi.mocked(validateCreateSubscriptionPortalRequest).mockImplementationOnce( + await loadRealValidate(), + ); + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ returnUrl: "https://chat.recoupable.com/billing" }), + }), + ); + expect(res.status).toBe(401); + await expect(res.json()).resolves.toEqual({ + error: "Exactly one of x-api-key or Authorization must be provided", + }); + expect(createBillingPortalSession).not.toHaveBeenCalled(); + }); +}); diff --git a/app/api/subscriptions/portal/__tests__/route.test.ts b/app/api/subscriptions/portal/__tests__/route.test.ts new file mode 100644 index 000000000..4d5d51255 --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/route.test.ts @@ -0,0 +1,11 @@ +import "./routeTestMocks"; +import { describe, it, expect } from "vitest"; + +const { POST, OPTIONS } = await import("../route"); + +describe("app/api/subscriptions/portal/route", () => { + it("exports POST and OPTIONS handlers", () => { + expect(typeof POST).toBe("function"); + expect(typeof OPTIONS).toBe("function"); + }); +}); diff --git a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts new file mode 100644 index 000000000..8026e9019 --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts @@ -0,0 +1,21 @@ +import { vi } from "vitest"; + +vi.mock("@/lib/auth/validateAuthContext", () => ({ + validateAuthContext: vi.fn(), +})); + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/stripe/validateCreateSubscriptionPortalRequest", () => ({ + validateCreateSubscriptionPortalRequest: vi.fn(), +})); + +vi.mock("@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId", () => ({ + selectStripeBillingCustomerByAccountId: vi.fn(), +})); + +vi.mock("@/lib/stripe/createBillingPortalSession", () => ({ + createBillingPortalSession: vi.fn(), +})); diff --git a/app/api/subscriptions/portal/route.ts b/app/api/subscriptions/portal/route.ts new file mode 100644 index 000000000..89d0f319e --- /dev/null +++ b/app/api/subscriptions/portal/route.ts @@ -0,0 +1,29 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { createSubscriptionPortalHandler } from "@/lib/stripe/createSubscriptionPortalHandler"; + +/** + * OPTIONS handler for CORS preflight requests. + * + * @returns A NextResponse with CORS headers. + */ +export async function OPTIONS() { + return new NextResponse(null, { + status: 200, + headers: getCorsHeaders(), + }); +} + +/** + * POST /api/subscriptions/portal: creates a subscription management (billing portal) session. + * + * @param request - The incoming HTTP request. + * @returns A NextResponse with portal session `id` and `url`, or an error body. + */ +export async function POST(request: NextRequest) { + return createSubscriptionPortalHandler(request); +} + +export const dynamic = "force-dynamic"; +export const fetchCache = "force-no-store"; +export const revalidate = 0; diff --git a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts new file mode 100644 index 000000000..9874dccbe --- /dev/null +++ b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts @@ -0,0 +1,91 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { createSubscriptionPortalHandler } from "@/lib/stripe/createSubscriptionPortalHandler"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/stripe/validateCreateSubscriptionPortalRequest", () => ({ + validateCreateSubscriptionPortalRequest: vi.fn(), +})); + +vi.mock("@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId", () => ({ + selectStripeBillingCustomerByAccountId: vi.fn(), +})); + +vi.mock("@/lib/stripe/createBillingPortalSession", () => ({ + createBillingPortalSession: vi.fn(), +})); + +const ACCOUNT = "123e4567-e89b-12d3-a456-426614174000"; + +describe("createSubscriptionPortalHandler", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.spyOn(console, "error").mockImplementation(() => undefined); + }); + afterEach(() => vi.mocked(console.error).mockRestore()); + + it("returns validation response unchanged", async () => { + const err = NextResponse.json({ error: "bad" }, { status: 400 }); + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue(err); + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + body: "{}", + }); + expect(await createSubscriptionPortalHandler(req)).toBe(err); + expect(selectStripeBillingCustomerByAccountId).not.toHaveBeenCalled(); + }); + + it("returns 200 with id and url", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }); + vi.mocked(createBillingPortalSession).mockResolvedValue({ + id: "bps_test_abc", + url: "https://billing.example.com/session/abc", + } as Awaited>); + + const res = await createSubscriptionPortalHandler( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(200); + await expect(res.json()).resolves.toEqual({ + id: "bps_test_abc", + url: "https://billing.example.com/session/abc", + }); + }); + + it("returns 500 when createBillingPortalSession throws", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }); + vi.mocked(createBillingPortalSession).mockRejectedValue(new Error("Stripe down")); + + const res = await createSubscriptionPortalHandler( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(500); + await expect(res.json()).resolves.toEqual({ error: "Internal server error" }); + }); +}); diff --git a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts new file mode 100644 index 000000000..0c77984b7 --- /dev/null +++ b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts @@ -0,0 +1,117 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/auth/validateAuthContext", () => ({ + validateAuthContext: vi.fn(), +})); + +const ACCOUNT = "123e4567-e89b-12d3-a456-426614174000"; + +describe("validateCreateSubscriptionPortalRequest", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns 400 { error } for invalid JSON", async () => { + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: "not-json", + }); + const res = await validateCreateSubscriptionPortalRequest(req); + expect(res).toBeInstanceOf(NextResponse); + expect((res as NextResponse).status).toBe(400); + await expect((res as NextResponse).json()).resolves.toEqual({ error: "Invalid JSON body" }); + expect(validateAuthContext).not.toHaveBeenCalled(); + }); + + it("returns 400 { error } when returnUrl is missing", async () => { + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({}), + }); + const res = await validateCreateSubscriptionPortalRequest(req); + expect((res as NextResponse).status).toBe(400); + const j = await (res as NextResponse).json(); + expect(j).toEqual({ error: expect.stringMatching(/returnUrl|Invalid input/i) }); + }); + + it("returns 400 for unknown body keys (strict)", async () => { + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({ + returnUrl: "https://chat.recoupable.com/billing", + extra: true, + }), + }); + const res = await validateCreateSubscriptionPortalRequest(req); + expect((res as NextResponse).status).toBe(400); + }); + + it("maps auth failure to { error } and preserves status", async () => { + vi.mocked(validateAuthContext).mockResolvedValue( + NextResponse.json( + { status: "error", error: "Exactly one of x-api-key or Authorization must be provided" }, + { status: 401 }, + ), + ); + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ returnUrl: "https://chat.recoupable.com/billing" }), + }); + const res = await validateCreateSubscriptionPortalRequest(req); + expect((res as NextResponse).status).toBe(401); + await expect((res as NextResponse).json()).resolves.toEqual({ + error: "Exactly one of x-api-key or Authorization must be provided", + }); + }); + + it("passes accountId to validateAuthContext when provided", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: ACCOUNT, + orgId: null, + authToken: "t", + }); + const otherAccount = "123e4567-e89b-12d3-a456-426614174099"; + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({ + returnUrl: "https://chat.recoupable.com/billing", + accountId: otherAccount, + }), + }); + await validateCreateSubscriptionPortalRequest(req); + expect(validateAuthContext).toHaveBeenCalledWith(req, { accountId: otherAccount }); + }); + + it("returns accountId and returnUrl when auth succeeds", async () => { + vi.mocked(validateAuthContext).mockResolvedValue({ + accountId: ACCOUNT, + orgId: null, + authToken: "t", + }); + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({ + returnUrl: "https://chat.recoupable.com/billing", + }), + }); + const out = await validateCreateSubscriptionPortalRequest(req); + expect(out).toEqual({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + expect(validateAuthContext).toHaveBeenCalledWith(req, {}); + }); +}); diff --git a/lib/stripe/createBillingPortalSession.ts b/lib/stripe/createBillingPortalSession.ts new file mode 100644 index 000000000..afb40bc53 --- /dev/null +++ b/lib/stripe/createBillingPortalSession.ts @@ -0,0 +1,12 @@ +import type Stripe from "stripe"; +import stripeClient from "@/lib/stripe/client"; + +export async function createBillingPortalSession( + stripeCustomerId: string, + returnUrl: string, +): Promise { + return stripeClient.billingPortal.sessions.create({ + customer: stripeCustomerId, + return_url: returnUrl, + }); +} diff --git a/lib/stripe/createSubscriptionPortalHandler.ts b/lib/stripe/createSubscriptionPortalHandler.ts new file mode 100644 index 000000000..e5cf26386 --- /dev/null +++ b/lib/stripe/createSubscriptionPortalHandler.ts @@ -0,0 +1,44 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; + +export async function createSubscriptionPortalHandler(request: NextRequest): Promise { + try { + const validated = await validateCreateSubscriptionPortalRequest(request); + if (validated instanceof NextResponse) { + return validated; + } + + const billingCustomer = await selectStripeBillingCustomerByAccountId(validated.accountId); + if (!billingCustomer) { + return NextResponse.json( + { error: "Billing customer not found" }, + { status: 400, headers: getCorsHeaders() }, + ); + } + + const session = await createBillingPortalSession( + billingCustomer.customer_id, + validated.returnUrl, + ); + if (!session.url) { + return NextResponse.json( + { error: "Billing portal URL missing" }, + { status: 400, headers: getCorsHeaders() }, + ); + } + + return NextResponse.json( + { id: session.id, url: session.url }, + { status: 200, headers: getCorsHeaders() }, + ); + } catch (error) { + console.error("[createSubscriptionPortalHandler]", error); + return NextResponse.json( + { error: "Internal server error" }, + { status: 500, headers: getCorsHeaders() }, + ); + } +} diff --git a/lib/stripe/createSubscriptionPortalSchemas.ts b/lib/stripe/createSubscriptionPortalSchemas.ts new file mode 100644 index 000000000..2a3d62927 --- /dev/null +++ b/lib/stripe/createSubscriptionPortalSchemas.ts @@ -0,0 +1,8 @@ +import { z } from "zod"; + +export const createSubscriptionPortalBodySchema = z + .object({ + returnUrl: z.string().min(1, "returnUrl is required").url("returnUrl must be a valid URL"), + accountId: z.string().uuid("accountId must be a valid UUID").optional(), + }) + .strict(); diff --git a/lib/stripe/validateCreateSubscriptionPortalRequest.ts b/lib/stripe/validateCreateSubscriptionPortalRequest.ts new file mode 100644 index 000000000..e7964658a --- /dev/null +++ b/lib/stripe/validateCreateSubscriptionPortalRequest.ts @@ -0,0 +1,42 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; +import { createSubscriptionPortalBodySchema } from "@/lib/stripe/createSubscriptionPortalSchemas"; +import { mapToSubscriptionSessionError } from "@/lib/stripe/mapToSubscriptionSessionError"; + +export type ValidatedCreateSubscriptionPortalRequest = { + accountId: string; + returnUrl: string; +}; + +export async function validateCreateSubscriptionPortalRequest( + request: NextRequest, +): Promise { + let body: unknown; + try { + body = await request.json(); + } catch { + return NextResponse.json( + { error: "Invalid JSON body" }, + { status: 400, headers: getCorsHeaders() }, + ); + } + + const parsed = createSubscriptionPortalBodySchema.safeParse(body); + if (!parsed.success) { + const first = parsed.error.issues[0]; + return NextResponse.json({ error: first.message }, { status: 400, headers: getCorsHeaders() }); + } + + const authContext = await validateAuthContext(request, { + accountId: parsed.data.accountId, + }); + if (authContext instanceof NextResponse) { + return mapToSubscriptionSessionError(authContext); + } + + return { + accountId: authContext.accountId, + returnUrl: parsed.data.returnUrl, + }; +} diff --git a/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts b/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts new file mode 100644 index 000000000..4ca95f203 --- /dev/null +++ b/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts @@ -0,0 +1,23 @@ +import supabase from "@/lib/supabase/serverClient"; +import type { Tables } from "@/types/database.types"; + +/** + * Returns the Stripe billing_customers row for an account, if one exists. + */ +export async function selectStripeBillingCustomerByAccountId( + accountId: string, +): Promise | null> { + const { data, error } = await supabase + .from("billing_customers") + .select("*") + .eq("account_id", accountId) + .eq("provider", "stripe") + .maybeSingle(); + + if (error) { + console.error("selectStripeBillingCustomerByAccountId:", error); + return null; + } + + return data ?? null; +} From 0bba531a3b65cf6a55be796d4eb81ed2492ac80c Mon Sep 17 00:00:00 2001 From: john Date: Mon, 4 May 2026 22:26:07 +0700 Subject: [PATCH 2/8] Remove outdated tests for subscription portal and validation request - Deleted test files for the subscription portal outcomes and validation request, as they are no longer needed. - Updated the `selectStripeBillingCustomerByAccountId` function to throw an error instead of returning null on failure, improving error handling. This cleanup enhances the codebase by removing obsolete tests and refining error management in the billing customer selection process. --- .../route.post.outcomes.early.test.ts | 58 ++++++++++++++++ ....ts => route.post.outcomes.portal.test.ts} | 69 ++++--------------- ...ateSubscriptionPortalRequest.auth.test.ts} | 57 ++------------- ...eateSubscriptionPortalRequest.body.test.ts | 52 ++++++++++++++ .../selectStripeBillingCustomerByAccountId.ts | 2 +- 5 files changed, 131 insertions(+), 107 deletions(-) create mode 100644 app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts rename app/api/subscriptions/portal/__tests__/{route.post.outcomes.test.ts => route.post.outcomes.portal.test.ts} (62%) rename lib/stripe/__tests__/{validateCreateSubscriptionPortalRequest.test.ts => validateCreateSubscriptionPortalRequest.auth.test.ts} (55%) create mode 100644 lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts new file mode 100644 index 000000000..c4b930f25 --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts @@ -0,0 +1,58 @@ +import "./routeTestMocks"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { NextRequest, NextResponse } from "next/server"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; + +const { POST } = await import("../route"); + +const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; + +describe("POST /api/subscriptions/portal (handler outcomes — validation & missing customer)", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.spyOn(console, "error").mockImplementation(() => undefined); + }); + + afterEach(() => vi.mocked(console.error).mockRestore()); + + it("returns validation response unchanged", async () => { + const err = NextResponse.json({ error: "bad" }, { status: 400 }); + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue(err); + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + body: "{}", + }); + expect(await POST(req)).toBe(err); + expect(selectStripeBillingCustomerByAccountId).not.toHaveBeenCalled(); + }); + + it("returns 400 when no billing customer", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(null); + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(400); + await expect(res.json()).resolves.toEqual({ error: "Billing customer not found" }); + expect(createBillingPortalSession).not.toHaveBeenCalled(); + }); + + it("returns 500 when billing customer lookup fails", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockRejectedValue(new Error("supabase down")); + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(500); + await expect(res.json()).resolves.toEqual({ error: "Internal server error" }); + }); +}); diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.test.ts similarity index 62% rename from app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts rename to app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.test.ts index d2217a4b1..9dbf823bf 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.test.ts @@ -1,6 +1,6 @@ import "./routeTestMocks"; import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; -import { NextRequest, NextResponse } from "next/server"; +import { NextRequest } from "next/server"; import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; @@ -9,60 +9,33 @@ const { POST } = await import("../route"); const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; -describe("POST /api/subscriptions/portal (handler outcomes)", () => { +const billingRow = { + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe" as const, +}; + +describe("POST /api/subscriptions/portal (handler outcomes — portal session)", () => { beforeEach(() => { vi.clearAllMocks(); vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); vi.spyOn(console, "error").mockImplementation(() => undefined); }); - afterEach(() => { - vi.mocked(console.error).mockRestore(); - }); - - it("returns validation response unchanged", async () => { - const err = NextResponse.json({ error: "bad" }, { status: 400 }); - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue(err); - const req = new NextRequest("http://localhost/api/subscriptions/portal", { - method: "POST", - body: "{}", - }); - expect(await POST(req)).toBe(err); - expect(selectStripeBillingCustomerByAccountId).not.toHaveBeenCalled(); - }); - - it("returns 400 when no billing customer", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ - accountId: ACCOUNT, - returnUrl: "https://chat.recoupable.com/billing", - }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(null); - - const res = await POST( - new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), - ); - expect(res.status).toBe(400); - await expect(res.json()).resolves.toEqual({ error: "Billing customer not found" }); - expect(createBillingPortalSession).not.toHaveBeenCalled(); - }); + afterEach(() => vi.mocked(console.error).mockRestore()); it("returns 200 with id and url", async () => { vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(billingRow); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: "https://billing.example.com/session/abc", } as Awaited>); - const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); @@ -78,18 +51,11 @@ describe("POST /api/subscriptions/portal (handler outcomes)", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(billingRow); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: null, } as Awaited>); - const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); @@ -102,15 +68,8 @@ describe("POST /api/subscriptions/portal (handler outcomes)", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(billingRow); vi.mocked(createBillingPortalSession).mockRejectedValue(new Error("Stripe down")); - const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); diff --git a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts similarity index 55% rename from lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts rename to lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts index 0c77984b7..8033d47e3 100644 --- a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.test.ts +++ b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts @@ -13,48 +13,8 @@ vi.mock("@/lib/auth/validateAuthContext", () => ({ const ACCOUNT = "123e4567-e89b-12d3-a456-426614174000"; -describe("validateCreateSubscriptionPortalRequest", () => { - beforeEach(() => { - vi.clearAllMocks(); - }); - - it("returns 400 { error } for invalid JSON", async () => { - const req = new NextRequest("http://localhost/api/subscriptions/portal", { - method: "POST", - headers: { "Content-Type": "application/json", "x-api-key": "k" }, - body: "not-json", - }); - const res = await validateCreateSubscriptionPortalRequest(req); - expect(res).toBeInstanceOf(NextResponse); - expect((res as NextResponse).status).toBe(400); - await expect((res as NextResponse).json()).resolves.toEqual({ error: "Invalid JSON body" }); - expect(validateAuthContext).not.toHaveBeenCalled(); - }); - - it("returns 400 { error } when returnUrl is missing", async () => { - const req = new NextRequest("http://localhost/api/subscriptions/portal", { - method: "POST", - headers: { "Content-Type": "application/json", "x-api-key": "k" }, - body: JSON.stringify({}), - }); - const res = await validateCreateSubscriptionPortalRequest(req); - expect((res as NextResponse).status).toBe(400); - const j = await (res as NextResponse).json(); - expect(j).toEqual({ error: expect.stringMatching(/returnUrl|Invalid input/i) }); - }); - - it("returns 400 for unknown body keys (strict)", async () => { - const req = new NextRequest("http://localhost/api/subscriptions/portal", { - method: "POST", - headers: { "Content-Type": "application/json", "x-api-key": "k" }, - body: JSON.stringify({ - returnUrl: "https://chat.recoupable.com/billing", - extra: true, - }), - }); - const res = await validateCreateSubscriptionPortalRequest(req); - expect((res as NextResponse).status).toBe(400); - }); +describe("validateCreateSubscriptionPortalRequest (auth)", () => { + beforeEach(() => vi.clearAllMocks()); it("maps auth failure to { error } and preserves status", async () => { vi.mocked(validateAuthContext).mockResolvedValue( @@ -69,8 +29,8 @@ describe("validateCreateSubscriptionPortalRequest", () => { body: JSON.stringify({ returnUrl: "https://chat.recoupable.com/billing" }), }); const res = await validateCreateSubscriptionPortalRequest(req); - expect((res as NextResponse).status).toBe(401); - await expect((res as NextResponse).json()).resolves.toEqual({ + expect(res.status).toBe(401); + await expect(res.json()).resolves.toEqual({ error: "Exactly one of x-api-key or Authorization must be provided", }); }); @@ -103,15 +63,10 @@ describe("validateCreateSubscriptionPortalRequest", () => { const req = new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", headers: { "Content-Type": "application/json", "x-api-key": "k" }, - body: JSON.stringify({ - returnUrl: "https://chat.recoupable.com/billing", - }), + body: JSON.stringify({ returnUrl: "https://chat.recoupable.com/billing" }), }); const out = await validateCreateSubscriptionPortalRequest(req); - expect(out).toEqual({ - accountId: ACCOUNT, - returnUrl: "https://chat.recoupable.com/billing", - }); + expect(out).toEqual({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing" }); expect(validateAuthContext).toHaveBeenCalledWith(req, {}); }); }); diff --git a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts new file mode 100644 index 000000000..e114457f2 --- /dev/null +++ b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts @@ -0,0 +1,52 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { NextRequest } from "next/server"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateAuthContext } from "@/lib/auth/validateAuthContext"; + +vi.mock("@/lib/networking/getCorsHeaders", () => ({ + getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), +})); + +vi.mock("@/lib/auth/validateAuthContext", () => ({ + validateAuthContext: vi.fn(), +})); + +describe("validateCreateSubscriptionPortalRequest (body)", () => { + beforeEach(() => vi.clearAllMocks()); + + it("returns 400 { error } for invalid JSON", async () => { + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: "not-json", + }); + const res = await validateCreateSubscriptionPortalRequest(req); + expect(res.status).toBe(400); + await expect(res.json()).resolves.toEqual({ error: "Invalid JSON body" }); + expect(validateAuthContext).not.toHaveBeenCalled(); + }); + + it("returns 400 { error } when returnUrl is missing", async () => { + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({}), + }); + const res = await validateCreateSubscriptionPortalRequest(req); + expect(res.status).toBe(400); + const j = await res.json(); + expect(j).toEqual({ error: expect.stringMatching(/returnUrl|Invalid input/i) }); + }); + + it("returns 400 for unknown body keys (strict)", async () => { + const req = new NextRequest("http://localhost/api/subscriptions/portal", { + method: "POST", + headers: { "Content-Type": "application/json", "x-api-key": "k" }, + body: JSON.stringify({ + returnUrl: "https://chat.recoupable.com/billing", + extra: true, + }), + }); + expect((await validateCreateSubscriptionPortalRequest(req)).status).toBe(400); + }); +}); diff --git a/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts b/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts index 4ca95f203..bc166858e 100644 --- a/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts +++ b/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts @@ -16,7 +16,7 @@ export async function selectStripeBillingCustomerByAccountId( if (error) { console.error("selectStripeBillingCustomerByAccountId:", error); - return null; + throw error; } return data ?? null; From a63bd0c6fef3ad857446fd063674a1ee33c2ab30 Mon Sep 17 00:00:00 2001 From: john Date: Mon, 4 May 2026 22:34:06 +0700 Subject: [PATCH 3/8] Remove obsolete test files for subscription portal outcomes and route validation - Deleted the test files for the subscription portal outcomes and the main route, as they are no longer relevant to the current implementation. - This cleanup helps streamline the test suite and maintain focus on active tests. --- ...route.post.outcomes.portal.errors.test.ts} | 44 +++++------------- ...route.post.outcomes.portal.success.test.ts | 46 +++++++++++++++++++ .../portal/__tests__/route.test.ts | 11 ----- 3 files changed, 58 insertions(+), 43 deletions(-) rename app/api/subscriptions/portal/__tests__/{route.post.outcomes.portal.test.ts => route.post.outcomes.portal.errors.test.ts} (56%) create mode 100644 app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts delete mode 100644 app/api/subscriptions/portal/__tests__/route.test.ts diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts similarity index 56% rename from app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.test.ts rename to app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts index 9dbf823bf..00a204809 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts @@ -9,7 +9,7 @@ const { POST } = await import("../route"); const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; -const billingRow = { +const row = { id: 1, account_id: ACCOUNT, customer_id: "cus_test_123", @@ -17,7 +17,15 @@ const billingRow = { provider: "stripe" as const, }; -describe("POST /api/subscriptions/portal (handler outcomes — portal session)", () => { +function mockValidated() { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(row); +} + +describe("POST /api/subscriptions/portal (portal session errors)", () => { beforeEach(() => { vi.clearAllMocks(); vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); @@ -26,32 +34,8 @@ describe("POST /api/subscriptions/portal (handler outcomes — portal session)", afterEach(() => vi.mocked(console.error).mockRestore()); - it("returns 200 with id and url", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ - accountId: ACCOUNT, - returnUrl: "https://chat.recoupable.com/billing", - }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(billingRow); - vi.mocked(createBillingPortalSession).mockResolvedValue({ - id: "bps_test_abc", - url: "https://billing.example.com/session/abc", - } as Awaited>); - const res = await POST( - new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), - ); - expect(res.status).toBe(200); - await expect(res.json()).resolves.toEqual({ - id: "bps_test_abc", - url: "https://billing.example.com/session/abc", - }); - }); - it("returns 400 when session.url is null", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ - accountId: ACCOUNT, - returnUrl: "https://chat.recoupable.com/billing", - }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(billingRow); + mockValidated(); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: null, @@ -64,11 +48,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — portal session)", }); it("returns 500 when createBillingPortalSession throws", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ - accountId: ACCOUNT, - returnUrl: "https://chat.recoupable.com/billing", - }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(billingRow); + mockValidated(); vi.mocked(createBillingPortalSession).mockRejectedValue(new Error("Stripe down")); const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts new file mode 100644 index 000000000..0699f057b --- /dev/null +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts @@ -0,0 +1,46 @@ +import "./routeTestMocks"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { NextRequest } from "next/server"; +import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; + +const { POST } = await import("../route"); + +const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; + +describe("POST /api/subscriptions/portal (200)", () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.spyOn(console, "error").mockImplementation(() => undefined); + }); + + afterEach(() => vi.mocked(console.error).mockRestore()); + + it("returns id and url when portal session is created", async () => { + vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + accountId: ACCOUNT, + returnUrl: "https://chat.recoupable.com/billing", + }); + vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }); + vi.mocked(createBillingPortalSession).mockResolvedValue({ + id: "bps_test_abc", + url: "https://billing.example.com/session/abc", + } as Awaited>); + const res = await POST( + new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), + ); + expect(res.status).toBe(200); + await expect(res.json()).resolves.toEqual({ + id: "bps_test_abc", + url: "https://billing.example.com/session/abc", + }); + }); +}); diff --git a/app/api/subscriptions/portal/__tests__/route.test.ts b/app/api/subscriptions/portal/__tests__/route.test.ts deleted file mode 100644 index 4d5d51255..000000000 --- a/app/api/subscriptions/portal/__tests__/route.test.ts +++ /dev/null @@ -1,11 +0,0 @@ -import "./routeTestMocks"; -import { describe, it, expect } from "vitest"; - -const { POST, OPTIONS } = await import("../route"); - -describe("app/api/subscriptions/portal/route", () => { - it("exports POST and OPTIONS handlers", () => { - expect(typeof POST).toBe("function"); - expect(typeof OPTIONS).toBe("function"); - }); -}); From f583a6ded1a65cf325ce0dd3346a2969041af71d Mon Sep 17 00:00:00 2001 From: john Date: Mon, 4 May 2026 22:50:00 +0700 Subject: [PATCH 4/8] Refactor subscription portal validation and schema - Removed the optional `accountId` field from the `createSubscriptionPortalBodySchema` as it is no longer required. - Updated the `validateCreateSubscriptionPortalRequest` function to no longer pass `accountId` to `validateAuthContext`, simplifying the authentication context validation. - Adjusted related tests to reflect the changes in the schema and validation logic, ensuring they accurately verify the expected behavior without the `accountId` dependency. --- ...eateSubscriptionPortalRequest.auth.test.ts | 21 +------------------ lib/stripe/createSubscriptionPortalSchemas.ts | 1 - ...validateCreateSubscriptionPortalRequest.ts | 4 +--- 3 files changed, 2 insertions(+), 24 deletions(-) diff --git a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts index 8033d47e3..abfa822b9 100644 --- a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts +++ b/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts @@ -35,26 +35,7 @@ describe("validateCreateSubscriptionPortalRequest (auth)", () => { }); }); - it("passes accountId to validateAuthContext when provided", async () => { - vi.mocked(validateAuthContext).mockResolvedValue({ - accountId: ACCOUNT, - orgId: null, - authToken: "t", - }); - const otherAccount = "123e4567-e89b-12d3-a456-426614174099"; - const req = new NextRequest("http://localhost/api/subscriptions/portal", { - method: "POST", - headers: { "Content-Type": "application/json", "x-api-key": "k" }, - body: JSON.stringify({ - returnUrl: "https://chat.recoupable.com/billing", - accountId: otherAccount, - }), - }); - await validateCreateSubscriptionPortalRequest(req); - expect(validateAuthContext).toHaveBeenCalledWith(req, { accountId: otherAccount }); - }); - - it("returns accountId and returnUrl when auth succeeds", async () => { + it("returns accountId from auth and returnUrl from body when auth succeeds", async () => { vi.mocked(validateAuthContext).mockResolvedValue({ accountId: ACCOUNT, orgId: null, diff --git a/lib/stripe/createSubscriptionPortalSchemas.ts b/lib/stripe/createSubscriptionPortalSchemas.ts index 2a3d62927..16d9ed68e 100644 --- a/lib/stripe/createSubscriptionPortalSchemas.ts +++ b/lib/stripe/createSubscriptionPortalSchemas.ts @@ -3,6 +3,5 @@ import { z } from "zod"; export const createSubscriptionPortalBodySchema = z .object({ returnUrl: z.string().min(1, "returnUrl is required").url("returnUrl must be a valid URL"), - accountId: z.string().uuid("accountId must be a valid UUID").optional(), }) .strict(); diff --git a/lib/stripe/validateCreateSubscriptionPortalRequest.ts b/lib/stripe/validateCreateSubscriptionPortalRequest.ts index e7964658a..90d8d7d49 100644 --- a/lib/stripe/validateCreateSubscriptionPortalRequest.ts +++ b/lib/stripe/validateCreateSubscriptionPortalRequest.ts @@ -28,9 +28,7 @@ export async function validateCreateSubscriptionPortalRequest( return NextResponse.json({ error: first.message }, { status: 400, headers: getCorsHeaders() }); } - const authContext = await validateAuthContext(request, { - accountId: parsed.data.accountId, - }); + const authContext = await validateAuthContext(request, {}); if (authContext instanceof NextResponse) { return mapToSubscriptionSessionError(authContext); } From ef8441ad10db2f5a5afd3076a6c642a373472559 Mon Sep 17 00:00:00 2001 From: john Date: Tue, 5 May 2026 04:04:04 +0700 Subject: [PATCH 5/8] Refactor subscription portal validation to use new body validation function - Replaced instances of `validateCreateSubscriptionPortalRequest` with `validateCreateSubscriptionPortalBody` across multiple test files and the handler. - Updated related test mocks and implementations to align with the new validation function. - Removed obsolete validation request and schema files, streamlining the codebase and improving clarity in the subscription portal handling logic. --- .../__tests__/route.post.outcomes.early.test.ts | 10 +++++----- .../route.post.outcomes.portal.errors.test.ts | 6 +++--- .../route.post.outcomes.portal.success.test.ts | 6 +++--- .../__tests__/route.post.validation.test.ts | 16 ++++++++-------- .../portal/__tests__/routeTestMocks.ts | 4 ++-- .../createSubscriptionPortalHandler.test.ts | 12 ++++++------ ...ateCreateSubscriptionPortalBody.auth.test.ts} | 8 ++++---- ...ateCreateSubscriptionPortalBody.body.test.ts} | 10 +++++----- lib/stripe/createSubscriptionPortalHandler.ts | 4 ++-- lib/stripe/createSubscriptionPortalSchemas.ts | 7 ------- ...s => validateCreateSubscriptionPortalBody.ts} | 16 ++++++++++++---- 11 files changed, 50 insertions(+), 49 deletions(-) rename lib/stripe/__tests__/{validateCreateSubscriptionPortalRequest.auth.test.ts => validateCreateSubscriptionPortalBody.auth.test.ts} (85%) rename lib/stripe/__tests__/{validateCreateSubscriptionPortalRequest.body.test.ts => validateCreateSubscriptionPortalBody.body.test.ts} (81%) delete mode 100644 lib/stripe/createSubscriptionPortalSchemas.ts rename lib/stripe/{validateCreateSubscriptionPortalRequest.ts => validateCreateSubscriptionPortalBody.ts} (68%) diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts index c4b930f25..be923526d 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts @@ -1,7 +1,7 @@ import "./routeTestMocks"; import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; @@ -12,7 +12,7 @@ const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; describe("POST /api/subscriptions/portal (handler outcomes — validation & missing customer)", () => { beforeEach(() => { vi.clearAllMocks(); - vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.mocked(validateCreateSubscriptionPortalBody).mockReset(); vi.spyOn(console, "error").mockImplementation(() => undefined); }); @@ -20,7 +20,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss it("returns validation response unchanged", async () => { const err = NextResponse.json({ error: "bad" }, { status: 400 }); - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue(err); + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue(err); const req = new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}", @@ -30,7 +30,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss }); it("returns 400 when no billing customer", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); @@ -44,7 +44,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss }); it("returns 500 when billing customer lookup fails", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts index 00a204809..755507b54 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts @@ -1,7 +1,7 @@ import "./routeTestMocks"; import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest } from "next/server"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; @@ -18,7 +18,7 @@ const row = { }; function mockValidated() { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); @@ -28,7 +28,7 @@ function mockValidated() { describe("POST /api/subscriptions/portal (portal session errors)", () => { beforeEach(() => { vi.clearAllMocks(); - vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.mocked(validateCreateSubscriptionPortalBody).mockReset(); vi.spyOn(console, "error").mockImplementation(() => undefined); }); diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts index 0699f057b..5e2825ad9 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts @@ -1,7 +1,7 @@ import "./routeTestMocks"; import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest } from "next/server"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; @@ -12,14 +12,14 @@ const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; describe("POST /api/subscriptions/portal (200)", () => { beforeEach(() => { vi.clearAllMocks(); - vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.mocked(validateCreateSubscriptionPortalBody).mockReset(); vi.spyOn(console, "error").mockImplementation(() => undefined); }); afterEach(() => vi.mocked(console.error).mockRestore()); it("returns id and url when portal session is created", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); diff --git a/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts b/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts index c1f730c8d..1238b6490 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.validation.test.ts @@ -1,7 +1,7 @@ import "./routeTestMocks"; import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; import { validateAuthContext } from "@/lib/auth/validateAuthContext"; @@ -9,15 +9,15 @@ const { POST } = await import("../route"); async function loadRealValidate() { const mod = await vi.importActual< - typeof import("@/lib/stripe/validateCreateSubscriptionPortalRequest") - >("@/lib/stripe/validateCreateSubscriptionPortalRequest"); - return mod.validateCreateSubscriptionPortalRequest; + typeof import("@/lib/stripe/validateCreateSubscriptionPortalBody") + >("@/lib/stripe/validateCreateSubscriptionPortalBody"); + return mod.validateCreateSubscriptionPortalBody; } describe("POST /api/subscriptions/portal (validation)", () => { beforeEach(() => { vi.clearAllMocks(); - vi.mocked(validateCreateSubscriptionPortalRequest).mockReset(); + vi.mocked(validateCreateSubscriptionPortalBody).mockReset(); vi.spyOn(console, "error").mockImplementation(() => undefined); }); @@ -26,7 +26,7 @@ describe("POST /api/subscriptions/portal (validation)", () => { }); it("returns 400 when body is invalid JSON", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockImplementationOnce( + vi.mocked(validateCreateSubscriptionPortalBody).mockImplementationOnce( await loadRealValidate(), ); const res = await POST( @@ -42,7 +42,7 @@ describe("POST /api/subscriptions/portal (validation)", () => { }); it("returns 400 when returnUrl is missing", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockImplementationOnce( + vi.mocked(validateCreateSubscriptionPortalBody).mockImplementationOnce( await loadRealValidate(), ); const res = await POST( @@ -65,7 +65,7 @@ describe("POST /api/subscriptions/portal (validation)", () => { { status: 401 }, ), ); - vi.mocked(validateCreateSubscriptionPortalRequest).mockImplementationOnce( + vi.mocked(validateCreateSubscriptionPortalBody).mockImplementationOnce( await loadRealValidate(), ); const res = await POST( diff --git a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts index 8026e9019..3bae8a41f 100644 --- a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts +++ b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts @@ -8,8 +8,8 @@ vi.mock("@/lib/networking/getCorsHeaders", () => ({ getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), })); -vi.mock("@/lib/stripe/validateCreateSubscriptionPortalRequest", () => ({ - validateCreateSubscriptionPortalRequest: vi.fn(), +vi.mock("@/lib/stripe/validateCreateSubscriptionPortalBody", () => ({ + validateCreateSubscriptionPortalBody: vi.fn(), })); vi.mock("@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId", () => ({ diff --git a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts index 9874dccbe..5d4bed705 100644 --- a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts +++ b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts @@ -1,7 +1,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; import { createSubscriptionPortalHandler } from "@/lib/stripe/createSubscriptionPortalHandler"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; @@ -9,8 +9,8 @@ vi.mock("@/lib/networking/getCorsHeaders", () => ({ getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), })); -vi.mock("@/lib/stripe/validateCreateSubscriptionPortalRequest", () => ({ - validateCreateSubscriptionPortalRequest: vi.fn(), +vi.mock("@/lib/stripe/validateCreateSubscriptionPortalBody", () => ({ + validateCreateSubscriptionPortalBody: vi.fn(), })); vi.mock("@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId", () => ({ @@ -32,7 +32,7 @@ describe("createSubscriptionPortalHandler", () => { it("returns validation response unchanged", async () => { const err = NextResponse.json({ error: "bad" }, { status: 400 }); - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue(err); + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue(err); const req = new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}", @@ -42,7 +42,7 @@ describe("createSubscriptionPortalHandler", () => { }); it("returns 200 with id and url", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); @@ -69,7 +69,7 @@ describe("createSubscriptionPortalHandler", () => { }); it("returns 500 when createBillingPortalSession throws", async () => { - vi.mocked(validateCreateSubscriptionPortalRequest).mockResolvedValue({ + vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); diff --git a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts b/lib/stripe/__tests__/validateCreateSubscriptionPortalBody.auth.test.ts similarity index 85% rename from lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts rename to lib/stripe/__tests__/validateCreateSubscriptionPortalBody.auth.test.ts index abfa822b9..a022385f2 100644 --- a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.auth.test.ts +++ b/lib/stripe/__tests__/validateCreateSubscriptionPortalBody.auth.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { validateAuthContext } from "@/lib/auth/validateAuthContext"; vi.mock("@/lib/networking/getCorsHeaders", () => ({ @@ -13,7 +13,7 @@ vi.mock("@/lib/auth/validateAuthContext", () => ({ const ACCOUNT = "123e4567-e89b-12d3-a456-426614174000"; -describe("validateCreateSubscriptionPortalRequest (auth)", () => { +describe("validateCreateSubscriptionPortalBody (auth)", () => { beforeEach(() => vi.clearAllMocks()); it("maps auth failure to { error } and preserves status", async () => { @@ -28,7 +28,7 @@ describe("validateCreateSubscriptionPortalRequest (auth)", () => { headers: { "Content-Type": "application/json" }, body: JSON.stringify({ returnUrl: "https://chat.recoupable.com/billing" }), }); - const res = await validateCreateSubscriptionPortalRequest(req); + const res = await validateCreateSubscriptionPortalBody(req); expect(res.status).toBe(401); await expect(res.json()).resolves.toEqual({ error: "Exactly one of x-api-key or Authorization must be provided", @@ -46,7 +46,7 @@ describe("validateCreateSubscriptionPortalRequest (auth)", () => { headers: { "Content-Type": "application/json", "x-api-key": "k" }, body: JSON.stringify({ returnUrl: "https://chat.recoupable.com/billing" }), }); - const out = await validateCreateSubscriptionPortalRequest(req); + const out = await validateCreateSubscriptionPortalBody(req); expect(out).toEqual({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing" }); expect(validateAuthContext).toHaveBeenCalledWith(req, {}); }); diff --git a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts b/lib/stripe/__tests__/validateCreateSubscriptionPortalBody.body.test.ts similarity index 81% rename from lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts rename to lib/stripe/__tests__/validateCreateSubscriptionPortalBody.body.test.ts index e114457f2..d7891fa69 100644 --- a/lib/stripe/__tests__/validateCreateSubscriptionPortalRequest.body.test.ts +++ b/lib/stripe/__tests__/validateCreateSubscriptionPortalBody.body.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { NextRequest } from "next/server"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { validateAuthContext } from "@/lib/auth/validateAuthContext"; vi.mock("@/lib/networking/getCorsHeaders", () => ({ @@ -11,7 +11,7 @@ vi.mock("@/lib/auth/validateAuthContext", () => ({ validateAuthContext: vi.fn(), })); -describe("validateCreateSubscriptionPortalRequest (body)", () => { +describe("validateCreateSubscriptionPortalBody (body)", () => { beforeEach(() => vi.clearAllMocks()); it("returns 400 { error } for invalid JSON", async () => { @@ -20,7 +20,7 @@ describe("validateCreateSubscriptionPortalRequest (body)", () => { headers: { "Content-Type": "application/json", "x-api-key": "k" }, body: "not-json", }); - const res = await validateCreateSubscriptionPortalRequest(req); + const res = await validateCreateSubscriptionPortalBody(req); expect(res.status).toBe(400); await expect(res.json()).resolves.toEqual({ error: "Invalid JSON body" }); expect(validateAuthContext).not.toHaveBeenCalled(); @@ -32,7 +32,7 @@ describe("validateCreateSubscriptionPortalRequest (body)", () => { headers: { "Content-Type": "application/json", "x-api-key": "k" }, body: JSON.stringify({}), }); - const res = await validateCreateSubscriptionPortalRequest(req); + const res = await validateCreateSubscriptionPortalBody(req); expect(res.status).toBe(400); const j = await res.json(); expect(j).toEqual({ error: expect.stringMatching(/returnUrl|Invalid input/i) }); @@ -47,6 +47,6 @@ describe("validateCreateSubscriptionPortalRequest (body)", () => { extra: true, }), }); - expect((await validateCreateSubscriptionPortalRequest(req)).status).toBe(400); + expect((await validateCreateSubscriptionPortalBody(req)).status).toBe(400); }); }); diff --git a/lib/stripe/createSubscriptionPortalHandler.ts b/lib/stripe/createSubscriptionPortalHandler.ts index e5cf26386..50e7488ab 100644 --- a/lib/stripe/createSubscriptionPortalHandler.ts +++ b/lib/stripe/createSubscriptionPortalHandler.ts @@ -2,11 +2,11 @@ import { NextRequest, NextResponse } from "next/server"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { validateCreateSubscriptionPortalRequest } from "@/lib/stripe/validateCreateSubscriptionPortalRequest"; +import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; export async function createSubscriptionPortalHandler(request: NextRequest): Promise { try { - const validated = await validateCreateSubscriptionPortalRequest(request); + const validated = await validateCreateSubscriptionPortalBody(request); if (validated instanceof NextResponse) { return validated; } diff --git a/lib/stripe/createSubscriptionPortalSchemas.ts b/lib/stripe/createSubscriptionPortalSchemas.ts deleted file mode 100644 index 16d9ed68e..000000000 --- a/lib/stripe/createSubscriptionPortalSchemas.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { z } from "zod"; - -export const createSubscriptionPortalBodySchema = z - .object({ - returnUrl: z.string().min(1, "returnUrl is required").url("returnUrl must be a valid URL"), - }) - .strict(); diff --git a/lib/stripe/validateCreateSubscriptionPortalRequest.ts b/lib/stripe/validateCreateSubscriptionPortalBody.ts similarity index 68% rename from lib/stripe/validateCreateSubscriptionPortalRequest.ts rename to lib/stripe/validateCreateSubscriptionPortalBody.ts index 90d8d7d49..3f026e84f 100644 --- a/lib/stripe/validateCreateSubscriptionPortalRequest.ts +++ b/lib/stripe/validateCreateSubscriptionPortalBody.ts @@ -1,17 +1,25 @@ import { NextRequest, NextResponse } from "next/server"; +import { z } from "zod"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import { validateAuthContext } from "@/lib/auth/validateAuthContext"; -import { createSubscriptionPortalBodySchema } from "@/lib/stripe/createSubscriptionPortalSchemas"; import { mapToSubscriptionSessionError } from "@/lib/stripe/mapToSubscriptionSessionError"; -export type ValidatedCreateSubscriptionPortalRequest = { +export const createSubscriptionPortalBodySchema = z + .object({ + returnUrl: z.string().min(1, "returnUrl is required").url("returnUrl must be a valid URL"), + }) + .strict(); + +export type CreateSubscriptionPortalBody = z.infer; + +export type ValidatedCreateSubscriptionPortalBody = { accountId: string; returnUrl: string; }; -export async function validateCreateSubscriptionPortalRequest( +export async function validateCreateSubscriptionPortalBody( request: NextRequest, -): Promise { +): Promise { let body: unknown; try { body = await request.json(); From a04ef32e21d1182fdfeb635d6169090cf7ce2e3d Mon Sep 17 00:00:00 2001 From: Sweets Sweetman Date: Wed, 6 May 2026 12:11:36 -0500 Subject: [PATCH 6/8] refactor(supabase): rename selectStripeBillingCustomerByAccountId to selectBillingCustomers Match the repo's select naming convention. The function now takes optional { accountId, provider } filters and returns an array, matching sibling helpers like selectAccounts. Handler picks the first row. --- .../route.post.outcomes.early.test.ts | 8 ++-- .../route.post.outcomes.portal.errors.test.ts | 4 +- ...route.post.outcomes.portal.success.test.ts | 18 +++++---- .../portal/__tests__/routeTestMocks.ts | 4 +- .../createSubscriptionPortalHandler.test.ts | 40 ++++++++++--------- lib/stripe/createSubscriptionPortalHandler.ts | 7 +++- .../selectBillingCustomers.ts | 26 ++++++++++++ .../selectStripeBillingCustomerByAccountId.ts | 23 ----------- 8 files changed, 71 insertions(+), 59 deletions(-) create mode 100644 lib/supabase/billing_customers/selectBillingCustomers.ts delete mode 100644 lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts index be923526d..4819dfb03 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts @@ -3,7 +3,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; const { POST } = await import("../route"); @@ -26,7 +26,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss body: "{}", }); expect(await POST(req)).toBe(err); - expect(selectStripeBillingCustomerByAccountId).not.toHaveBeenCalled(); + expect(selectBillingCustomers).not.toHaveBeenCalled(); }); it("returns 400 when no billing customer", async () => { @@ -34,7 +34,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(null); + vi.mocked(selectBillingCustomers).mockResolvedValue([]); const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); @@ -48,7 +48,7 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockRejectedValue(new Error("supabase down")); + vi.mocked(selectBillingCustomers).mockRejectedValue(new Error("supabase down")); const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts index 755507b54..fed6c766d 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts @@ -3,7 +3,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest } from "next/server"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; const { POST } = await import("../route"); @@ -22,7 +22,7 @@ function mockValidated() { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue(row); + vi.mocked(selectBillingCustomers).mockResolvedValue([row]); } describe("POST /api/subscriptions/portal (portal session errors)", () => { diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts index 5e2825ad9..648844c04 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts @@ -3,7 +3,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest } from "next/server"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; const { POST } = await import("../route"); @@ -23,13 +23,15 @@ describe("POST /api/subscriptions/portal (200)", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }); + vi.mocked(selectBillingCustomers).mockResolvedValue([ + { + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }, + ]); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: "https://billing.example.com/session/abc", diff --git a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts index 3bae8a41f..c266a72b5 100644 --- a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts +++ b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts @@ -12,8 +12,8 @@ vi.mock("@/lib/stripe/validateCreateSubscriptionPortalBody", () => ({ validateCreateSubscriptionPortalBody: vi.fn(), })); -vi.mock("@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId", () => ({ - selectStripeBillingCustomerByAccountId: vi.fn(), +vi.mock("@/lib/supabase/billing_customers/selectBillingCustomers", () => ({ + selectBillingCustomers: vi.fn(), })); vi.mock("@/lib/stripe/createBillingPortalSession", () => ({ diff --git a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts index 5d4bed705..4bf44a9b1 100644 --- a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts +++ b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts @@ -2,7 +2,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; import { createSubscriptionPortalHandler } from "@/lib/stripe/createSubscriptionPortalHandler"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; -import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; vi.mock("@/lib/networking/getCorsHeaders", () => ({ @@ -13,8 +13,8 @@ vi.mock("@/lib/stripe/validateCreateSubscriptionPortalBody", () => ({ validateCreateSubscriptionPortalBody: vi.fn(), })); -vi.mock("@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId", () => ({ - selectStripeBillingCustomerByAccountId: vi.fn(), +vi.mock("@/lib/supabase/billing_customers/selectBillingCustomers", () => ({ + selectBillingCustomers: vi.fn(), })); vi.mock("@/lib/stripe/createBillingPortalSession", () => ({ @@ -38,7 +38,7 @@ describe("createSubscriptionPortalHandler", () => { body: "{}", }); expect(await createSubscriptionPortalHandler(req)).toBe(err); - expect(selectStripeBillingCustomerByAccountId).not.toHaveBeenCalled(); + expect(selectBillingCustomers).not.toHaveBeenCalled(); }); it("returns 200 with id and url", async () => { @@ -46,13 +46,15 @@ describe("createSubscriptionPortalHandler", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }); + vi.mocked(selectBillingCustomers).mockResolvedValue([ + { + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }, + ]); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: "https://billing.example.com/session/abc", @@ -73,13 +75,15 @@ describe("createSubscriptionPortalHandler", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectStripeBillingCustomerByAccountId).mockResolvedValue({ - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }); + vi.mocked(selectBillingCustomers).mockResolvedValue([ + { + id: 1, + account_id: ACCOUNT, + customer_id: "cus_test_123", + email: null, + provider: "stripe", + }, + ]); vi.mocked(createBillingPortalSession).mockRejectedValue(new Error("Stripe down")); const res = await createSubscriptionPortalHandler( diff --git a/lib/stripe/createSubscriptionPortalHandler.ts b/lib/stripe/createSubscriptionPortalHandler.ts index 50e7488ab..1cf78faa8 100644 --- a/lib/stripe/createSubscriptionPortalHandler.ts +++ b/lib/stripe/createSubscriptionPortalHandler.ts @@ -1,6 +1,6 @@ import { NextRequest, NextResponse } from "next/server"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; -import { selectStripeBillingCustomerByAccountId } from "@/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId"; +import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; @@ -11,7 +11,10 @@ export async function createSubscriptionPortalHandler(request: NextRequest): Pro return validated; } - const billingCustomer = await selectStripeBillingCustomerByAccountId(validated.accountId); + const [billingCustomer] = await selectBillingCustomers({ + accountId: validated.accountId, + provider: "stripe", + }); if (!billingCustomer) { return NextResponse.json( { error: "Billing customer not found" }, diff --git a/lib/supabase/billing_customers/selectBillingCustomers.ts b/lib/supabase/billing_customers/selectBillingCustomers.ts new file mode 100644 index 000000000..f13d22953 --- /dev/null +++ b/lib/supabase/billing_customers/selectBillingCustomers.ts @@ -0,0 +1,26 @@ +import supabase from "@/lib/supabase/serverClient"; +import type { Tables } from "@/types/database.types"; + +/** + * Select rows from `billing_customers`, optionally filtered by account and provider. + */ +export async function selectBillingCustomers({ + accountId, + provider, +}: { + accountId?: string; + provider?: string; +} = {}): Promise[]> { + let query = supabase.from("billing_customers").select("*"); + + if (accountId) query = query.eq("account_id", accountId); + if (provider) query = query.eq("provider", provider); + + const { data, error } = await query; + + if (error) { + throw error; + } + + return data ?? []; +} diff --git a/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts b/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts deleted file mode 100644 index bc166858e..000000000 --- a/lib/supabase/billing_customers/selectStripeBillingCustomerByAccountId.ts +++ /dev/null @@ -1,23 +0,0 @@ -import supabase from "@/lib/supabase/serverClient"; -import type { Tables } from "@/types/database.types"; - -/** - * Returns the Stripe billing_customers row for an account, if one exists. - */ -export async function selectStripeBillingCustomerByAccountId( - accountId: string, -): Promise | null> { - const { data, error } = await supabase - .from("billing_customers") - .select("*") - .eq("account_id", accountId) - .eq("provider", "stripe") - .maybeSingle(); - - if (error) { - console.error("selectStripeBillingCustomerByAccountId:", error); - throw error; - } - - return data ?? null; -} From 58eed67c4fd890a9a00e34d4fb01f20cda480f23 Mon Sep 17 00:00:00 2001 From: Sweets Sweetman Date: Wed, 6 May 2026 12:17:22 -0500 Subject: [PATCH 7/8] fix(supabase): type provider param as billing_provider enum MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Vercel build failed because Supabase types `billing_customers.provider` as the strict enum union, not string. Local pnpm test passes because vitest doesn't typecheck — only `next build` does, which is what runs on Vercel. --- lib/supabase/billing_customers/selectBillingCustomers.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/supabase/billing_customers/selectBillingCustomers.ts b/lib/supabase/billing_customers/selectBillingCustomers.ts index f13d22953..e254b4cbf 100644 --- a/lib/supabase/billing_customers/selectBillingCustomers.ts +++ b/lib/supabase/billing_customers/selectBillingCustomers.ts @@ -1,5 +1,7 @@ import supabase from "@/lib/supabase/serverClient"; -import type { Tables } from "@/types/database.types"; +import type { Database, Tables } from "@/types/database.types"; + +type BillingProvider = Database["public"]["Enums"]["billing_provider"]; /** * Select rows from `billing_customers`, optionally filtered by account and provider. @@ -9,7 +11,7 @@ export async function selectBillingCustomers({ provider, }: { accountId?: string; - provider?: string; + provider?: BillingProvider; } = {}): Promise[]> { let query = supabase.from("billing_customers").select("*"); From 9e75112b0a8ccdbe162b0db2ad8084b7deb9e5b7 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 7 May 2026 01:58:44 +0700 Subject: [PATCH 8/8] refactor(subscriptions): replace selectBillingCustomers with getActiveSubscriptionDetails This commit updates the subscription portal handler to utilize the new `getActiveSubscriptionDetails` function instead of the deprecated `selectBillingCustomers`. The changes include updating test cases to reflect the new logic, modifying error messages for clarity, and ensuring that the handler correctly checks for active subscriptions. Additionally, the `selectBillingCustomers` function has been removed from the codebase. --- .../route.post.outcomes.early.test.ts | 16 +++++----- .../route.post.outcomes.portal.errors.test.ts | 14 +++----- ...route.post.outcomes.portal.success.test.ts | 14 +++----- .../portal/__tests__/routeTestMocks.ts | 4 +-- .../createSubscriptionPortalHandler.test.ts | 32 ++++++------------- lib/stripe/createSubscriptionPortalHandler.ts | 13 +++----- .../selectBillingCustomers.ts | 28 ---------------- 7 files changed, 33 insertions(+), 88 deletions(-) delete mode 100644 lib/supabase/billing_customers/selectBillingCustomers.ts diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts index 4819dfb03..c3b101998 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.early.test.ts @@ -3,13 +3,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; +import { getActiveSubscriptionDetails } from "@/lib/stripe/getActiveSubscriptionDetails"; const { POST } = await import("../route"); const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; -describe("POST /api/subscriptions/portal (handler outcomes — validation & missing customer)", () => { +describe("POST /api/subscriptions/portal (handler outcomes — validation & no subscription)", () => { beforeEach(() => { vi.clearAllMocks(); vi.mocked(validateCreateSubscriptionPortalBody).mockReset(); @@ -26,29 +26,29 @@ describe("POST /api/subscriptions/portal (handler outcomes — validation & miss body: "{}", }); expect(await POST(req)).toBe(err); - expect(selectBillingCustomers).not.toHaveBeenCalled(); + expect(getActiveSubscriptionDetails).not.toHaveBeenCalled(); }); - it("returns 400 when no billing customer", async () => { + it("returns 400 when no active subscription", async () => { vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectBillingCustomers).mockResolvedValue([]); + vi.mocked(getActiveSubscriptionDetails).mockResolvedValue(null); const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); expect(res.status).toBe(400); - await expect(res.json()).resolves.toEqual({ error: "Billing customer not found" }); + await expect(res.json()).resolves.toEqual({ error: "No active subscription found" }); expect(createBillingPortalSession).not.toHaveBeenCalled(); }); - it("returns 500 when billing customer lookup fails", async () => { + it("returns 500 when subscription lookup fails", async () => { vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectBillingCustomers).mockRejectedValue(new Error("supabase down")); + vi.mocked(getActiveSubscriptionDetails).mockRejectedValue(new Error("stripe down")); const res = await POST( new NextRequest("http://localhost/api/subscriptions/portal", { method: "POST", body: "{}" }), ); diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts index fed6c766d..db6db0b6d 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.errors.test.ts @@ -3,26 +3,20 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest } from "next/server"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; +import { getActiveSubscriptionDetails } from "@/lib/stripe/getActiveSubscriptionDetails"; const { POST } = await import("../route"); const ACCOUNT = "123e4567-e89b-12d3-a456-426614174001"; -const row = { - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe" as const, -}; - function mockValidated() { vi.mocked(validateCreateSubscriptionPortalBody).mockResolvedValue({ accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectBillingCustomers).mockResolvedValue([row]); + vi.mocked(getActiveSubscriptionDetails).mockResolvedValue({ + customer: "cus_test_123", + } as Awaited>); } describe("POST /api/subscriptions/portal (portal session errors)", () => { diff --git a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts index 648844c04..f225a9a15 100644 --- a/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts +++ b/app/api/subscriptions/portal/__tests__/route.post.outcomes.portal.success.test.ts @@ -3,7 +3,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest } from "next/server"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; -import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; +import { getActiveSubscriptionDetails } from "@/lib/stripe/getActiveSubscriptionDetails"; const { POST } = await import("../route"); @@ -23,15 +23,9 @@ describe("POST /api/subscriptions/portal (200)", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectBillingCustomers).mockResolvedValue([ - { - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }, - ]); + vi.mocked(getActiveSubscriptionDetails).mockResolvedValue({ + customer: "cus_test_123", + } as Awaited>); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: "https://billing.example.com/session/abc", diff --git a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts index c266a72b5..2ba147680 100644 --- a/app/api/subscriptions/portal/__tests__/routeTestMocks.ts +++ b/app/api/subscriptions/portal/__tests__/routeTestMocks.ts @@ -12,8 +12,8 @@ vi.mock("@/lib/stripe/validateCreateSubscriptionPortalBody", () => ({ validateCreateSubscriptionPortalBody: vi.fn(), })); -vi.mock("@/lib/supabase/billing_customers/selectBillingCustomers", () => ({ - selectBillingCustomers: vi.fn(), +vi.mock("@/lib/stripe/getActiveSubscriptionDetails", () => ({ + getActiveSubscriptionDetails: vi.fn(), })); vi.mock("@/lib/stripe/createBillingPortalSession", () => ({ diff --git a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts index 4bf44a9b1..56ebb7ba8 100644 --- a/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts +++ b/lib/stripe/__tests__/createSubscriptionPortalHandler.test.ts @@ -2,8 +2,8 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextRequest, NextResponse } from "next/server"; import { createSubscriptionPortalHandler } from "@/lib/stripe/createSubscriptionPortalHandler"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; -import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { getActiveSubscriptionDetails } from "@/lib/stripe/getActiveSubscriptionDetails"; vi.mock("@/lib/networking/getCorsHeaders", () => ({ getCorsHeaders: vi.fn(() => ({ "Access-Control-Allow-Origin": "*" })), @@ -13,8 +13,8 @@ vi.mock("@/lib/stripe/validateCreateSubscriptionPortalBody", () => ({ validateCreateSubscriptionPortalBody: vi.fn(), })); -vi.mock("@/lib/supabase/billing_customers/selectBillingCustomers", () => ({ - selectBillingCustomers: vi.fn(), +vi.mock("@/lib/stripe/getActiveSubscriptionDetails", () => ({ + getActiveSubscriptionDetails: vi.fn(), })); vi.mock("@/lib/stripe/createBillingPortalSession", () => ({ @@ -38,7 +38,7 @@ describe("createSubscriptionPortalHandler", () => { body: "{}", }); expect(await createSubscriptionPortalHandler(req)).toBe(err); - expect(selectBillingCustomers).not.toHaveBeenCalled(); + expect(getActiveSubscriptionDetails).not.toHaveBeenCalled(); }); it("returns 200 with id and url", async () => { @@ -46,15 +46,9 @@ describe("createSubscriptionPortalHandler", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectBillingCustomers).mockResolvedValue([ - { - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }, - ]); + vi.mocked(getActiveSubscriptionDetails).mockResolvedValue({ + customer: "cus_test_123", + } as Awaited>); vi.mocked(createBillingPortalSession).mockResolvedValue({ id: "bps_test_abc", url: "https://billing.example.com/session/abc", @@ -75,15 +69,9 @@ describe("createSubscriptionPortalHandler", () => { accountId: ACCOUNT, returnUrl: "https://chat.recoupable.com/billing", }); - vi.mocked(selectBillingCustomers).mockResolvedValue([ - { - id: 1, - account_id: ACCOUNT, - customer_id: "cus_test_123", - email: null, - provider: "stripe", - }, - ]); + vi.mocked(getActiveSubscriptionDetails).mockResolvedValue({ + customer: "cus_test_123", + } as Awaited>); vi.mocked(createBillingPortalSession).mockRejectedValue(new Error("Stripe down")); const res = await createSubscriptionPortalHandler( diff --git a/lib/stripe/createSubscriptionPortalHandler.ts b/lib/stripe/createSubscriptionPortalHandler.ts index 1cf78faa8..d6b659f9b 100644 --- a/lib/stripe/createSubscriptionPortalHandler.ts +++ b/lib/stripe/createSubscriptionPortalHandler.ts @@ -1,7 +1,7 @@ import { NextRequest, NextResponse } from "next/server"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; -import { selectBillingCustomers } from "@/lib/supabase/billing_customers/selectBillingCustomers"; import { createBillingPortalSession } from "@/lib/stripe/createBillingPortalSession"; +import { getActiveSubscriptionDetails } from "@/lib/stripe/getActiveSubscriptionDetails"; import { validateCreateSubscriptionPortalBody } from "@/lib/stripe/validateCreateSubscriptionPortalBody"; export async function createSubscriptionPortalHandler(request: NextRequest): Promise { @@ -11,19 +11,16 @@ export async function createSubscriptionPortalHandler(request: NextRequest): Pro return validated; } - const [billingCustomer] = await selectBillingCustomers({ - accountId: validated.accountId, - provider: "stripe", - }); - if (!billingCustomer) { + const subscription = await getActiveSubscriptionDetails(validated.accountId); + if (!subscription) { return NextResponse.json( - { error: "Billing customer not found" }, + { error: "No active subscription found" }, { status: 400, headers: getCorsHeaders() }, ); } const session = await createBillingPortalSession( - billingCustomer.customer_id, + subscription.customer as string, validated.returnUrl, ); if (!session.url) { diff --git a/lib/supabase/billing_customers/selectBillingCustomers.ts b/lib/supabase/billing_customers/selectBillingCustomers.ts deleted file mode 100644 index e254b4cbf..000000000 --- a/lib/supabase/billing_customers/selectBillingCustomers.ts +++ /dev/null @@ -1,28 +0,0 @@ -import supabase from "@/lib/supabase/serverClient"; -import type { Database, Tables } from "@/types/database.types"; - -type BillingProvider = Database["public"]["Enums"]["billing_provider"]; - -/** - * Select rows from `billing_customers`, optionally filtered by account and provider. - */ -export async function selectBillingCustomers({ - accountId, - provider, -}: { - accountId?: string; - provider?: BillingProvider; -} = {}): Promise[]> { - let query = supabase.from("billing_customers").select("*"); - - if (accountId) query = query.eq("account_id", accountId); - if (provider) query = query.eq("provider", provider); - - const { data, error } = await query; - - if (error) { - throw error; - } - - return data ?? []; -}