diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3886b64..999ebd5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,7 +84,7 @@ jobs: cache-dependency-glob: pyproject.toml - name: Install Python dependencies - run: uv sync --extra dev + run: uv sync --extra dev --extra oauth - name: Lint Python with Ruff run: uv run ruff check py_src/ tests/ @@ -262,7 +262,7 @@ jobs: cache-dependency-glob: pyproject.toml - name: Install Python dependencies - run: uv sync --extra dev + run: uv sync --extra dev --extra oauth - name: Build native extension with maturin uses: PyO3/maturin-action@v1 diff --git a/Makefile b/Makefile deleted file mode 100644 index c74ddd1..0000000 --- a/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -.PHONY: docs docs-serve - -docs: - uv run zensical build - @echo "Copying markdown sources..." - cd docs && find . -name '*.md' -exec install -D {} ../site/_sources/{} \; - -docs-serve: - uv run zensical serve diff --git a/dashboard/src/components/layout/header.tsx b/dashboard/src/components/layout/header.tsx index a66e8d2..35d06fc 100644 --- a/dashboard/src/components/layout/header.tsx +++ b/dashboard/src/components/layout/header.tsx @@ -1,5 +1,6 @@ import { Search } from "lucide-react"; import { Button, Kbd } from "@/components/ui"; +import { UserMenu } from "@/features/auth"; import { useCommandPalette } from "@/providers"; import { LastRefreshed } from "./last-refreshed"; import { MobileMenu } from "./mobile-menu"; @@ -37,6 +38,7 @@ export function Header() {
+
); diff --git a/dashboard/src/components/layout/sidebar.tsx b/dashboard/src/components/layout/sidebar.tsx index f3ddd13..5f9e209 100644 --- a/dashboard/src/components/layout/sidebar.tsx +++ b/dashboard/src/components/layout/sidebar.tsx @@ -14,6 +14,7 @@ import { Server, Settings2, Skull, + Webhook as WebhookIcon, } from "lucide-react"; import { useBranding, useExternalLinks } from "@/features/settings"; import { cn } from "@/lib/cn"; @@ -57,7 +58,11 @@ const NAV: NavGroup[] = [ }, { title: "Configuration", - items: [{ to: "/settings", label: "Settings", icon: Cog }], + items: [ + { to: "/tasks", label: "Tasks", icon: ListTree }, + { to: "/webhooks", label: "Webhooks", icon: WebhookIcon }, + { to: "/settings", label: "Settings", icon: Cog }, + ], }, ]; diff --git a/dashboard/src/features/auth/api.test.ts b/dashboard/src/features/auth/api.test.ts new file mode 100644 index 0000000..9a4e8e9 --- /dev/null +++ b/dashboard/src/features/auth/api.test.ts @@ -0,0 +1,26 @@ +import { describe, expect, it } from "vitest"; +import { oauthStartUrl } from "./api"; + +describe("oauthStartUrl", () => { + it("returns the slot-rooted path when no next is supplied", () => { + expect(oauthStartUrl("google")).toBe("/api/auth/oauth/start/google"); + }); + + it("URL-encodes the next path so it survives querystring parsing", () => { + expect(oauthStartUrl("google", "/jobs?status=failed")).toBe( + "/api/auth/oauth/start/google?next=%2Fjobs%3Fstatus%3Dfailed", + ); + }); + + it("URL-encodes provider slots that contain reserved characters", () => { + // OIDC slot names must match ^[a-z][a-z0-9_-]{0,31}$ at the server, + // so this is defence-in-depth — but the encoding must not break the + // slot regex on the way out. + expect(oauthStartUrl("acme-okta", "/")).toBe("/api/auth/oauth/start/acme-okta?next=%2F"); + }); + + it("ignores empty / undefined next gracefully", () => { + expect(oauthStartUrl("github", undefined)).toBe("/api/auth/oauth/start/github"); + expect(oauthStartUrl("github", "")).toBe("/api/auth/oauth/start/github"); + }); +}); diff --git a/dashboard/src/features/auth/api.ts b/dashboard/src/features/auth/api.ts new file mode 100644 index 0000000..c4ae22a --- /dev/null +++ b/dashboard/src/features/auth/api.ts @@ -0,0 +1,51 @@ +import { api } from "@/lib/api-client"; +import type { + AuthStatus, + LoginResponse, + ProvidersResponse, + SetupResponse, + WhoamiResponse, +} from "./types"; + +export function fetchAuthStatus(signal?: AbortSignal): Promise { + return api.get("/api/auth/status", { signal }); +} + +export function fetchProviders(signal?: AbortSignal): Promise { + return api.get("/api/auth/providers", { signal }); +} + +/** Browser URL the user is sent to when they click an OAuth provider button. + * + * The server's ``/api/auth/oauth/start/{slot}`` endpoint will mint state and + * 302 to the provider. We append ``next`` so the post-login callback can + * land the user back where they were trying to go. + */ +export function oauthStartUrl(slot: string, next?: string): string { + const base = `/api/auth/oauth/start/${encodeURIComponent(slot)}`; + if (!next) return base; + return `${base}?next=${encodeURIComponent(next)}`; +} + +export function fetchWhoami(signal?: AbortSignal): Promise { + return api.get("/api/auth/whoami", { signal }); +} + +export function login(username: string, password: string): Promise { + return api.post("/api/auth/login", { username, password }); +} + +export function logout(): Promise<{ ok: boolean }> { + return api.post<{ ok: boolean }>("/api/auth/logout"); +} + +export function setup(username: string, password: string): Promise { + return api.post("/api/auth/setup", { username, password }); +} + +export function changePassword(oldPassword: string, newPassword: string): Promise<{ ok: boolean }> { + return api.post<{ ok: boolean }>("/api/auth/change-password", { + old_password: oldPassword, + new_password: newPassword, + }); +} diff --git a/dashboard/src/features/auth/components/auth-gate.tsx b/dashboard/src/features/auth/components/auth-gate.tsx new file mode 100644 index 0000000..4bcafc6 --- /dev/null +++ b/dashboard/src/features/auth/components/auth-gate.tsx @@ -0,0 +1,46 @@ +import { useNavigate } from "@tanstack/react-router"; +import type { ReactNode } from "react"; +import { useEffect } from "react"; +import { Skeleton } from "@/components/ui"; +import { useAuthStatus, useWhoami } from "../hooks"; + +/** + * Wraps the authenticated portion of the dashboard. + * + * - When setup is required, redirects to ``/login`` (which shows the setup + * form). + * - When the user isn't signed in, redirects to ``/login``. + * - While loading, renders a centered skeleton so the page never flashes + * raw content. + * + * Once a session resolves, children render normally. + */ +export function AuthGate({ children }: { children: ReactNode }) { + const navigate = useNavigate(); + const status = useAuthStatus(); + const whoami = useWhoami(); + + const setupRequired = status.data?.setup_required === true; + const authenticated = !!whoami.data?.user; + const loading = status.isLoading || whoami.isLoading; + + useEffect(() => { + if (loading) return; + if (setupRequired || !authenticated) { + void navigate({ to: "/login" }); + } + }, [loading, setupRequired, authenticated, navigate]); + + if (loading || setupRequired || !authenticated) { + return ( +
+
+ + +
+
+ ); + } + + return <>{children}; +} diff --git a/dashboard/src/features/auth/components/login-form.tsx b/dashboard/src/features/auth/components/login-form.tsx new file mode 100644 index 0000000..1485775 --- /dev/null +++ b/dashboard/src/features/auth/components/login-form.tsx @@ -0,0 +1,140 @@ +import { useNavigate, useSearch } from "@tanstack/react-router"; +import { AlertCircle, LogIn } from "lucide-react"; +import { type FormEvent, useState } from "react"; +import { Button } from "@/components/ui"; +import { Input } from "@/components/ui/input"; +import { ApiError } from "@/lib/api-client"; +import { useAuthProviders, useLogin } from "../hooks"; +import { OAuthButton } from "./oauth-button"; + +const ERROR_MESSAGES: Record = { + invalid_credentials: "Invalid username or password.", + setup_required: "Dashboard setup is required before login.", +}; + +export function LoginForm() { + const navigate = useNavigate(); + const search = useSearch({ strict: false }) as { next?: string } | undefined; + const nextPath = typeof search?.next === "string" ? search.next : undefined; + + const [username, setUsername] = useState(""); + const [password, setPassword] = useState(""); + const login = useLogin(); + const providers = useAuthProviders(); + + // Default to password-on while the providers query is in flight so the + // form doesn't flash empty on the first render. + const passwordEnabled = providers.data?.password_enabled ?? true; + const oauthProviders = providers.data?.providers ?? []; + const hasOAuth = oauthProviders.length > 0; + + function onSubmit(event: FormEvent): void { + event.preventDefault(); + login.mutate( + { username, password }, + { + onSuccess: () => { + void navigate({ to: nextPath ?? "/" }); + }, + }, + ); + } + + const error = errorMessage(login.error); + const disabled = login.isPending || !username || !password; + + return ( +
+
+

Sign in

+

+ {passwordEnabled + ? "Enter your dashboard credentials to continue." + : "Choose a provider to continue."} +

+
+ + {hasOAuth ? ( +
+ {oauthProviders.map((provider) => ( + + ))} +
+ ) : null} + + {hasOAuth && passwordEnabled ? ( +
+
+ or sign in with password +
+
+ ) : null} + + {passwordEnabled ? ( +
+ + + {error ? ( +
+ + {error} +
+ ) : null} + +
+ ) : null} + + {!passwordEnabled && !hasOAuth ? ( +
+ + + No login methods are configured. Set{" "} + TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED=true or configure an OAuth + provider. + +
+ ) : null} +
+ ); +} + +function errorMessage(error: unknown): string | null { + if (!error) return null; + if (error instanceof ApiError) { + const code = + typeof error.body === "object" && error.body && "error" in error.body + ? String((error.body as { error: unknown }).error) + : ""; + return ERROR_MESSAGES[code] ?? error.message ?? "Sign-in failed."; + } + return "Sign-in failed."; +} diff --git a/dashboard/src/features/auth/components/oauth-button.tsx b/dashboard/src/features/auth/components/oauth-button.tsx new file mode 100644 index 0000000..e5f53bc --- /dev/null +++ b/dashboard/src/features/auth/components/oauth-button.tsx @@ -0,0 +1,85 @@ +import { KeyRound } from "lucide-react"; +import { oauthStartUrl } from "../api"; +import type { AuthProvider } from "../types"; + +interface OAuthButtonProps { + provider: AuthProvider; + /** Path to send the user to after a successful login. Validated server-side. */ + next?: string; +} + +/** "Sign in with X" button — renders as a plain anchor so the browser + * follows the 302 from ``/api/auth/oauth/start/{slot}`` natively. + * + * Styling matches the dashboard's design system without depending on the + * primary :class:`Button` component (we need anchor semantics, not button). + */ +export function OAuthButton({ provider, next }: OAuthButtonProps) { + return ( + + + Continue with {provider.label} + + ); +} + +function ProviderIcon({ type }: { type: AuthProvider["type"] }) { + if (type === "google") { + return ; + } + if (type === "github") { + return ; + } + // Generic OIDC — operator-configured SSO. + return ; +} + +/** Official Google "G" mark — inlined SVG so we don't pull in a brand-asset + * dependency. Matches Google's brand guidelines for sign-in buttons. + */ +function GoogleGlyph() { + return ( + + Google + + + + + + ); +} + +/** GitHub Octocat (Mark) — inlined so we don't depend on a brand-icon set + * that might drop it (lucide-react 1.x removed brand icons). + */ +function GitHubGlyph() { + return ( + + GitHub + + + ); +} diff --git a/dashboard/src/features/auth/components/setup-form.tsx b/dashboard/src/features/auth/components/setup-form.tsx new file mode 100644 index 0000000..2a1cd6e --- /dev/null +++ b/dashboard/src/features/auth/components/setup-form.tsx @@ -0,0 +1,116 @@ +import { useNavigate } from "@tanstack/react-router"; +import { AlertCircle, ShieldCheck } from "lucide-react"; +import { type FormEvent, useState } from "react"; +import { Button } from "@/components/ui"; +import { Input } from "@/components/ui/input"; +import { ApiError } from "@/lib/api-client"; +import { useLogin, useSetup } from "../hooks"; + +export function SetupForm() { + const navigate = useNavigate(); + const [username, setUsername] = useState(""); + const [password, setPassword] = useState(""); + const [confirm, setConfirm] = useState(""); + const [formError, setFormError] = useState(null); + const setup = useSetup(); + const login = useLogin(); + + function onSubmit(event: FormEvent): void { + event.preventDefault(); + setFormError(null); + if (password !== confirm) { + setFormError("Passwords don't match."); + return; + } + setup.mutate( + { username, password }, + { + onSuccess: () => { + // Auto-login as the brand-new admin so the user lands on the + // dashboard without an extra hop. + login.mutate( + { username, password }, + { + onSuccess: () => { + void navigate({ to: "/" }); + }, + }, + ); + }, + }, + ); + } + + const pending = setup.isPending || login.isPending; + const disabled = pending || !username || !password || !confirm; + const apiError = errorMessage(setup.error ?? login.error); + const error = formError ?? apiError; + + return ( +
+
+

Create the first admin

+

+ Set up the initial dashboard administrator. You'll be signed in automatically. +

+
+ + + + {error ? ( +
+ + {error} +
+ ) : null} + +
+ ); +} + +function errorMessage(error: unknown): string | null { + if (!error) return null; + if (error instanceof ApiError) return error.message; + return "Setup failed."; +} diff --git a/dashboard/src/features/auth/components/user-menu.tsx b/dashboard/src/features/auth/components/user-menu.tsx new file mode 100644 index 0000000..52b5c13 --- /dev/null +++ b/dashboard/src/features/auth/components/user-menu.tsx @@ -0,0 +1,59 @@ +import { useNavigate } from "@tanstack/react-router"; +import { LogOut, User as UserIcon } from "lucide-react"; +import { + Button, + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui"; +import { useLogout, useWhoami } from "../hooks"; + +export function UserMenu() { + const { data } = useWhoami(); + const navigate = useNavigate(); + const logout = useLogout(); + + if (!data?.user) return null; + + const { username, role } = data.user; + + function onLogout() { + logout.mutate(undefined, { + onSettled: () => { + void navigate({ to: "/login" }); + }, + }); + } + + return ( + + + + + + +
+ {username} + {role} +
+
+ + + Sign out + +
+
+ ); +} diff --git a/dashboard/src/features/auth/hooks.ts b/dashboard/src/features/auth/hooks.ts new file mode 100644 index 0000000..b1481a3 --- /dev/null +++ b/dashboard/src/features/auth/hooks.ts @@ -0,0 +1,116 @@ +import { queryOptions, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { ApiError } from "@/lib/api-client"; +import { + changePassword, + fetchAuthStatus, + fetchProviders, + fetchWhoami, + login, + logout, + setup, +} from "./api"; +import type { WhoamiResponse } from "./types"; + +export const AUTH_STATUS_KEY = ["auth", "status"] as const; +export const WHOAMI_KEY = ["auth", "whoami"] as const; +export const PROVIDERS_KEY = ["auth", "providers"] as const; + +export function authStatusQuery() { + return queryOptions({ + queryKey: AUTH_STATUS_KEY, + queryFn: ({ signal }) => fetchAuthStatus(signal), + staleTime: 60_000, + }); +} + +/** + * Resolve the current session. ``data`` is ``null`` when no session is + * active (the server returns 401, which we trap so the rest of the app can + * test for ``data === null`` without a try/catch). + */ +export function whoamiQuery() { + return queryOptions({ + queryKey: WHOAMI_KEY, + queryFn: async ({ signal }): Promise => { + try { + return await fetchWhoami(signal); + } catch (e) { + if (e instanceof ApiError && (e.status === 401 || e.status === 404)) { + return null; + } + throw e; + } + }, + staleTime: 30_000, + retry: (failureCount, error) => { + if (error instanceof ApiError && error.status >= 400 && error.status < 500) { + return false; + } + return failureCount < 2; + }, + }); +} + +/** List of OAuth providers exposed by the server. */ +export function providersQuery() { + return queryOptions({ + queryKey: PROVIDERS_KEY, + queryFn: ({ signal }) => fetchProviders(signal), + staleTime: 60_000, + }); +} + +export function useAuthStatus() { + return useQuery(authStatusQuery()); +} + +export function useWhoami() { + return useQuery(whoamiQuery()); +} + +export function useAuthProviders() { + return useQuery(providersQuery()); +} + +export function useLogin() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ username, password }: { username: string; password: string }) => + login(username, password), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: WHOAMI_KEY }); + await qc.invalidateQueries({ queryKey: AUTH_STATUS_KEY }); + }, + }); +} + +export function useLogout() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: logout, + onSettled: async () => { + qc.setQueryData(WHOAMI_KEY, null); + // Drop every cached query — there will be no further data to show + // until the user logs back in. + qc.clear(); + }, + }); +} + +export function useSetup() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ username, password }: { username: string; password: string }) => + setup(username, password), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: AUTH_STATUS_KEY }); + }, + }); +} + +export function useChangePassword() { + return useMutation({ + mutationFn: ({ oldPassword, newPassword }: { oldPassword: string; newPassword: string }) => + changePassword(oldPassword, newPassword), + }); +} diff --git a/dashboard/src/features/auth/index.ts b/dashboard/src/features/auth/index.ts new file mode 100644 index 0000000..40babe4 --- /dev/null +++ b/dashboard/src/features/auth/index.ts @@ -0,0 +1,27 @@ +export { AuthGate } from "./components/auth-gate"; +export { LoginForm } from "./components/login-form"; +export { OAuthButton } from "./components/oauth-button"; +export { SetupForm } from "./components/setup-form"; +export { UserMenu } from "./components/user-menu"; +export { + authStatusQuery, + providersQuery, + useAuthProviders, + useAuthStatus, + useChangePassword, + useLogin, + useLogout, + useSetup, + useWhoami, + whoamiQuery, +} from "./hooks"; +export type { + AuthProvider, + AuthSession, + AuthStatus, + AuthUser, + LoginResponse, + ProvidersResponse, + SetupResponse, + WhoamiResponse, +} from "./types"; diff --git a/dashboard/src/features/auth/types.ts b/dashboard/src/features/auth/types.ts new file mode 100644 index 0000000..e61b632 --- /dev/null +++ b/dashboard/src/features/auth/types.ts @@ -0,0 +1,47 @@ +export interface AuthUser { + username: string; + role: string; + created_at: number; + last_login_at: number | null; +} + +export interface AuthSession { + username: string; + role: string; + expires_at: number; + csrf_token: string; +} + +export interface LoginResponse { + user: AuthUser; + session: AuthSession; +} + +export interface SetupResponse { + user: AuthUser; +} + +export interface AuthStatus { + setup_required: boolean; +} + +export interface WhoamiResponse { + user: AuthUser; + csrf_token: string; + expires_at: number; +} + +/** One entry in the providers listing response. */ +export interface AuthProvider { + /** Stable URL-safe identifier used in the callback path. */ + slot: string; + /** Human-readable button label. */ + label: string; + /** Provider type, drives which icon is rendered. */ + type: "google" | "github" | "oidc"; +} + +export interface ProvidersResponse { + password_enabled: boolean; + providers: AuthProvider[]; +} diff --git a/dashboard/src/features/tasks/api.ts b/dashboard/src/features/tasks/api.ts new file mode 100644 index 0000000..e1232fd --- /dev/null +++ b/dashboard/src/features/tasks/api.ts @@ -0,0 +1,26 @@ +import { api } from "@/lib/api-client"; +import type { QueueEntry, QueueOverridePatch, TaskEntry, TaskOverridePatch } from "./types"; + +export function listTasks(signal?: AbortSignal): Promise { + return api.get("/api/tasks", { signal }); +} + +export function listQueues(signal?: AbortSignal): Promise { + return api.get("/api/queues", { signal }); +} + +export function putTaskOverride(name: string, patch: TaskOverridePatch): Promise { + return api.put(`/api/tasks/${encodeURIComponent(name)}/override`, patch); +} + +export function clearTaskOverride(name: string): Promise<{ cleared: boolean }> { + return api.delete<{ cleared: boolean }>(`/api/tasks/${encodeURIComponent(name)}/override`); +} + +export function putQueueOverride(name: string, patch: QueueOverridePatch): Promise { + return api.put(`/api/queues/${encodeURIComponent(name)}/override`, patch); +} + +export function clearQueueOverride(name: string): Promise<{ cleared: boolean }> { + return api.delete<{ cleared: boolean }>(`/api/queues/${encodeURIComponent(name)}/override`); +} diff --git a/dashboard/src/features/tasks/components/middleware-toggles.tsx b/dashboard/src/features/tasks/components/middleware-toggles.tsx new file mode 100644 index 0000000..c61241d --- /dev/null +++ b/dashboard/src/features/tasks/components/middleware-toggles.tsx @@ -0,0 +1,99 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Power } from "lucide-react"; +import { toast } from "sonner"; +import { ErrorState, Skeleton } from "@/components/ui"; +import { api } from "@/lib/api-client"; + +interface TaskMiddlewareEntry { + name: string; + class_path: string; + disabled: boolean; + effective: boolean; +} + +interface TaskMiddlewareResponse { + task: string; + middleware: TaskMiddlewareEntry[]; +} + +interface Props { + taskName: string; +} + +const queryKey = (task: string) => ["tasks", task, "middleware"] as const; + +export function MiddlewareToggles({ taskName }: Props) { + const qc = useQueryClient(); + const query = useQuery({ + queryKey: queryKey(taskName), + queryFn: ({ signal }) => + api.get(`/api/tasks/${encodeURIComponent(taskName)}/middleware`, { + signal, + }), + }); + + const mutation = useMutation({ + mutationFn: ({ mwName, enabled }: { mwName: string; enabled: boolean }) => + api.put( + `/api/tasks/${encodeURIComponent(taskName)}/middleware/${encodeURIComponent(mwName)}`, + { enabled }, + ), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: queryKey(taskName) }); + }, + onError: () => toast.error("Failed to update middleware"), + }); + + if (query.isLoading) { + return ; + } + if (query.error) { + return ( + + ); + } + const entries = query.data?.middleware ?? []; + if (entries.length === 0) { + return ( +
+ No middleware registered for this task. +
+ ); + } + + return ( +
    + {entries.map((entry) => { + const enabled = !entry.disabled; + return ( +
  • +
    +
    {entry.name}
    +
    {entry.class_path}
    +
    + +
  • + ); + })} +
+ ); +} diff --git a/dashboard/src/features/tasks/components/task-list-table.tsx b/dashboard/src/features/tasks/components/task-list-table.tsx new file mode 100644 index 0000000..4c69e32 --- /dev/null +++ b/dashboard/src/features/tasks/components/task-list-table.tsx @@ -0,0 +1,132 @@ +import { ListTree } from "lucide-react"; +import { useState } from "react"; +import { + Badge, + Button, + EmptyState, + Sheet, + SheetContent, + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui"; +import type { TaskEntry } from "../types"; +import { TaskOverrideForm } from "./task-override-form"; + +interface Props { + tasks: TaskEntry[]; +} + +export function TaskListTable({ tasks }: Props) { + const [editing, setEditing] = useState(null); + + if (tasks.length === 0) { + return ( + + ); + } + + return ( + <> +
+ + + + Task + Queue + Rate limit + Concurrency + Retries + Timeout + Override + + + + + {tasks.map((task) => ( + + {task.name} + + {task.queue} + + + (v == null ? "—" : String(v))} + /> + + + (v == null ? "—" : String(v))} + /> + + + String(v)} + /> + + + `${v}s`} + /> + + + {task.paused ? ( + Paused + ) : task.override ? ( + Override + ) : ( + Default + )} + + + + + + ))} + +
+
+ + !open && setEditing(null)}> + + {editing ? setEditing(null)} /> : null} + + + + ); +} + +interface CellProps { + effective: T; + decoratorDefault: T; + formatter: (v: T) => string; +} + +function EffectiveCell({ effective, decoratorDefault, formatter }: CellProps) { + const overridden = effective !== decoratorDefault; + return ( + + {formatter(effective)} + + ); +} diff --git a/dashboard/src/features/tasks/components/task-override-form.tsx b/dashboard/src/features/tasks/components/task-override-form.tsx new file mode 100644 index 0000000..5e962b5 --- /dev/null +++ b/dashboard/src/features/tasks/components/task-override-form.tsx @@ -0,0 +1,237 @@ +import { Save, Trash2 } from "lucide-react"; +import { type FormEvent, useState } from "react"; +import { Button, Input, Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui"; +import { useClearTaskOverride, useSetTaskOverride } from "../hooks"; +import type { TaskEntry, TaskOverridePatch } from "../types"; +import { MiddlewareToggles } from "./middleware-toggles"; + +interface Props { + task: TaskEntry; + onDone?: () => void; +} + +/** + * Side-panel form for editing a task's overrides. Empty inputs mean + * "inherit the decorator default" (the override field is omitted / + * cleared); a non-empty value overrides the default. Submit applies the + * change; ``Clear`` removes the override entirely. + */ +export function TaskOverrideForm({ task, onDone }: Props) { + const setOverride = useSetTaskOverride(); + const clearOverride = useClearTaskOverride(); + + const o = task.override ?? {}; + const [rateLimit, setRateLimit] = useState(o.rate_limit ?? ""); + const [maxConcurrent, setMaxConcurrent] = useState( + o.max_concurrent != null ? String(o.max_concurrent) : "", + ); + const [maxRetries, setMaxRetries] = useState(o.max_retries != null ? String(o.max_retries) : ""); + const [timeout, setTimeoutValue] = useState(o.timeout != null ? String(o.timeout) : ""); + const [priority, setPriority] = useState(o.priority != null ? String(o.priority) : ""); + const [paused, setPaused] = useState(o.paused ?? false); + + function buildPatch(): TaskOverridePatch | null { + const patch: TaskOverridePatch = {}; + const numOr = (raw: string, name: keyof TaskOverridePatch) => { + if (raw === "") { + patch[name] = null as never; + } else { + const v = Number(raw); + if (!Number.isFinite(v)) return false; + (patch as Record)[name] = v; + } + return true; + }; + patch.rate_limit = rateLimit ? rateLimit : null; + if (!numOr(maxConcurrent, "max_concurrent")) return null; + if (!numOr(maxRetries, "max_retries")) return null; + if (!numOr(timeout, "timeout")) return null; + if (!numOr(priority, "priority")) return null; + patch.paused = paused; + return patch; + } + + function onSubmit(event: FormEvent): void { + event.preventDefault(); + const patch = buildPatch(); + if (!patch) return; + setOverride.mutate({ name: task.name, patch }, { onSuccess: () => onDone?.() }); + } + + return ( +
+
+

{task.name}

+

Queue · {task.queue}

+
+ + + Overrides + Middleware + + + clearOverride.mutate(task.name, { onSuccess: () => onDone?.() })} + /> + + + + + +
+ ); +} + +interface OverrideFormProps { + task: TaskEntry; + onSubmit: (e: FormEvent) => void; + rateLimit: string; + setRateLimit: (v: string) => void; + maxConcurrent: string; + setMaxConcurrent: (v: string) => void; + maxRetries: string; + setMaxRetries: (v: string) => void; + timeoutValue: string; + setTimeoutValue: (v: string) => void; + priority: string; + setPriority: (v: string) => void; + paused: boolean; + setPaused: (v: boolean) => void; + saving: boolean; + clearing: boolean; + onClear: () => void; +} + +function OverrideForm({ + task, + onSubmit, + rateLimit, + setRateLimit, + maxConcurrent, + setMaxConcurrent, + maxRetries, + setMaxRetries, + timeoutValue, + setTimeoutValue, + priority, + setPriority, + paused, + setPaused, + saving, + clearing, + onClear, +}: OverrideFormProps) { + return ( +
+

+ Overrides apply on the next worker restart; pausing takes effect immediately. +

+ + + + + + +
+ + +
+ + ); +} + +interface FieldProps { + id: string; + label: string; + value: string; + onChange: (v: string) => void; + defaultValue: string; + type: "text" | "number"; + placeholder?: string; +} + +function NumberField({ id, label, value, onChange, defaultValue, type, placeholder }: FieldProps) { + return ( + + ); +} diff --git a/dashboard/src/features/tasks/hooks.ts b/dashboard/src/features/tasks/hooks.ts new file mode 100644 index 0000000..2e91188 --- /dev/null +++ b/dashboard/src/features/tasks/hooks.ts @@ -0,0 +1,100 @@ +import { queryOptions, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { toast } from "sonner"; +import { ApiError } from "@/lib/api-client"; +import { + clearQueueOverride, + clearTaskOverride, + listQueues, + listTasks, + putQueueOverride, + putTaskOverride, +} from "./api"; +import type { QueueOverridePatch, TaskOverridePatch } from "./types"; + +const TASKS_KEY = ["tasks"] as const; +const QUEUES_KEY = ["queues-overrides"] as const; + +function describeError(error: unknown): string | undefined { + if (error instanceof ApiError && error.status >= 400 && error.status < 500) { + return error.message; + } + return undefined; +} + +export function tasksQuery() { + return queryOptions({ + queryKey: TASKS_KEY, + queryFn: ({ signal }) => listTasks(signal), + }); +} + +export function queuesQuery() { + return queryOptions({ + queryKey: QUEUES_KEY, + queryFn: ({ signal }) => listQueues(signal), + }); +} + +export function useTasks() { + return useQuery(tasksQuery()); +} + +export function useQueues() { + return useQuery(queuesQuery()); +} + +export function useSetTaskOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ name, patch }: { name: string; patch: TaskOverridePatch }) => + putTaskOverride(name, patch), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: TASKS_KEY }); + toast.success("Override saved", { + description: "Applied on next worker restart.", + }); + }, + onError: (error) => + toast.error("Failed to save override", { description: describeError(error) }), + }); +} + +export function useClearTaskOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (name: string) => clearTaskOverride(name), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: TASKS_KEY }); + toast.success("Override cleared"); + }, + onError: (error) => + toast.error("Failed to clear override", { description: describeError(error) }), + }); +} + +export function useSetQueueOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ name, patch }: { name: string; patch: QueueOverridePatch }) => + putQueueOverride(name, patch), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: QUEUES_KEY }); + toast.success("Queue override saved"); + }, + onError: (error) => + toast.error("Failed to save queue override", { description: describeError(error) }), + }); +} + +export function useClearQueueOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (name: string) => clearQueueOverride(name), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: QUEUES_KEY }); + toast.success("Queue override cleared"); + }, + onError: (error) => + toast.error("Failed to clear queue override", { description: describeError(error) }), + }); +} diff --git a/dashboard/src/features/tasks/index.ts b/dashboard/src/features/tasks/index.ts new file mode 100644 index 0000000..e9eae07 --- /dev/null +++ b/dashboard/src/features/tasks/index.ts @@ -0,0 +1,19 @@ +export { TaskListTable } from "./components/task-list-table"; +export { TaskOverrideForm } from "./components/task-override-form"; +export { + queuesQuery, + tasksQuery, + useClearQueueOverride, + useClearTaskOverride, + useQueues, + useSetQueueOverride, + useSetTaskOverride, + useTasks, +} from "./hooks"; +export type { + QueueEntry, + QueueOverridePatch, + TaskDefaults, + TaskEntry, + TaskOverridePatch, +} from "./types"; diff --git a/dashboard/src/features/tasks/types.ts b/dashboard/src/features/tasks/types.ts new file mode 100644 index 0000000..01b46cb --- /dev/null +++ b/dashboard/src/features/tasks/types.ts @@ -0,0 +1,41 @@ +export interface TaskDefaults { + max_retries: number; + retry_backoff: number; + timeout: number; + priority: number; + rate_limit: string | null; + max_concurrent: number | null; +} + +export interface TaskOverridePatch { + rate_limit?: string | null; + max_concurrent?: number | null; + max_retries?: number | null; + retry_backoff?: number | null; + timeout?: number | null; + priority?: number | null; + paused?: boolean; +} + +export interface TaskEntry { + name: string; + queue: string; + defaults: TaskDefaults; + override: TaskOverridePatch | null; + effective: TaskDefaults; + paused: boolean; +} + +export interface QueueOverridePatch { + rate_limit?: string | null; + max_concurrent?: number | null; + paused?: boolean; +} + +export interface QueueEntry { + name: string; + defaults: Record; + override: QueueOverridePatch | null; + effective: Record; + paused: boolean; +} diff --git a/dashboard/src/features/webhooks/api.ts b/dashboard/src/features/webhooks/api.ts new file mode 100644 index 0000000..e5e0be1 --- /dev/null +++ b/dashboard/src/features/webhooks/api.ts @@ -0,0 +1,77 @@ +import { api } from "@/lib/api-client"; +import type { + CreateWebhookInput, + DeliveryListPage, + DeliveryStatus, + ReplayDeliveryResult, + RotateSecretResult, + TestWebhookResult, + UpdateWebhookInput, + Webhook, + WebhookDelivery, +} from "./types"; + +export function listWebhooks(signal?: AbortSignal): Promise { + return api.get("/api/webhooks", { signal }); +} + +export function getWebhook(id: string, signal?: AbortSignal): Promise { + return api.get(`/api/webhooks/${id}`, { signal }); +} + +export function createWebhook(input: CreateWebhookInput): Promise { + return api.post("/api/webhooks", input); +} + +export function updateWebhook(id: string, input: UpdateWebhookInput): Promise { + return api.put(`/api/webhooks/${id}`, input); +} + +export function deleteWebhook(id: string): Promise<{ deleted: true }> { + return api.delete<{ deleted: true }>(`/api/webhooks/${id}`); +} + +export function rotateWebhookSecret(id: string): Promise { + return api.post(`/api/webhooks/${id}/rotate-secret`); +} + +export function testWebhook(id: string): Promise { + return api.post(`/api/webhooks/${id}/test`); +} + +export function listEventTypes(signal?: AbortSignal): Promise { + return api.get("/api/event-types", { signal }); +} + +export function listDeliveries( + subscriptionId: string, + options: { status?: DeliveryStatus; limit?: number; offset?: number; signal?: AbortSignal } = {}, +): Promise { + return api.get(`/api/webhooks/${subscriptionId}/deliveries`, { + signal: options.signal, + params: { + status: options.status, + limit: options.limit, + offset: options.offset, + }, + }); +} + +export function getDelivery( + subscriptionId: string, + deliveryId: string, + signal?: AbortSignal, +): Promise { + return api.get(`/api/webhooks/${subscriptionId}/deliveries/${deliveryId}`, { + signal, + }); +} + +export function replayDelivery( + subscriptionId: string, + deliveryId: string, +): Promise { + return api.post( + `/api/webhooks/${subscriptionId}/deliveries/${deliveryId}/replay`, + ); +} diff --git a/dashboard/src/features/webhooks/components/create-webhook-dialog.tsx b/dashboard/src/features/webhooks/components/create-webhook-dialog.tsx new file mode 100644 index 0000000..9e5eece --- /dev/null +++ b/dashboard/src/features/webhooks/components/create-webhook-dialog.tsx @@ -0,0 +1,169 @@ +import { AlertCircle, Plus } from "lucide-react"; +import { type FormEvent, useState } from "react"; +import { + Button, + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, + Input, +} from "@/components/ui"; +import { ApiError } from "@/lib/api-client"; +import { useCreateWebhook } from "../hooks"; +import type { Webhook } from "../types"; +import { EventTypeMultiSelect } from "./event-type-multi-select"; +import { SecretReveal } from "./secret-reveal"; +import { TaskFilterInput } from "./task-filter-input"; + +export function CreateWebhookDialog() { + const [open, setOpen] = useState(false); + const [url, setUrl] = useState(""); + const [description, setDescription] = useState(""); + const [events, setEvents] = useState([]); + const [taskFilter, setTaskFilter] = useState(null); + const [generateSecret, setGenerateSecret] = useState(true); + const [createdWebhook, setCreatedWebhook] = useState(null); + const create = useCreateWebhook(); + + function reset() { + setUrl(""); + setDescription(""); + setEvents([]); + setTaskFilter(null); + setGenerateSecret(true); + setCreatedWebhook(null); + create.reset(); + } + + function onOpenChange(next: boolean) { + if (!next) reset(); + setOpen(next); + } + + function onSubmit(event: FormEvent): void { + event.preventDefault(); + create.mutate( + { + url, + description: description || null, + events, + task_filter: taskFilter, + generate_secret: generateSecret, + }, + { onSuccess: (webhook) => setCreatedWebhook(webhook) }, + ); + } + + const errorMessage = + create.error instanceof ApiError + ? create.error.message + : create.error + ? "Failed to create webhook." + : null; + + return ( + + + + + + {createdWebhook ? ( + onOpenChange(false)} /> + ) : ( +
+ + New webhook + + Subscribe an HTTP endpoint to job lifecycle events. + + + + +
+ Events + + + Leave empty to subscribe to every event. + +
+ + + {errorMessage ? ( +
+ + {errorMessage} +
+ ) : null} + + + + + + )} +
+
+ ); +} + +function SuccessView({ webhook, onDone }: { webhook: Webhook; onDone: () => void }) { + return ( +
+ + Webhook created + + Deliveries will start immediately for the events you selected. + + +
+
URL
+
{webhook.url}
+
+ {webhook.secret ? : null} + + + +
+ ); +} diff --git a/dashboard/src/features/webhooks/components/delivery-list-table.tsx b/dashboard/src/features/webhooks/components/delivery-list-table.tsx new file mode 100644 index 0000000..10b0f56 --- /dev/null +++ b/dashboard/src/features/webhooks/components/delivery-list-table.tsx @@ -0,0 +1,183 @@ +import { History, RotateCcw } from "lucide-react"; +import { useState } from "react"; +import { + Badge, + Button, + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + EmptyState, + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui"; +import { formatRelative } from "@/lib/time"; +import { useReplayDelivery } from "../hooks"; +import type { DeliveryStatus, WebhookDelivery } from "../types"; + +interface Props { + subscriptionId: string; + deliveries: WebhookDelivery[]; +} + +function statusTone(status: DeliveryStatus): "success" | "danger" | "warning" | "neutral" { + if (status === "delivered") return "success"; + if (status === "dead") return "danger"; + if (status === "failed") return "warning"; + return "neutral"; +} + +export function DeliveryListTable({ subscriptionId, deliveries }: Props) { + const [inspecting, setInspecting] = useState(null); + const replay = useReplayDelivery(subscriptionId); + + if (deliveries.length === 0) { + return ( + + ); + } + + return ( + <> +
+ + + + When + Event + Status + Code + Latency + Attempts + + + + + {deliveries.map((delivery) => ( + setInspecting(delivery)} + > + + {formatRelative(delivery.created_at)} + + {delivery.event} + + {delivery.status} + + {delivery.response_code ?? "—"} + + {delivery.latency_ms !== null ? `${delivery.latency_ms} ms` : "—"} + + + {delivery.attempts} + + e.stopPropagation()}> + + + + ))} + +
+
+ + !open && setInspecting(null)}> + + {inspecting ? ( + <> + + Delivery details + + {inspecting.event} ·{" "} + {inspecting.status} + + + +
+ +
+ + ) : null} +
+
+ + ); +} + +function DeliveryDetail({ delivery }: { delivery: WebhookDelivery }) { + return ( +
+ + + + {delivery.error ? ( + + {delivery.error} + + } + /> + ) : null} +
+
Payload
+
+          {JSON.stringify(delivery.payload, null, 2)}
+        
+
+ {delivery.response_body ? ( +
+
+ Response body (truncated) +
+
+            {delivery.response_body}
+          
+
+ ) : null} +
+ ); +} + +function Row({ label, value }: { label: string; value: React.ReactNode }) { + return ( +
+
{label}
+
{value}
+
+ ); +} diff --git a/dashboard/src/features/webhooks/components/event-type-multi-select.tsx b/dashboard/src/features/webhooks/components/event-type-multi-select.tsx new file mode 100644 index 0000000..7f2716b --- /dev/null +++ b/dashboard/src/features/webhooks/components/event-type-multi-select.tsx @@ -0,0 +1,108 @@ +import { Check, ChevronDown } from "lucide-react"; +import { useState } from "react"; +import { + Badge, + Button, + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui"; +import { cn } from "@/lib/cn"; +import { useEventTypes } from "../hooks"; + +interface Props { + value: string[]; + onChange: (next: string[]) => void; + placeholder?: string; + /** When ``true``, an empty array means "all events" and is rendered as a hint. */ + allowAll?: boolean; +} + +export function EventTypeMultiSelect({ + value, + onChange, + placeholder = "All events", + allowAll = true, +}: Props) { + const { data: events = [] } = useEventTypes(); + const [open, setOpen] = useState(false); + + function toggle(event: string) { + if (value.includes(event)) { + onChange(value.filter((e) => e !== event)); + } else { + onChange([...value, event]); + } + } + + const label = + value.length === 0 + ? allowAll + ? placeholder + : "Select events…" + : `${value.length} event${value.length === 1 ? "" : "s"} selected`; + + return ( +
+
+ + {open ? ( +
+ + + + No events match. + + {events.map((event) => { + const selected = value.includes(event); + return ( + toggle(event)} + className="cursor-pointer" + > + + {event} + + ); + })} + + + +
+ ) : null} +
+ {value.length > 0 ? ( +
+ {value.map((event) => ( + + {event} + + + ))} +
+ ) : null} +
+ ); +} diff --git a/dashboard/src/features/webhooks/components/secret-reveal.tsx b/dashboard/src/features/webhooks/components/secret-reveal.tsx new file mode 100644 index 0000000..4005b4a --- /dev/null +++ b/dashboard/src/features/webhooks/components/secret-reveal.tsx @@ -0,0 +1,64 @@ +import { Check, Copy, Eye, EyeOff, KeyRound } from "lucide-react"; +import { useState } from "react"; +import { Button } from "@/components/ui"; + +interface Props { + secret: string; + hint?: string; +} + +/** + * One-shot secret display. Shows a masked value, lets the user reveal and + * copy it, and reminds them that the secret won't be shown again. Used by + * the create response and the rotate-secret response. + */ +export function SecretReveal({ secret, hint }: Props) { + const [shown, setShown] = useState(false); + const [copied, setCopied] = useState(false); + + async function copyToClipboard() { + try { + await navigator.clipboard.writeText(secret); + setCopied(true); + setTimeout(() => setCopied(false), 1500); + } catch { + // Clipboard write can fail (e.g. http context); the user can still + // select-and-copy the visible value. + } + } + + return ( +
+
+ + {hint ?? "Signing secret"} +
+
+ + {shown ? secret : "•".repeat(Math.min(secret.length, 48))} + + + +
+

+ Store this securely — it will not be shown again. +

+
+ ); +} diff --git a/dashboard/src/features/webhooks/components/task-filter-input.tsx b/dashboard/src/features/webhooks/components/task-filter-input.tsx new file mode 100644 index 0000000..1b5c821 --- /dev/null +++ b/dashboard/src/features/webhooks/components/task-filter-input.tsx @@ -0,0 +1,81 @@ +import { X } from "lucide-react"; +import { type KeyboardEvent, useState } from "react"; +import { Badge, Input } from "@/components/ui"; + +interface Props { + value: string[] | null; + onChange: (next: string[] | null) => void; +} + +/** + * Free-form task name list input. ``null`` means "deliver for every task"; + * an empty array means "deliver for no task" (effectively disabled). + * + * Tasks are added by typing a name and pressing Enter, comma, or space. + */ +export function TaskFilterInput({ value, onChange }: Props) { + const [draft, setDraft] = useState(""); + const enabled = value !== null; + const tasks = value ?? []; + + function commitDraft() { + const trimmed = draft.trim(); + if (!trimmed) return; + if (!tasks.includes(trimmed)) onChange([...tasks, trimmed]); + setDraft(""); + } + + function onKeyDown(event: KeyboardEvent) { + if (event.key === "Enter" || event.key === "," || event.key === " ") { + event.preventDefault(); + commitDraft(); + } else if (event.key === "Backspace" && !draft && tasks.length > 0) { + onChange(tasks.slice(0, -1)); + } + } + + function remove(task: string) { + onChange(tasks.filter((t) => t !== task)); + } + + return ( +
+ + {enabled ? ( + <> + setDraft(e.target.value)} + onKeyDown={onKeyDown} + onBlur={commitDraft} + /> + {tasks.length > 0 ? ( +
+ {tasks.map((task) => ( + + {task} + + + ))} +
+ ) : null} + + ) : null} +
+ ); +} diff --git a/dashboard/src/features/webhooks/components/webhook-list-table.tsx b/dashboard/src/features/webhooks/components/webhook-list-table.tsx new file mode 100644 index 0000000..966d87f --- /dev/null +++ b/dashboard/src/features/webhooks/components/webhook-list-table.tsx @@ -0,0 +1,111 @@ +import { Webhook as WebhookIcon } from "lucide-react"; +import { + Badge, + EmptyState, + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui"; +import type { Webhook } from "../types"; +import { WebhookRowActions } from "./webhook-row-actions"; + +interface Props { + webhooks: Webhook[]; +} + +export function WebhookListTable({ webhooks }: Props) { + if (webhooks.length === 0) { + return ( + + ); + } + + return ( +
+ + + + URL + Events + Task filter + Retries + Status + + + + + {webhooks.map((wh) => ( + + +
+ {wh.url} + {wh.description ? ( + {wh.description} + ) : null} +
+
+ + {wh.events.length === 0 ? ( + All events + ) : ( +
+ {wh.events.slice(0, 3).map((event) => ( + + {event} + + ))} + {wh.events.length > 3 ? ( + + +{wh.events.length - 3} more + + ) : null} +
+ )} +
+ + {wh.task_filter === null ? ( + All tasks + ) : wh.task_filter.length === 0 ? ( + Disabled + ) : ( +
+ {wh.task_filter.slice(0, 2).map((task) => ( + + {task} + + ))} + {wh.task_filter.length > 2 ? ( + + +{wh.task_filter.length - 2} + + ) : null} +
+ )} +
+ + {wh.max_retries}× / {wh.timeout_seconds}s + + + {wh.enabled ? ( + Enabled + ) : ( + Disabled + )} + + + + +
+ ))} +
+
+
+ ); +} diff --git a/dashboard/src/features/webhooks/components/webhook-row-actions.tsx b/dashboard/src/features/webhooks/components/webhook-row-actions.tsx new file mode 100644 index 0000000..69d300a --- /dev/null +++ b/dashboard/src/features/webhooks/components/webhook-row-actions.tsx @@ -0,0 +1,151 @@ +import { Link } from "@tanstack/react-router"; +import { + Eye, + History, + MoreHorizontal, + Power, + PowerOff, + RotateCcw, + Send, + Trash2, +} from "lucide-react"; +import { useState } from "react"; +import { + Button, + ConfirmDialog, + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui"; +import { DestructiveConfirmDialog } from "@/components/ui/destructive-confirm-dialog"; +import { useDeleteWebhook, useRotateSecret, useTestWebhook, useUpdateWebhook } from "../hooks"; +import type { Webhook } from "../types"; +import { SecretReveal } from "./secret-reveal"; + +interface Props { + webhook: Webhook; +} + +export function WebhookRowActions({ webhook }: Props) { + const update = useUpdateWebhook(); + const remove = useDeleteWebhook(); + const rotate = useRotateSecret(); + const test = useTestWebhook(); + + const [confirmDelete, setConfirmDelete] = useState(false); + const [confirmRotate, setConfirmRotate] = useState(false); + const [revealedSecret, setRevealedSecret] = useState(null); + + function onToggleEnabled() { + update.mutate({ + id: webhook.id, + input: { enabled: !webhook.enabled }, + }); + } + + function onRotate() { + rotate.mutate(webhook.id, { + onSuccess: (result) => { + setRevealedSecret(result.secret); + }, + }); + } + + return ( + <> + + + + + + + + View deliveries + + + test.mutate(webhook.id)} + disabled={test.isPending || !webhook.enabled} + > + Send test + + + {webhook.enabled ? ( + <> + Disable + + ) : ( + <> + Enable + + )} + + setConfirmRotate(true)}> + Rotate secret + + + setConfirmDelete(true)} + className="text-danger focus:text-danger" + > + Delete + + + + + { + await remove.mutateAsync(webhook.id); + }} + /> + + { + setConfirmRotate(false); + onRotate(); + }} + /> + + !open && setRevealedSecret(null)} + > + + + New signing secret + Configure your receiver with this value. + + {revealedSecret ? : null} + + + + + ); +} diff --git a/dashboard/src/features/webhooks/hooks.ts b/dashboard/src/features/webhooks/hooks.ts new file mode 100644 index 0000000..89570af --- /dev/null +++ b/dashboard/src/features/webhooks/hooks.ts @@ -0,0 +1,175 @@ +import { queryOptions, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { toast } from "sonner"; +import { ApiError } from "@/lib/api-client"; +import { + createWebhook, + deleteWebhook, + getWebhook, + listDeliveries, + listEventTypes, + listWebhooks, + replayDelivery, + rotateWebhookSecret, + testWebhook, + updateWebhook, +} from "./api"; +import type { CreateWebhookInput, DeliveryStatus, UpdateWebhookInput, Webhook } from "./types"; + +const KEY = ["webhooks"] as const; +const EVENT_TYPES_KEY = ["webhooks", "event-types"] as const; + +function describeError(error: unknown): string | undefined { + if (error instanceof ApiError && error.status >= 400 && error.status < 500) { + return error.message; + } + return undefined; +} + +export function webhooksQuery() { + return queryOptions({ + queryKey: KEY, + queryFn: ({ signal }) => listWebhooks(signal), + }); +} + +export function webhookQuery(id: string) { + return queryOptions({ + queryKey: [...KEY, id], + queryFn: ({ signal }) => getWebhook(id, signal), + }); +} + +export function eventTypesQuery() { + return queryOptions({ + queryKey: EVENT_TYPES_KEY, + queryFn: ({ signal }) => listEventTypes(signal), + staleTime: 5 * 60 * 1000, + }); +} + +export function useWebhooks() { + return useQuery(webhooksQuery()); +} + +export function useEventTypes() { + return useQuery(eventTypesQuery()); +} + +export function useCreateWebhook() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (input: CreateWebhookInput) => createWebhook(input), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: KEY }); + toast.success("Webhook created"); + }, + onError: (error) => + toast.error("Failed to create webhook", { description: describeError(error) }), + }); +} + +export function useUpdateWebhook() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ id, input }: { id: string; input: UpdateWebhookInput }) => + updateWebhook(id, input), + onMutate: async ({ id, input }) => { + await qc.cancelQueries({ queryKey: KEY }); + const prev = qc.getQueryData(KEY); + if (prev) { + qc.setQueryData( + KEY, + prev.map((w) => (w.id === id ? { ...w, ...input } : w)), + ); + } + return { prev }; + }, + onError: (error, _vars, context) => { + if (context?.prev) qc.setQueryData(KEY, context.prev); + toast.error("Failed to update webhook", { description: describeError(error) }); + }, + onSettled: async () => { + await qc.invalidateQueries({ queryKey: KEY }); + }, + }); +} + +export function useDeleteWebhook() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (id: string) => deleteWebhook(id), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: KEY }); + toast.success("Webhook deleted"); + }, + onError: (error) => + toast.error("Failed to delete webhook", { description: describeError(error) }), + }); +} + +export function useRotateSecret() { + return useMutation({ + mutationFn: (id: string) => rotateWebhookSecret(id), + onError: (error) => + toast.error("Failed to rotate secret", { description: describeError(error) }), + }); +} + +export function deliveriesQuery( + subscriptionId: string, + options: { status?: DeliveryStatus; limit?: number; offset?: number } = {}, +) { + return queryOptions({ + queryKey: [...KEY, subscriptionId, "deliveries", options] as const, + queryFn: ({ signal }) => listDeliveries(subscriptionId, { ...options, signal }), + }); +} + +export function useDeliveries( + subscriptionId: string, + options: { status?: DeliveryStatus; limit?: number; offset?: number } = {}, +) { + return useQuery(deliveriesQuery(subscriptionId, options)); +} + +export function useReplayDelivery(subscriptionId: string) { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (deliveryId: string) => replayDelivery(subscriptionId, deliveryId), + onSuccess: async (result) => { + await qc.invalidateQueries({ queryKey: [...KEY, subscriptionId, "deliveries"] }); + if (result.delivered) { + toast.success("Delivery replayed", { + description: `Endpoint returned ${result.status}`, + }); + } else { + toast.error("Replay failed", { + description: result.status + ? `Endpoint returned ${result.status}` + : "No response received from endpoint", + }); + } + }, + onError: (error) => toast.error("Replay failed", { description: describeError(error) }), + }); +} + +export function useTestWebhook() { + return useMutation({ + mutationFn: (id: string) => testWebhook(id), + onSuccess: (result) => { + if (result.delivered) { + toast.success("Test event delivered", { + description: `Endpoint returned ${result.status}`, + }); + } else { + toast.error("Test event failed", { + description: result.status + ? `Endpoint returned ${result.status}` + : "No response received from endpoint", + }); + } + }, + onError: (error) => toast.error("Test event failed", { description: describeError(error) }), + }); +} diff --git a/dashboard/src/features/webhooks/index.ts b/dashboard/src/features/webhooks/index.ts new file mode 100644 index 0000000..d1b9361 --- /dev/null +++ b/dashboard/src/features/webhooks/index.ts @@ -0,0 +1,33 @@ +export { CreateWebhookDialog } from "./components/create-webhook-dialog"; +export { DeliveryListTable } from "./components/delivery-list-table"; +export { EventTypeMultiSelect } from "./components/event-type-multi-select"; +export { SecretReveal } from "./components/secret-reveal"; +export { TaskFilterInput } from "./components/task-filter-input"; +export { WebhookListTable } from "./components/webhook-list-table"; +export { WebhookRowActions } from "./components/webhook-row-actions"; +export { + deliveriesQuery, + eventTypesQuery, + useCreateWebhook, + useDeleteWebhook, + useDeliveries, + useEventTypes, + useReplayDelivery, + useRotateSecret, + useTestWebhook, + useUpdateWebhook, + useWebhooks, + webhookQuery, + webhooksQuery, +} from "./hooks"; +export type { + CreateWebhookInput, + DeliveryListPage, + DeliveryStatus, + ReplayDeliveryResult, + RotateSecretResult, + TestWebhookResult, + UpdateWebhookInput, + Webhook, + WebhookDelivery, +} from "./types"; diff --git a/dashboard/src/features/webhooks/types.ts b/dashboard/src/features/webhooks/types.ts new file mode 100644 index 0000000..a48e90c --- /dev/null +++ b/dashboard/src/features/webhooks/types.ts @@ -0,0 +1,93 @@ +/** + * Shape of a persisted webhook subscription returned by the dashboard API. + * + * The ``secret`` field is only present on the response to the *create* and + * *rotate-secret* endpoints — every other endpoint redacts it and exposes + * only ``has_secret`` so the raw value can't leak in repeated reads. + */ +export interface Webhook { + id: string; + url: string; + events: string[]; + task_filter: string[] | null; + headers: Record; + has_secret: boolean; + secret?: string; + max_retries: number; + timeout_seconds: number; + retry_backoff: number; + enabled: boolean; + description: string | null; + created_at: number; + updated_at: number; +} + +export interface CreateWebhookInput { + url: string; + events?: string[]; + task_filter?: string[] | null; + headers?: Record; + secret?: string | null; + generate_secret?: boolean; + max_retries?: number; + timeout_seconds?: number; + retry_backoff?: number; + description?: string | null; +} + +export type UpdateWebhookInput = Partial< + Pick< + Webhook, + | "url" + | "events" + | "task_filter" + | "headers" + | "max_retries" + | "timeout_seconds" + | "retry_backoff" + | "enabled" + | "description" + > +>; + +export interface TestWebhookResult { + status: number | null; + delivered: boolean; +} + +export interface RotateSecretResult { + id: string; + secret: string; +} + +export type DeliveryStatus = "delivered" | "failed" | "dead" | "pending"; + +export interface WebhookDelivery { + id: string; + subscription_id: string; + event: string; + payload: Record; + task_name: string | null; + job_id: string | null; + status: DeliveryStatus; + attempts: number; + response_code: number | null; + response_body: string | null; + latency_ms: number | null; + error: string | null; + created_at: number; + completed_at: number | null; +} + +export interface DeliveryListPage { + items: WebhookDelivery[]; + total: number; + limit: number; + offset: number; +} + +export interface ReplayDeliveryResult { + replayed_of: string; + status: number | null; + delivered: boolean; +} diff --git a/dashboard/src/lib/api-client.test.ts b/dashboard/src/lib/api-client.test.ts index 920bc96..6c82d8a 100644 --- a/dashboard/src/lib/api-client.test.ts +++ b/dashboard/src/lib/api-client.test.ts @@ -162,3 +162,47 @@ describe("ApiError", () => { expect(err).toBeInstanceOf(Error); }); }); + +describe("CSRF cookie forwarding", () => { + beforeEach(() => { + vi.spyOn(globalThis, "fetch").mockResolvedValue(jsonResponse({ ok: true })); + // Vitest's default env is node; api-client falls back to "no cookie" when + // ``document`` is undefined, so stub a minimal document object here. + vi.stubGlobal("document", { cookie: "taskito_csrf=test-csrf-token; path=/" }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.unstubAllGlobals(); + }); + + it("attaches X-CSRF-Token to POST when the cookie is set", async () => { + await api.post("/api/auth/logout"); + const [, init] = vi.mocked(globalThis.fetch).mock.calls[0]!; + expect(init?.headers).toMatchObject({ "X-CSRF-Token": "test-csrf-token" }); + }); + + it("attaches X-CSRF-Token to PUT", async () => { + await api.put("/api/settings/k", { value: "v" }); + const [, init] = vi.mocked(globalThis.fetch).mock.calls[0]!; + expect(init?.headers).toMatchObject({ "X-CSRF-Token": "test-csrf-token" }); + }); + + it("attaches X-CSRF-Token to DELETE", async () => { + await api.delete("/api/settings/k"); + const [, init] = vi.mocked(globalThis.fetch).mock.calls[0]!; + expect(init?.headers).toMatchObject({ "X-CSRF-Token": "test-csrf-token" }); + }); + + it("does NOT attach X-CSRF-Token to GET", async () => { + await api.get("/api/stats"); + const [, init] = vi.mocked(globalThis.fetch).mock.calls[0]!; + expect(init?.headers).not.toHaveProperty("X-CSRF-Token"); + }); + + it("uses same-origin credentials so cookies are sent", async () => { + await api.get("/api/stats"); + const [, init] = vi.mocked(globalThis.fetch).mock.calls[0]!; + expect(init?.credentials).toBe("same-origin"); + }); +}); diff --git a/dashboard/src/lib/api-client.ts b/dashboard/src/lib/api-client.ts index 8624c30..72499a5 100644 --- a/dashboard/src/lib/api-client.ts +++ b/dashboard/src/lib/api-client.ts @@ -1,5 +1,8 @@ const API_BASE = import.meta.env.VITE_API_BASE ?? ""; +const CSRF_COOKIE = "taskito_csrf"; +const CSRF_HEADER = "X-CSRF-Token"; + export class ApiError extends Error { readonly status: number; readonly body: unknown; @@ -26,6 +29,25 @@ function buildUrl(path: string, params?: Query): string { return query ? `${url}?${query}` : url; } +/** + * Read a cookie by name from ``document.cookie``. Returns ``undefined`` if + * the cookie isn't set or we're running in an environment without + * ``document`` (e.g. unit tests via jsdom may omit it). + */ +export function readCookie(name: string): string | undefined { + if (typeof document === "undefined") return undefined; + for (const part of document.cookie.split(";")) { + const [k, v] = part.trim().split("="); + if (k === name) return v; + } + return undefined; +} + +function withCsrf(headers: Record): Record { + const csrf = readCookie(CSRF_COOKIE); + return csrf ? { ...headers, [CSRF_HEADER]: csrf } : headers; +} + async function parse(response: Response): Promise { const contentType = response.headers.get("content-type") ?? ""; const payload: unknown = contentType.includes("application/json") @@ -53,6 +75,7 @@ export const api = { const response = await fetch(buildUrl(path, options.params), { method: "GET", signal: options.signal, + credentials: "same-origin", headers: { Accept: "application/json", ...options.headers }, }); return parse(response); @@ -62,11 +85,12 @@ export const api = { const response = await fetch(buildUrl(path, options.params), { method: "POST", signal: options.signal, - headers: { + credentials: "same-origin", + headers: withCsrf({ Accept: "application/json", ...(body !== undefined ? { "Content-Type": "application/json" } : {}), ...options.headers, - }, + }), body: body === undefined ? undefined : JSON.stringify(body), }); return parse(response); @@ -76,11 +100,12 @@ export const api = { const response = await fetch(buildUrl(path, options.params), { method: "PUT", signal: options.signal, - headers: { + credentials: "same-origin", + headers: withCsrf({ Accept: "application/json", ...(body !== undefined ? { "Content-Type": "application/json" } : {}), ...options.headers, - }, + }), body: body === undefined ? undefined : JSON.stringify(body), }); return parse(response); @@ -90,7 +115,8 @@ export const api = { const response = await fetch(buildUrl(path, options.params), { method: "DELETE", signal: options.signal, - headers: { Accept: "application/json", ...options.headers }, + credentials: "same-origin", + headers: withCsrf({ Accept: "application/json", ...options.headers }), }); return parse(response); }, diff --git a/dashboard/src/routes/__root.tsx b/dashboard/src/routes/__root.tsx index bb4e516..8296f15 100644 --- a/dashboard/src/routes/__root.tsx +++ b/dashboard/src/routes/__root.tsx @@ -1,8 +1,9 @@ import type { QueryClient } from "@tanstack/react-query"; -import { createRootRouteWithContext, Link, Outlet } from "@tanstack/react-router"; +import { createRootRouteWithContext, Link, Outlet, useLocation } from "@tanstack/react-router"; import { AlertTriangle, ArrowLeft, Home } from "lucide-react"; import { AppShell, BackendOffline } from "@/components/layout"; import { Button, buttonVariants } from "@/components/ui"; +import { AuthGate } from "@/features/auth"; import { cn } from "@/lib/cn"; import { isBackendUnreachable } from "@/lib/errors"; @@ -16,11 +17,23 @@ export const Route = createRootRouteWithContext()({ notFoundComponent: NotFoundView, }); +/** + * Public routes that render without the AppShell or the auth gate. The + * login route handles its own redirect when a session is already active. + */ +const UNAUTHED_ROUTES = new Set(["/login"]); + function RootLayout() { + const { pathname } = useLocation(); + if (UNAUTHED_ROUTES.has(pathname)) { + return ; + } return ( - - - + + + + + ); } diff --git a/dashboard/src/routes/login.tsx b/dashboard/src/routes/login.tsx new file mode 100644 index 0000000..cdbb321 --- /dev/null +++ b/dashboard/src/routes/login.tsx @@ -0,0 +1,51 @@ +import { createFileRoute, Navigate, useRouter } from "@tanstack/react-router"; +import { AlertOctagon } from "lucide-react"; +import { LoginForm, SetupForm, useAuthStatus, useWhoami } from "@/features/auth"; + +export const Route = createFileRoute("/login")({ + component: LoginPage, +}); + +/** + * Standalone auth route — no AppShell wrapping. Shows the setup form when + * no users exist yet, the login form otherwise. Logged-in visitors are + * bounced back to the dashboard root. + */ +function LoginPage() { + const router = useRouter(); + const status = useAuthStatus(); + const whoami = useWhoami(); + + if (whoami.data?.user) { + return ; + } + + const loading = status.isLoading || whoami.isLoading; + + return ( +
+
+
+
+ +
+ taskito +
+ {loading ? ( +
Loading…
+ ) : status.data?.setup_required ? ( + + ) : ( + + )} + +
+
+ ); +} diff --git a/dashboard/src/routes/tasks.tsx b/dashboard/src/routes/tasks.tsx new file mode 100644 index 0000000..1465ba9 --- /dev/null +++ b/dashboard/src/routes/tasks.tsx @@ -0,0 +1,31 @@ +import { createFileRoute } from "@tanstack/react-router"; +import { PageHeader } from "@/components/layout/page-header"; +import { ErrorState, Skeleton } from "@/components/ui"; +import { TaskListTable, useTasks } from "@/features/tasks"; + +export const Route = createFileRoute("/tasks")({ + component: TasksPage, +}); + +function TasksPage() { + const { data, isLoading, error } = useTasks(); + + return ( +
+ + {isLoading ? ( + + ) : error ? ( + + ) : ( + + )} +
+ ); +} diff --git a/dashboard/src/routes/webhooks.tsx b/dashboard/src/routes/webhooks.tsx new file mode 100644 index 0000000..5a220ae --- /dev/null +++ b/dashboard/src/routes/webhooks.tsx @@ -0,0 +1,32 @@ +import { createFileRoute } from "@tanstack/react-router"; +import { PageHeader } from "@/components/layout/page-header"; +import { ErrorState, Skeleton } from "@/components/ui"; +import { CreateWebhookDialog, useWebhooks, WebhookListTable } from "@/features/webhooks"; + +export const Route = createFileRoute("/webhooks")({ + component: WebhooksPage, +}); + +function WebhooksPage() { + const { data, isLoading, error } = useWebhooks(); + + return ( +
+ } + /> + {isLoading ? ( + + ) : error ? ( + + ) : ( + + )} +
+ ); +} diff --git a/dashboard/src/routes/webhooks_.$id.deliveries.tsx b/dashboard/src/routes/webhooks_.$id.deliveries.tsx new file mode 100644 index 0000000..37d52c5 --- /dev/null +++ b/dashboard/src/routes/webhooks_.$id.deliveries.tsx @@ -0,0 +1,86 @@ +import { createFileRoute, Link } from "@tanstack/react-router"; +import { ArrowLeft } from "lucide-react"; +import { useState } from "react"; +import { PageHeader } from "@/components/layout/page-header"; +import { + Button, + ErrorState, + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + Skeleton, +} from "@/components/ui"; +import type { DeliveryStatus } from "@/features/webhooks"; +import { DeliveryListTable, useDeliveries, useWebhooks } from "@/features/webhooks"; + +export const Route = createFileRoute("/webhooks_/$id/deliveries")({ + component: DeliveriesPage, +}); + +const STATUSES: { label: string; value: DeliveryStatus | "all" }[] = [ + { label: "All statuses", value: "all" }, + { label: "Delivered", value: "delivered" }, + { label: "Failed", value: "failed" }, + { label: "Dead", value: "dead" }, +]; + +function DeliveriesPage() { + const { id } = Route.useParams(); + const [status, setStatus] = useState("all"); + + const webhooks = useWebhooks(); + const webhook = webhooks.data?.find((w) => w.id === id); + + const { data, isLoading, error, refetch } = useDeliveries(id, { + status: status === "all" ? undefined : status, + limit: 100, + }); + + return ( +
+ + + + + + +
+ } + /> + {isLoading ? ( + + ) : error ? ( + + ) : ( + + )} +
+ ); +} diff --git a/docs/content/docs/guides/dashboard/authentication.mdx b/docs/content/docs/guides/dashboard/authentication.mdx new file mode 100644 index 0000000..5fb972c --- /dev/null +++ b/docs/content/docs/guides/dashboard/authentication.mdx @@ -0,0 +1,163 @@ +--- +title: Authentication +description: "Session-based login, CSRF, env-var bootstrap, and the setup-required flow." +--- + +import { Callout } from "fumadocs-ui/components/callout"; +import { Tab, Tabs } from "fumadocs-ui/components/tabs"; + +The taskito dashboard is auth-gated by default. Until the first admin +exists, every protected API route returns `503 setup_required` and the +SPA shows the one-time setup form. Once an admin is registered the +dashboard requires a valid session cookie on every API call and a CSRF +token on every state-changing request. + +## How auth works + +- **Users + sessions** live in the existing `dashboard_settings` + key/value table — no new schema, so SQLite, PostgreSQL, and Redis + backends are supported uniformly. +- **Passwords** are hashed with stdlib `hashlib.pbkdf2_hmac` (SHA-256, + 600,000 iterations, 16-byte random salt — the OWASP 2023+ PBKDF2 + baseline). No third-party crypto dependency. +- **Sessions** are stored server-side under + `auth:session:` with a 24-hour TTL. The token rides in + an `HttpOnly` + `SameSite=Strict` cookie named `taskito_session`. +- **CSRF** uses the double-submit pattern: a non-HttpOnly cookie named + `taskito_csrf` carries a per-session token that the SPA reads and + echoes back via the `X-CSRF-Token` header on POST/PUT/DELETE. The + server rejects any state-changing request whose header doesn't match + both the cookie and the session-bound token. + +## First-run setup + +On a fresh database the dashboard refuses to do anything else until an +admin exists. + +![First-run setup form](/screenshots/dashboard/auth-setup.png) + +The form submits to `POST /api/auth/setup`, which is allowed to run +exactly once — it returns `400 setup already complete` after the first +user is created. The new admin is signed in automatically. + +### Bootstrap via environment variables + +For headless deployments (Docker, Kubernetes, systemd) you usually don't +want to visit a browser just to register the first user. Set both env +vars before starting the dashboard: + +```bash +export TASKITO_DASHBOARD_ADMIN_USER=admin +export TASKITO_DASHBOARD_ADMIN_PASSWORD='change-me-on-first-login' +taskito dashboard --app myapp:queue +``` + +The bootstrap is **idempotent** — once a user with that name exists, +subsequent dashboard restarts read the env vars but skip creation. + + + Rotate the password after first login (use ``POST /api/auth/change-password`` + or the future UI). Leaving the env var in your deployment is fine for + recovery, but anyone with access to the env can re-bootstrap a fresh + install — keep it scoped accordingly. + + +## Sign in + +After setup, every visit routes through the sign-in form. + +![Sign-in form](/screenshots/dashboard/auth-login.png) + +```bash +# Login from the CLI — note the cookie jar so subsequent requests +# carry the session. +curl -c jar -X POST http://localhost:8080/api/auth/login \ + -H 'Content-Type: application/json' \ + -d '{"username":"admin","password":"change-me-on-first-login"}' + +# The CSRF token comes back in a non-HttpOnly cookie. Read it and +# echo it on writes. +CSRF=$(grep taskito_csrf jar | awk '{print $7}') +curl -b jar -H "X-CSRF-Token: $CSRF" -X POST \ + http://localhost:8080/api/queues/default/pause +``` + +The browser SPA does this automatically — the `api-client.ts` wrapper +reads `document.cookie` and attaches the header on POST/PUT/DELETE. + +## API surface + +All routes live under `/api/auth/`: + +| Method | Path | What it does | +|---|---|---| +| `GET` | `/api/auth/status` | Public. Returns `{setup_required: bool}` | +| `POST` | `/api/auth/setup` | Public, locks itself after the first user | +| `POST` | `/api/auth/login` | Returns the user + session and sets cookies | +| `POST` | `/api/auth/logout` | Invalidates the current session, clears cookies | +| `GET` | `/api/auth/whoami` | Returns the current user + CSRF token + expiry | +| `POST` | `/api/auth/change-password` | Requires the current password | + +Every other route under `/api/` is auth-gated. Public exceptions: +`/health`, `/readiness`, `/metrics` (Prometheus), and the static SPA +assets. + +## Headless requests + +The same endpoints work for any HTTP client — Slack bots, +deployment scripts, custom dashboards. The minimal workflow: + +1. `POST /api/auth/login` — save the `Set-Cookie` values +2. For every read: send the `taskito_session` cookie +3. For every write: send the `taskito_session` cookie + `taskito_csrf` + cookie + matching `X-CSRF-Token` header + +```python +import requests + +s = requests.Session() +s.post( + "http://localhost:8080/api/auth/login", + json={"username": "admin", "password": "..."}, +) +csrf = s.cookies.get("taskito_csrf") +s.headers["X-CSRF-Token"] = csrf + +# Reads — no CSRF needed. +stats = s.get("http://localhost:8080/api/stats").json() + +# Writes — CSRF auto-attached via session headers. +s.post("http://localhost:8080/api/queues/default/pause") +``` + +## SSRF guard for outbound URLs + +Webhook URLs entered through the dashboard are vetted before any +delivery happens. By default the server rejects: + +- Non-`http`/`https` schemes +- `localhost`, `*.localhost`, `*.local`, `*.internal`, `*.intranet`, + `*.lan`, `*.private` +- Resolved addresses in any RFC1918 / loopback / link-local / + multicast range (including the AWS metadata service at + `169.254.169.254`) + +Set `TASKITO_WEBHOOKS_ALLOW_PRIVATE=1` to disable the guard for local +development against `http://localhost`. Production should keep the +guard on. + +## SSO / OAuth login + +Native sign-in with Google, GitHub, and any OIDC-compliant provider +(Okta, Auth0, Keycloak, Microsoft Entra) is available alongside +password auth — see [SSO (OAuth & OIDC)](/guides/dashboard/sso). +Operators can mix-and-match providers or run an OAuth-only deployment +by setting `TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED=false`. + +## Limitations + +- **One role** today (`admin`). Read-only viewers and per-route + permissions are planned; the column already exists on the user + record. +- **Password rotation** has an endpoint but no UI yet — invoke + `POST /api/auth/change-password` directly. diff --git a/docs/content/docs/guides/dashboard/index.mdx b/docs/content/docs/guides/dashboard/index.mdx new file mode 100644 index 0000000..4e5333c --- /dev/null +++ b/docs/content/docs/guides/dashboard/index.mdx @@ -0,0 +1,234 @@ +--- +title: Overview +description: "Zero-dependency built-in web UI for browsing jobs, configuring webhooks, tuning per-task runtime limits, and managing the queue." +--- + +import { Callout } from "fumadocs-ui/components/callout"; +import { Tab, Tabs } from "fumadocs-ui/components/tabs"; + +taskito ships with a built-in web dashboard for monitoring jobs, inspecting +dead letters, configuring webhooks, tuning per-task runtime limits, and +managing your task queue in real time. The dashboard is a single-page +application served directly from the Python package — **zero extra +dependencies required**. + +![Overview page with stats cards, throughput chart, and recent activity](/screenshots/dashboard/overview.png) + +## Launching the dashboard + + + + +```bash +taskito dashboard --app myapp:queue +``` + +The `--app` argument uses the same `module:attribute` format as the worker. + + + + +```python +from taskito.dashboard import serve_dashboard +from myapp import queue + +serve_dashboard(queue, host="0.0.0.0", port=8000) +``` + + + + +By default the dashboard starts on `http://localhost:8080`. + +### CLI options + +| Flag | Default | Description | +|---|---|---| +| `--app` | *required* | Module path to your `Queue` instance, e.g. `myapp:queue` | +| `--host` | `127.0.0.1` | Bind address | +| `--port` | `8080` | Bind port | + +```bash +# Bind to all interfaces on port 9000 +taskito dashboard --app myapp:queue --host 0.0.0.0 --port 9000 +``` + + + The dashboard reads directly from the same database as the worker. You + can run them side by side without any coordination: + + ```bash + # Terminal 1 + taskito worker --app myapp:queue + + # Terminal 2 + taskito dashboard --app myapp:queue + ``` + + + + On a fresh database the dashboard refuses every API request with + ``503 setup_required`` until you create the first admin. See + [Authentication](/guides/dashboard/authentication) for the full + flow, including the env-var bootstrap path useful for managed + deployments. + + +## Pages + +The dashboard is grouped by intent — Monitoring (what's happening), +Infrastructure (where it runs), Reliability (when it goes wrong), and +Configuration (how to change it): + +| Group | Page | What it does | +|---|---|---| +| Monitoring | **Overview** | Stats cards, throughput sparkline, queue-by-queue table | +| Monitoring | **Jobs** | Filterable job listing (status, queue, task, metadata, error, date range) | +| Monitoring | **Job Detail** | Full job info, error history, task logs, replay history, dependency DAG | +| Monitoring | **Metrics** | Per-task performance (avg, P50, P95, P99) with timeseries chart | +| Monitoring | **Logs** | Structured task execution logs with task/level filters | +| Infrastructure | **Queues** | Per-queue stats, pause and resume controls | +| Infrastructure | **Workers** | Worker cards with heartbeat status and queue assignments | +| Infrastructure | **Resources** | Worker DI runtime status — health, scope, init duration | +| Reliability | **Dead Letters** | Failed jobs that exhausted retries — retry or purge | +| Reliability | **Circuit Breakers** | Automatic failure protection state, thresholds, cooldowns | +| Reliability | **System** | Proxy reconstruction and interception strategy metrics | +| Configuration | **Tasks** | Decorator defaults + runtime overrides per task ([guide](/guides/dashboard/task-overrides)) | +| Configuration | **Webhooks** | HTTP event subscriptions with delivery history + replay ([guide](/guides/extensibility/events-webhooks)) | +| Configuration | **Settings** | Dashboard branding, external links, integrations | + +The full REST API surface is documented at +[Dashboard REST API](/guides/dashboard/rest-api). + +## Design + +The dashboard is a React 19 + Vite 8 + TypeScript SPA routed via TanStack +Router, styled with Tailwind v4 and shadcn/ui, and shipped as +hash-busted multi-file assets under `py_src/taskito/static/dashboard/`. + +- **Dark and light mode** — Toggle via the sun/moon button in the header. + Preference is stored in `localStorage` and follows the system scheme by + default. +- **Auto-refresh** — Configurable interval (2s, 5s, 10s, or off) via the + header dropdown. TanStack Query handles caching and background + revalidation. +- **Command palette** — `⌘K` / `Ctrl+K` opens a `cmdk` palette for route + navigation. +- **Toast notifications** — Every action shows a success or error toast. + Optimistic mutations update the UI immediately and roll back on error. +- **Destructive confirms** — Irreversible actions (purge, delete) use a + type-to-confirm dialog. +- **Loading + error states** — Skeleton screens for tables and cards; + error boundaries with retry. + + + The built SPA ships inside the Python wheel under + `py_src/taskito/static/dashboard/` and is served by the Python + dashboard process. No Node.js, no pnpm, no CDN at runtime — just + `pip install taskito`. Node.js and pnpm are only needed by + contributors rebuilding the dashboard source. + + +## Walkthrough + +### Sign in + +On the first visit you'll see the setup form. After you create the +first admin, every subsequent visit shows the sign-in form. + +![First-run setup form for the initial admin](/screenshots/dashboard/auth-setup.png) + +See [Authentication](/guides/dashboard/authentication) for the env +var-based bootstrap (`TASKITO_DASHBOARD_ADMIN_USER` / +`TASKITO_DASHBOARD_ADMIN_PASSWORD`) and the CSRF model. + +### Browse jobs and dig into one + +The **Jobs** page shows a filterable, paginated table. Filters live in +the sidebar panel: status, queue, task name, metadata search, error +text, date range. Click any row to open the detail view with the full +job state, error history, task logs, replay history, and a dependency +DAG for jobs with relationships. + +![Jobs page with filter panel and paginated list](/screenshots/dashboard/jobs.png) + +### Configure webhooks + +The **Webhooks** page lists every HTTP endpoint subscribed to job +events. Add new endpoints with the **+ New webhook** button. Each row +has a dropdown menu — send a test event, enable/disable, rotate the +signing secret, or view the delivery history. Full guide: +[Events & Webhooks](/guides/extensibility/events-webhooks). + +![Webhooks page with three subscriptions in different states](/screenshots/dashboard/webhooks-list.png) + +### Tune per-task limits + +The **Tasks** page lists every registered task with its decorator +defaults and any active runtime override. Click **Edit** to open a +side sheet with two tabs: **Overrides** (rate limit, concurrency, +retries, timeout, priority, paused) and **Middleware** (toggle each +middleware on or off for the task). Full guide: +[Task & Queue Overrides](/guides/dashboard/task-overrides). + +![Tasks page with one task overridden in accent](/screenshots/dashboard/tasks-list.png) + +### Manage queues + +The **Queues** page lists every queue mentioned by a registered task, +showing pending/running counts and the current pause state. Pause and +resume buttons take effect immediately on the running worker. + +![Queues page with per-queue controls](/screenshots/dashboard/queues.png) + +### Inspect workers + +The **Workers** page lists every registered worker with heartbeat +status, the queues it consumes from, tags, and registration time. Stale +workers (no heartbeat for 30s) automatically transition to "offline". + +![Workers page showing a single active worker](/screenshots/dashboard/workers.png) + +## Development + +Contributors who want to modify the dashboard source: + +```bash +# Install dependencies (pnpm is pinned via the `packageManager` field) +cd dashboard && pnpm install + +# Start Vite dev server (proxies /api/* to localhost:8080) +pnpm run dev + +# In another terminal, start the backend +taskito dashboard --app myapp:queue + +# Build and copy to Python package +pnpm run build +``` + + + Run `corepack enable` once (Node 16+) and pnpm will be provisioned + automatically from the version pinned in `dashboard/package.json`. + + +The build produces a static `index.html` plus hashed JS/CSS chunks +under `py_src/taskito/static/dashboard/`. The built assets aren't +committed — release tooling runs `pnpm -C dashboard build` before +packaging so the wheel ships them. + +### Regenerating screenshots + +Every dashboard screenshot in this documentation is produced by a +reproducible script that seeds a fresh queue, walks the UI in headless +Chrome via Playwright, and writes PNGs into `docs/public/screenshots/dashboard/`: + +```bash +uv sync --extra docs # one-time +uv run python -m playwright install chromium # one-time +uv run python scripts/capture_docs_screenshots.py +``` + +Pass `--skip-capture` to start the seeded demo dashboard in a browser +without running Playwright — useful when iterating on UI changes +locally. diff --git a/docs/content/docs/guides/dashboard/meta.json b/docs/content/docs/guides/dashboard/meta.json new file mode 100644 index 0000000..0d020d8 --- /dev/null +++ b/docs/content/docs/guides/dashboard/meta.json @@ -0,0 +1,4 @@ +{ + "title": "Dashboard", + "pages": ["index", "authentication", "sso", "task-overrides", "rest-api"] +} diff --git a/docs/content/docs/guides/dashboard/rest-api.mdx b/docs/content/docs/guides/dashboard/rest-api.mdx new file mode 100644 index 0000000..68e743e --- /dev/null +++ b/docs/content/docs/guides/dashboard/rest-api.mdx @@ -0,0 +1,494 @@ +--- +title: REST API +description: "JSON endpoints for stats, jobs, dead letters, metrics, logs, infrastructure, observability, webhooks, and runtime overrides." +--- + +import { Callout } from "fumadocs-ui/components/callout"; + +The dashboard exposes a JSON API you can use independently of the UI. +All endpoints return `application/json` and live under the same origin +as the dashboard itself. + + + Every `/api/*` endpoint except the public set (`/api/auth/status`, + `/api/auth/login`, `/api/auth/setup`) requires a valid session cookie + obtained from `POST /api/auth/login`. State-changing requests + (POST/PUT/DELETE) additionally require a CSRF header. See + [Dashboard Authentication](/guides/dashboard/authentication) for + the login flow and headless usage examples. + + +## Auth + +### `GET /api/auth/status` + +Public. Returns whether the dashboard needs first-run setup. + +```json +{ "setup_required": false } +``` + +### `POST /api/auth/setup` + +Public, but locks itself after the first user is created. Body: +`{"username": "...", "password": "..."}`. Returns the new user. + +### `POST /api/auth/login` + +Body: `{"username": "...", "password": "..."}`. Sets the +`taskito_session` (HttpOnly) and `taskito_csrf` cookies on success. +Returns `400 invalid_credentials` on failure. + +### `POST /api/auth/logout` + +Invalidates the current session and clears cookies. + +### `GET /api/auth/whoami` + +Returns the current user, CSRF token, and expiry. `401` when no session. + +### `POST /api/auth/change-password` + +Body: `{"old_password": "...", "new_password": "..."}`. + +## Stats + +### `GET /api/stats` + +Queue statistics snapshot. + +```json +{ + "pending": 12, + "running": 3, + "completed": 450, + "failed": 2, + "dead": 1, + "cancelled": 0 +} +``` + +### `GET /api/stats/queues` + +Per-queue statistics. Pass `?queue=name` for a single queue, or omit +for all queues. + +## Jobs + +### `GET /api/jobs` + +Paginated list of jobs with filtering. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `status` | `string` | all | Filter by status | +| `queue` | `string` | all | Filter by queue name | +| `task` | `string` | all | Filter by task name | +| `metadata` | `string` | — | Search metadata (LIKE) | +| `error` | `string` | — | Search error text (LIKE) | +| `created_after` | `int` | — | Unix ms timestamp | +| `created_before` | `int` | — | Unix ms timestamp | +| `limit` | `int` | `20` | Page size | +| `offset` | `int` | `0` | Pagination offset | + +### `GET /api/jobs/{id}` + +Full detail for a single job. + +### `GET /api/jobs/{id}/errors` + +Error history for a job (one entry per failed attempt). + +### `GET /api/jobs/{id}/logs` + +Task execution logs for a specific job. + +### `GET /api/jobs/{id}/replay-history` + +Replay history for a job that has been replayed. + +### `GET /api/jobs/{id}/dag` + +Dependency graph for a job (nodes and edges). + +### `POST /api/jobs/{id}/cancel` + +Cancel a pending job. + +### `POST /api/jobs/{id}/replay` + +Replay a completed or failed job with the same payload. + +## Dead letters + +### `GET /api/dead-letters` + +Paginated list of dead letter entries. Supports `limit` and `offset`. + +### `POST /api/dead-letters/{id}/retry` + +Re-enqueue a dead letter job. + +### `POST /api/dead-letters/purge` + +Purge all dead letters. + +## Webhooks + +Full guide: [Events & Webhooks](/guides/extensibility/events-webhooks). + +### `GET /api/webhooks` + +List all subscriptions. The `secret` field is **never** returned — only +a `has_secret` boolean. The secret is only included on the response to +`POST /api/webhooks` (create) and `POST /api/webhooks/{id}/rotate-secret`, +exactly once. + +```json +[ + { + "id": "f00563cbbb1a4200bb461f83d1db47bf", + "url": "https://hooks.example.com/ops-failures", + "events": ["job.failed", "job.dead"], + "task_filter": null, + "headers": {}, + "has_secret": true, + "max_retries": 5, + "timeout_seconds": 8.0, + "retry_backoff": 2.0, + "enabled": true, + "description": "Page ops on permanent failures", + "created_at": 1716000000, + "updated_at": 1716000000 + } +] +``` + +### `POST /api/webhooks` + +Create a subscription. + +| Field | Type | Description | +|---|---|---| +| `url` | `string` | Required. http/https URL, SSRF-vetted | +| `events` | `string[]` | Event types (`job.failed`, etc.). Empty/missing → all | +| `task_filter` | `string[] \| null` | Restrict to specific task names. `null` → all tasks | +| `headers` | `object` | Extra HTTP headers | +| `secret` | `string \| null` | Explicit signing key | +| `generate_secret` | `bool` | If true, server generates a fresh secret | +| `max_retries` | `int` | Default 3 | +| `timeout_seconds` | `float` | Default 10.0 | +| `retry_backoff` | `float` | Default 2.0 | +| `description` | `string \| null` | Free-form label | + +Response includes the secret **once** if one was set or generated. + +### `GET /api/webhooks/{id}` + +Single subscription (secret redacted). + +### `PUT /api/webhooks/{id}` + +Partial update. Only fields you include are touched. Same field set as +create. + +### `DELETE /api/webhooks/{id}` + +Delete the subscription. + +### `POST /api/webhooks/{id}/test` + +Synchronously POST a synthetic `test.ping` event and return the +result inline. + +```json +{ "status": 200, "delivered": true } +``` + +### `POST /api/webhooks/{id}/rotate-secret` + +Generate a fresh HMAC secret. Returns `{id, secret}` — the only time +the new value is visible. + +### `GET /api/event-types` + +Sorted list of every valid event type value. Used by the dashboard's +event multi-select. + +```json +["job.cancelled", "job.completed", "job.dead", ...] +``` + +## Webhook deliveries + +### `GET /api/webhooks/{id}/deliveries` + +Persistent log of attempts for the subscription. Supports filters: + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `status` | `delivered \| failed \| dead \| pending` | all | Filter by outcome | +| `event` | `string` | all | Filter by event type | +| `limit` | `int` | `50` | Page size (max 200) | +| `offset` | `int` | `0` | Pagination offset | + +```json +{ + "items": [ + { + "id": "01H...", + "subscription_id": "f00563cb...", + "event": "job.failed", + "payload": { "job_id": "...", "task_name": "...", "error": "..." }, + "task_name": "myapp.tasks.process_image", + "job_id": "01H...", + "status": "dead", + "attempts": 3, + "response_code": 500, + "response_body": "Internal Server Error", + "latency_ms": 30000, + "error": null, + "created_at": 1716000000000, + "completed_at": 1716000030000 + } + ], + "total": 1, + "limit": 50, + "offset": 0 +} +``` + +### `GET /api/webhooks/{id}/deliveries/{delivery_id}` + +Single delivery record. + +### `POST /api/webhooks/{id}/deliveries/{delivery_id}/replay` + +Re-fire the stored payload synchronously. Records the outcome as a +fresh delivery on top of the original (audit trail preserved). + +```json +{ "replayed_of": "01H...", "status": 200, "delivered": true } +``` + +## Tasks and overrides + +Full guide: [Task & Queue Overrides](/guides/dashboard/task-overrides). + +### `GET /api/tasks` + +List every registered task with decorator defaults, override, and +effective values. + +```json +[ + { + "name": "myapp.tasks.send_email", + "queue": "default", + "defaults": { + "max_retries": 3, + "retry_backoff": 1.0, + "timeout": 300, + "priority": 0, + "rate_limit": null, + "max_concurrent": null + }, + "override": { "rate_limit": "200/m", "max_retries": 10 }, + "effective": { + "max_retries": 10, + "retry_backoff": 1.0, + "timeout": 300, + "priority": 0, + "rate_limit": "200/m", + "max_concurrent": null + }, + "paused": false + } +] +``` + +### `GET /api/tasks/{name}/override` + +Single task's override row. `404` if none set. + +### `PUT /api/tasks/{name}/override` + +Upsert the override. Body keys must be in the allow-list: +`rate_limit`, `max_concurrent`, `max_retries`, `retry_backoff`, +`timeout`, `priority`, `paused`. Passing `null` for a field removes +just that field. Unknown fields → `400`. + +### `DELETE /api/tasks/{name}/override` + +Remove the override entirely. Returns `{cleared: bool}`. + +### `GET /api/queues` + +List every queue mentioned by a task config with defaults, override, +effective, and paused state. + +### `GET /api/queues/{name}/override` / `PUT` / `DELETE` + +Same shape as tasks. Allow-list for queue overrides: +`rate_limit`, `max_concurrent`, `paused`. The `paused` flag also flips +the live `paused_queues` table so it takes effect on running workers +immediately. + +## Middleware + +### `GET /api/middleware` + +List every registered middleware (global + per-task) with its scopes. + +```json +[ + { "name": "sentry", "class_path": "myapp.middleware.SentryMiddleware", "scopes": [{"kind": "global"}] } +] +``` + +### `GET /api/tasks/{name}/middleware` + +The middleware chain that fires for a task, with each entry's +`disabled` and `effective` flags. + +```json +{ + "task": "myapp.tasks.send_email", + "middleware": [ + { "name": "demo.logging", "class_path": "...", "disabled": false, "effective": true }, + { "name": "demo.metrics", "class_path": "...", "disabled": true, "effective": false } + ] +} +``` + +### `PUT /api/tasks/{name}/middleware/{mw_name}` + +Body: `{"enabled": bool}`. Returns `{task, disabled: [...]}` reflecting +the new disable list. `404` if the middleware name isn't registered on +the task — typos can't write no-op disables. + +### `DELETE /api/tasks/{name}/middleware` + +Clear all middleware disables for a task — every middleware fires +again. + +## Metrics + +### `GET /api/metrics` + +Per-task execution metrics. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `task` | `string` | all | Filter by task name | +| `since` | `int` | `3600` | Lookback window in seconds | + +### `GET /api/metrics/timeseries` + +Time-bucketed metrics for charts. + +## Logs + +### `GET /api/logs` + +Query task execution logs across all jobs. + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `task` | `string` | all | Filter by task name | +| `level` | `string` | all | Filter by log level | +| `since` | `int` | `3600` | Lookback window in seconds | +| `limit` | `int` | `100` | Max entries | + +## Infrastructure + +### `GET /api/workers` + +List registered workers with heartbeat status. + +### `GET /api/circuit-breakers` + +Current state of all circuit breakers. + +### `GET /api/resources` + +Worker resource health and pool status. + +### `GET /api/queues/paused` + +List paused queue names. + +### `POST /api/queues/{name}/pause` / `POST /api/queues/{name}/resume` + +Pause or resume a queue. Takes effect immediately. + +## Observability + +### `GET /api/proxy-stats` + +Per-handler proxy reconstruction metrics. + +### `GET /api/interception-stats` + +Interception strategy performance metrics. + +### `GET /api/scaler` + +KEDA-compatible autoscaler payload. Pass `?queue=name` for a specific queue. + +### `GET /health` + +Public liveness check. Always returns `{"status": "ok"}`. + +### `GET /readiness` + +Public readiness check with storage, worker, and resource health. + +### `GET /metrics` + +Public Prometheus metrics endpoint (requires `prometheus-client` package). + +## Settings + +### `GET /api/settings` + +Dump of every dashboard setting key/value. + +### `GET /api/settings/{key}` / `PUT` / `DELETE` + +Read, set, or delete a single dashboard setting. Used by the dashboard +itself for branding, external links, and integration URLs — but you +can write your own keys here too. Note that this is the same store +where authentication, webhook subscriptions, delivery logs, and +runtime overrides live, all under namespaced prefixes (`auth:*`, +`webhooks:*`, `overrides:*`, etc.). + +## Using the API programmatically + +```python +import requests + +s = requests.Session() +s.post( + "http://localhost:8080/api/auth/login", + json={"username": "admin", "password": "..."}, +) +csrf = s.cookies.get("taskito_csrf") +s.headers["X-CSRF-Token"] = csrf + +# Health check script. +stats = s.get("http://localhost:8080/api/stats").json() +if stats["dead"] > 0: + print(f"WARNING: {stats['dead']} dead letter(s)") + +# Tune a task's rate limit during an incident. +s.put( + "http://localhost:8080/api/tasks/myapp.tasks.send_email/override", + json={"rate_limit": "30/m"}, +) + +# Pause a queue during deployment. +s.post("http://localhost:8080/api/queues/default/pause") +# ... deploy ... +s.post("http://localhost:8080/api/queues/default/resume") +``` diff --git a/docs/content/docs/guides/dashboard/sso.mdx b/docs/content/docs/guides/dashboard/sso.mdx new file mode 100644 index 0000000..7771ac6 --- /dev/null +++ b/docs/content/docs/guides/dashboard/sso.mdx @@ -0,0 +1,284 @@ +--- +title: SSO (OAuth & OIDC) +description: "Sign in with Google, GitHub, or any OIDC provider. Per-domain / per-org allowlists, OAuth-only mode." +--- + +import { Callout } from "fumadocs-ui/components/callout"; +import { Tab, Tabs } from "fumadocs-ui/components/tabs"; + +The dashboard ships native sign-in for **Google**, **GitHub**, and any +**OIDC-compliant** provider (Okta, Auth0, Keycloak, Microsoft Entra, Dex, +…). Multiple OIDC providers can run side-by-side, each rendered as its +own button on the login screen. + +OAuth is **off by default**. Setting any provider's env vars turns it +on; password login remains enabled unless you opt out explicitly. + + + OAuth requires the `authlib` extra: + + ```bash + pip install 'taskito[oauth]' + # or with uv: + uv pip install 'taskito[oauth]' + ``` + + Skip this if you only use password login. + + +## How it works + +>D: GET /login + B->>D: GET /api/auth/providers + D-->>B: { providers, password_enabled } + + Note over B: user clicks "Continue with Google" + B->>D: GET /api/auth/oauth/start/google + Note over D: mint state + nonce + PKCE
persist state row (5-min TTL) + D-->>B: 302 to provider authorize URL + + B->>P: GET /authorize?... + Note over P: user consents + P-->>B: 302 to /api/auth/oauth/callback/google?code=...&state=... + + B->>D: GET /api/auth/oauth/callback/google + Note over D: validate + consume state (single-use) + D->>P: POST /token (code + code_verifier) + P-->>D: { id_token, access_token } + Note over D: verify JWKS / nonce / aud / iss
enforce allowlist
get_or_create User
create Session + D-->>B: 302 to /
Set-Cookie: taskito_session, taskito_csrf`} +/> + +State is **single-use** and **time-bounded** (5-min default TTL). PKCE +S256, OIDC nonce, ID-token signature (via the provider's JWKS), +`iss` / `aud` / `exp` are all enforced server-side. + +## Quick start: Google login + +1. **Create an OAuth client.** Visit the + [Google Cloud Console → APIs & Services → Credentials](https://console.cloud.google.com/apis/credentials), + create an OAuth 2.0 Client ID of type *Web application*, and register + the callback URL: + + ``` + https://taskito.your-company.com/api/auth/oauth/callback/google + ``` + + (For local development, `http://localhost:8000/api/auth/oauth/callback/google` + works without HTTPS.) + +2. **Set env vars** before starting the dashboard: + + ```bash + export TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL=https://taskito.your-company.com + export TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID=...apps.googleusercontent.com + export TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET=... + # Restrict logins to your Google Workspace domain: + export TASKITO_DASHBOARD_OAUTH_GOOGLE_ALLOWED_DOMAINS=your-company.com + ``` + +3. **Start the dashboard.** The login screen now shows a "Continue with + Google" button above the password form. + +## GitHub login + +GitHub is OAuth2-only (no OIDC), so the dashboard hits `/user` and +`/user/emails` to derive an identity. Org membership is verified via +`/orgs/{org}/members/{login}`. + +1. Create a [GitHub OAuth App](https://github.com/settings/developers). + Set the *Authorization callback URL* to + `https://taskito.your-company.com/api/auth/oauth/callback/github`. + +2. Env vars: + + ```bash + export TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_ID=Iv1.xxxxx + export TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_SECRET=... + # Restrict logins to members of these GitHub orgs: + export TASKITO_DASHBOARD_OAUTH_GITHUB_ALLOWED_ORGS=your-org,partner-org + ``` + +When `ALLOWED_ORGS` is set the OAuth scope automatically expands to +include `read:org` so the membership endpoint returns reliable results +for private orgs. Users who consent without the additional scope are +rejected at the allowlist gate. + + + GitHub accounts that have no `verified=true` primary email + (returned by `GET /user/emails`) are always assigned the `viewer` + role, even if listed in `TASKITO_DASHBOARD_OAUTH_ADMIN_EMAILS`. This + prevents privilege escalation via spoofed email claims. + + +## Generic OIDC (Okta, Auth0, Keycloak, Microsoft, …) + +Generic OIDC providers are configured as **named slots**. Each slot has +its own callback URL, own user namespace, and own button on the login +screen. + +```bash +export TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL=https://taskito.your-company.com + +# List the slots first. +export TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS=okta,microsoft + +# Then per-slot config (slot name uppercase, separators normalised to _). +export TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_ID=... +export TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_SECRET=... +export TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_DISCOVERY_URL=https://acme.okta.com/.well-known/openid-configuration +export TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_LABEL="Acme SSO" +export TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_ALLOWED_DOMAINS=your-company.com + +export TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_CLIENT_ID=... +export TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_CLIENT_SECRET=... +export TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_DISCOVERY_URL=https://login.microsoftonline.com//v2.0/.well-known/openid-configuration +export TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_LABEL="Microsoft 365" +``` + +The callback URL for each slot is +`{REDIRECT_BASE_URL}/api/auth/oauth/callback/{slot}` — register that +exact URL with your IdP. + +Slot names must match `^[a-z][a-z0-9_-]{0,31}$` and must not collide +with `google` / `github` (the built-ins). The Taskito user generated +for an OIDC login is namespaced as `{slot}:{sub}`, so two different +Okta tenants stay distinct users even when subjects overlap. + +## Role assignment for OAuth users + +The first time someone signs in via OAuth, the dashboard decides their +role using this rule: + +1. **`TASKITO_DASHBOARD_OAUTH_ADMIN_EMAILS` match** — case-insensitive + match against a verified email → role `admin`. +2. **Empty user table fallback** — if no users (password or OAuth) exist + yet, the first OAuth user with a verified email becomes `admin`. +3. **Everyone else** → role `viewer`. + +```bash +export TASKITO_DASHBOARD_OAUTH_ADMIN_EMAILS=alice@your-company.com,bob@your-company.com +``` + +Once a user is created, their role is **not** re-evaluated on subsequent +logins (you can change it from the dashboard or via the API). Their +`email` and `display_name` are refreshed from each new login's claims. + +## OAuth-only mode + +To disable password login entirely: + +```bash +export TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED=false +``` + +The dashboard refuses to start in OAuth-only mode if no provider is +configured (you'd have no way to log in). The login page hides the +username/password form and renders only provider buttons. + +## Allowlist semantics + +| Provider | Allowlist scope | Where it's checked | +|---|---|---| +| Google | `ALLOWED_DOMAINS` — the email domain (lowercased) must be in this list. Required: `email_verified=true`. | Server-side after JWKS verification of the ID token. | +| GitHub | `ALLOWED_ORGS` — user must be a member of at least one listed org. | `GET /orgs/{org}/members/{login}` returning 204. | +| Generic OIDC | `ALLOWED_DOMAINS` — same as Google. | Server-side after ID-token JWKS verification. | + +An **empty** allowlist means "any account from this provider is welcome" +— useful for personal projects but never appropriate for a production +deployment. Configure at least the admin-email list, and ideally a +domain/org allowlist too. + + + Allowlists are not editable from the dashboard UI. Changes require + restarting the server with new env values. This keeps the security + surface in one place (your deployment config) and avoids drift across + the operator's GitOps and the database. + + +## Security model + +| Control | Implementation | +|---|---| +| **PKCE** | S256 challenge derived from a 32-byte random verifier, per RFC 7636. Required by OAuth 2.1 and most providers in 2026. | +| **State** | 32-byte URL-safe random, stored server-side in `auth:oauth_state:` with a 5-min TTL. **Single-use** — deleted on first read. | +| **Nonce** | 16-byte random, embedded in the OIDC authorize request, verified against the ID-token `nonce` claim. Replay protection. | +| **ID-token signature** | Verified against the provider's JWKS (fetched from the discovery doc and cached per-provider). | +| **iss / aud / exp** | All validated; 60-second clock skew tolerance for `exp`. | +| **Open redirect** | The `next` query param is validated against `is_safe_redirect` — relative paths only, no scheme, no `//`. Falls back to `/`. | +| **HTTPS required** | `redirect_base_url` must be `https://` unless the host is `localhost` / `127.0.0.1`. Misconfiguration aborts startup. | +| **Provider tokens** | Never persisted. Only the verified identity flows into the Taskito session. | +| **Cross-provider linking** | Disabled by design. A given `(slot, subject)` always maps to one user. Two different providers with the same email = two different users. | + +## API surface + +| Method | Path | What it does | +|---|---|---| +| `GET` | `/api/auth/providers` | Public. Returns `{password_enabled, providers: [{slot, label, type}]}` for the login UI. | +| `GET` | `/api/auth/oauth/start/{slot}` | Public. Mints state, 302s to the provider's authorize URL. Accepts `?next=/path` (validated). | +| `GET` | `/api/auth/oauth/callback/{slot}` | Public. Validates state, exchanges code, enforces allowlist, creates/refreshes the user, sets cookies, 302s to `next`. | + +The callback uses the same `taskito_session` + `taskito_csrf` cookies +as password login — every other dashboard route works identically once +you're signed in. + +## Troubleshooting + +**"oauth_state_invalid"** — the state row expired (5-min window) or +already consumed. The user pressed back / refresh after the provider +redirect; have them start over. + +**"oauth_identity_failed: id_token issuer mismatch"** — the +`TASKITO_DASHBOARD_OAUTH_OIDC__DISCOVERY_URL` points to a +different issuer than what the IdP signed. Check the `issuer` field in +the discovery doc. + +**"oauth_allowlist_denied"** — the user authenticated successfully but +isn't in your allowlist. Either widen the allowlist or remove it. + +**Provider button doesn't appear** — `GET /api/auth/providers` returns +the list the UI renders. If the button is missing, check the server +logs for an env-var parse error at startup. The dashboard falls back to +password-only auth (logged at WARN) when env parsing fails. + +**"redirect_uri_mismatch" from the provider** — the callback URL you +registered with the provider doesn't match `{REDIRECT_BASE_URL}/api/auth/oauth/callback/{slot}`. +The trailing slash and the slot value must match exactly. + +## Env var reference + +```bash +# Required when any provider is configured. +TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL=https://taskito.company.com + +# Google. +TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID=... +TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET=... +TASKITO_DASHBOARD_OAUTH_GOOGLE_ALLOWED_DOMAINS=company.com,partner.com # optional + +# GitHub. +TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_ID=Iv1.xxxxx +TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_SECRET=... +TASKITO_DASHBOARD_OAUTH_GITHUB_ALLOWED_ORGS=org1,org2 # optional + +# Generic OIDC — list slots, then config each one. +TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS=okta,microsoft +TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_ID=... +TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_SECRET=... +TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_DISCOVERY_URL=https://acme.okta.com/.well-known/openid-configuration +TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_LABEL=Acme SSO # optional +TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_ALLOWED_DOMAINS=company.com # optional + +# Role bootstrap. +TASKITO_DASHBOARD_OAUTH_ADMIN_EMAILS=alice@company.com,bob@company.com # optional + +# Disable password login (OAuth-only mode). Defaults to true. +TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED=false # optional +``` diff --git a/docs/content/docs/guides/dashboard/task-overrides.mdx b/docs/content/docs/guides/dashboard/task-overrides.mdx new file mode 100644 index 0000000..4fe33da --- /dev/null +++ b/docs/content/docs/guides/dashboard/task-overrides.mdx @@ -0,0 +1,222 @@ +--- +title: Task & Queue Overrides +description: "Tune retry policy, concurrency, rate limits, and middleware per task from the dashboard — without redeploying." +--- + +import { Callout } from "fumadocs-ui/components/callout"; +import { Tab, Tabs } from "fumadocs-ui/components/tabs"; + +The decorator-declared values on `@queue.task(...)` are *defaults*. The +dashboard lets operators override them at runtime — adjust a rate +limit, pause a misbehaving task, lower the retry budget after an +incident — without redeploying. + +Two surfaces: + +- **Task overrides** — per-task knobs: rate limit, concurrency, + retries, retry backoff, timeout, priority, paused. +- **Queue overrides** — per-queue knobs: rate limit, concurrency, + paused. + +Plus a separate but related toggle for **middleware** on a per-task +basis (see [§ Middleware toggles](#middleware-toggles) below). + +## Tasks page + +The **Tasks** page lists every task registered on the live Queue with +its decorator defaults, any active override, and the *effective* value +(default merged with override). Overridden values render in accent so +"which knobs are pinned" is visible at a glance. + +![Tasks page with one override active](/screenshots/dashboard/tasks-list.png) + +Click **Edit** on any row to open the side sheet. The form mirrors +the decorator kwargs: + +![Override side sheet on the send_email task](/screenshots/dashboard/task-edit-overrides.png) + +- **Empty input** → inherit the decorator default +- **Value entered** → override the default +- **Clear override** → remove the row entirely; task falls back to + every decorator value + +| Field | Decorator equivalent | +|---|---| +| Rate limit | `rate_limit="100/m"` | +| Max concurrent | `max_concurrent=10` | +| Max retries | `max_retries=5` | +| Timeout | `timeout=300` (seconds) | +| Priority | `priority=2` | +| Paused | n/a — runtime-only | + +## When changes take effect + +This is the most important thing to internalize: + +| Change | Takes effect | +|---|---| +| Pausing a task | Next worker restart for the rate-limit/concurrency side effects, **but** the paused flag is plumbed through the live `paused_queues` mechanism so the scheduler stops dequeuing immediately for queue-level pauses | +| Pausing a queue | **Immediately** on running workers (writes to `paused_queues`) | +| Rate limit / max concurrent / retries / timeout / priority on a task | **Next worker restart** — the values are baked into `PyTaskConfig` at `run_worker` time and passed to the Rust scheduler | +| Rate limit / max concurrent on a queue | **Next worker restart** — same mechanism, merged into `queue_configs` JSON sent to Rust | +| Middleware on/off per task | **Next job** — middleware lookup happens at every task invocation | + +This split is intentional. Pause is a fast-path safety valve; +retry/rate-limit changes need scheduler buy-in and are deliberately +"restart to apply" so operators have a clear mental model of when the +new values take over. + + + Pulling rate-limit / retries / timeout into the Rust scheduler's + per-poll lookup would let those changes hot-reload too. The + ``PyTaskConfig`` → scheduler path would gain a cache-invalidation + counter (incremented on every override write) the poller checks + before each admission cycle. Until then, restart the worker to apply + changes to those knobs. + + +## Programmatic API + +The dashboard CRUD is a thin shell over the `Queue` API — you can +script overrides the same way: + +```python +from taskito import Queue + +queue = Queue(db_path="tasks.db") + +# Tasks +queue.set_task_override( + "myapp.tasks.send_email", + rate_limit="200/m", + max_retries=10, +) +queue.set_task_override("myapp.tasks.send_email", paused=True) # immediate-ish +queue.clear_task_override("myapp.tasks.send_email") + +# Queues +queue.set_queue_override("email", max_concurrent=5) +queue.set_queue_override("email", paused=True) # immediate +queue.clear_queue_override("email") + +# Discovery — what's registered + what's overridden +for entry in queue.registered_tasks(): + print(entry["name"], entry["effective"]) + +for entry in queue.registered_queues(): + print(entry["name"], entry["effective"]) +``` + +Allowed task override fields: `rate_limit`, `max_concurrent`, +`max_retries`, `retry_backoff`, `timeout`, `priority`, `paused`. + +Allowed queue override fields: `rate_limit`, `max_concurrent`, +`paused`. + +The store validates types and ranges before persisting — a typo (or a +typed-in `-1`) raises `ValueError` rather than writing garbage. The +dashboard handler surfaces the same errors as `400 Bad Request`. + +## Storage + +Overrides live as JSON entries under +`overrides:task:` and `overrides:queue:` keys +in the `dashboard_settings` table. SQLite, PostgreSQL, and Redis +backends all support them uniformly — no new schema. The encoded +JSON only includes fields the operator actually set, so removing a +field by passing `None` shrinks the row rather than leaving stale +data. + +## Middleware toggles + +Middleware are normally global (via `Queue(middleware=[...])`) or +per-task (via `@queue.task(middleware=[...])`). The dashboard adds a +third axis: **temporarily disable a middleware for one task** without +touching code. Useful when: + +- A logging middleware is generating noise for one chatty task +- A retry-policy middleware is interfering with a specific debug job +- You want to A/B compare runs with and without a middleware + +### Toggle from the dashboard + +Open the same side sheet as for overrides and switch to the +**Middleware** tab. Each registered middleware shows up as a pill +button — green for enabled, grey for disabled. + +![Middleware tab with one toggle disabled](/screenshots/dashboard/task-edit-middleware.png) + +Changes take effect on the **next job** — no worker restart required. +The middleware lookup runs at every task invocation, so the next time +the task is dequeued the new chain applies. + +### Naming middleware + +Every `TaskMiddleware` carries a stable `name` attribute that the +disable list keys on. By default the name is the fully-qualified class +path (e.g. `myapp.middleware.LoggingMiddleware`) so it survives +restarts. Override it to pin a shorter, user-facing name: + +```python +from taskito.middleware import TaskMiddleware + +class SentryMiddleware(TaskMiddleware): + name = "sentry" # shows up as "sentry" in the dashboard + + def before(self, ctx): + ... +``` + +The dashboard rejects toggles for unknown middleware names (`404`), so +typos can't silently write no-op disables. + +### Programmatic API + +```python +queue.list_middleware() # [{name, class_path, scopes}, ...] +queue.disable_middleware_for_task("myapp.tasks.send_email", "demo.metrics") +queue.enable_middleware_for_task("myapp.tasks.send_email", "demo.metrics") +queue.clear_middleware_disables("myapp.tasks.send_email") +queue.get_disabled_middleware_for("myapp.tasks.send_email") # ["demo.metrics"] +``` + +## Examples + +### Pause one task without redeploying + +A flaky third-party API is rate-limiting your `send_email` task and +you want to stop new sends while you investigate: + +```python +queue.set_task_override("myapp.tasks.send_email", paused=True) +# ... or from the dashboard: Tasks → send_email → Edit → check "Pause this task" +``` + +Existing in-flight jobs finish normally; nothing new dequeues until +you clear the override. + +### Lower a rate limit during an incident + +Cut `send_email` from 200/m to 30/m while a downstream is recovering: + +```python +queue.set_task_override("myapp.tasks.send_email", rate_limit="30/m") +# Restart the workers for the change to take effect on the scheduler. +``` + +### Disable a heavyweight middleware for one task + +A debug middleware is dumping payloads for every invocation, and you +want to keep it on for everything except your high-volume +`process_image` task: + +```python +queue.disable_middleware_for_task("myapp.tasks.process_image", "debug.payload") +# Takes effect on the next process_image job, no restart needed. +``` + +## Reference + +- [Dashboard REST API: Tasks & overrides](/guides/dashboard/rest-api#tasks-and-overrides) +- [Dashboard REST API: Middleware](/guides/dashboard/rest-api#middleware) +- [Tasks decorator reference](/api-reference/task) diff --git a/docs/content/docs/guides/extensibility/events-webhooks.mdx b/docs/content/docs/guides/extensibility/events-webhooks.mdx index 03d13d7..8323753 100644 --- a/docs/content/docs/guides/extensibility/events-webhooks.mdx +++ b/docs/content/docs/guides/extensibility/events-webhooks.mdx @@ -1,17 +1,28 @@ --- title: Events & Webhooks -description: "In-process event bus and HMAC-signed HTTP webhooks for job and worker lifecycle events." +description: "In-process event bus, dashboard-managed HMAC-signed webhooks, persistent delivery log, and replay." --- -taskito includes an in-process event bus and webhook delivery system for -reacting to job lifecycle events. +import { Callout } from "fumadocs-ui/components/callout"; +import { Tab, Tabs } from "fumadocs-ui/components/tabs"; + +taskito has two complementary primitives for reacting to job lifecycle: + +1. **In-process event bus** — `queue.on_event()` registers Python callbacks + dispatched in a thread pool. Same process, lowest latency, no + serialization, no HTTP. +2. **Webhooks** — HMAC-signed HTTP POSTs to external endpoints, managed + from the dashboard (or the Python API), with a persistent delivery + log and one-click replay. + +This guide covers both, starting with the events that drive them. ## Event types The `EventType` enum defines all available lifecycle events: -| Event | Fired when | Payload fields | -|-------|------------|----------------| +| Event | Fires when | Payload fields | +|---|---|---| | `JOB_ENQUEUED` | A job is added to the queue | `job_id`, `task_name`, `queue` | | `JOB_COMPLETED` | A job finishes successfully | `job_id`, `task_name`, `queue` | | `JOB_FAILED` | A job raises an exception (before retry) | `job_id`, `task_name`, `queue`, `error` | @@ -26,18 +37,11 @@ The `EventType` enum defines all available lifecycle events: | `QUEUE_PAUSED` | A named queue is paused | `queue` | | `QUEUE_RESUMED` | A paused queue is resumed | `queue` | -`JOB_RETRYING`, `JOB_DEAD`, and `JOB_CANCELLED` are emitted by the Rust -result handler immediately after the scheduler records the outcome. -Middleware hooks (`on_retry`, `on_dead_letter`, `on_cancel`) are called in -the same result-handling pass, after the event fires. +## In-process listeners -`QUEUE_PAUSED` and `QUEUE_RESUMED` are emitted synchronously by -`queue.pause()` and `queue.resume()` after the queue state is written to -storage. - -## Registering listeners - -Use `queue.on_event()` to subscribe a callback to a specific event type: +Use `queue.on_event()` to subscribe a callback. Callbacks run in a +`ThreadPoolExecutor` so they never block the worker, and exceptions are +logged but don't affect job processing. ```python from taskito import Queue @@ -45,66 +49,97 @@ from taskito.events import EventType queue = Queue(db_path="tasks.db") -def on_failure(event_type: EventType, payload: dict): +def on_failure(event_type: EventType, payload: dict) -> None: print(f"Job {payload['job_id']} failed: {payload.get('error')}") queue.on_event(EventType.JOB_FAILED, on_failure) ``` -### Callback signature +Configure the pool size via `Queue(event_workers=N)` (default 4) if +your callbacks are slow. + +## Webhooks (dashboard-managed) -All callbacks receive two arguments: +Webhook subscriptions are first-class **persisted** resources — survive +restarts, propagate across every worker pointed at the same backend, +and are fully manageable from the dashboard. The same surface is +available programmatically via `queue.add_webhook()` / +`list_webhooks()` / `update_webhook()` / etc. -- `event_type` (`EventType`) — the event that occurred -- `payload` (`dict`) — event details including `job_id`, `task_name`, `queue`, and event-specific fields +### Configure from the dashboard -### Async delivery +The Webhooks page (sidebar → Configuration → Webhooks) lists every +subscription with its URL, event filter, optional task filter, retry +policy, and status. -Callbacks are dispatched asynchronously in a `ThreadPoolExecutor`. The -thread pool size defaults to 4 and can be configured via -`Queue(event_workers=N)`. This means: +![Webhooks page with three subscriptions](/screenshots/dashboard/webhooks-list.png) -- Callbacks never block the worker -- Exceptions in callbacks are logged but do not affect job processing -- Callbacks may execute slightly after the event occurs +Click **+ New webhook** to add an endpoint. The dialog walks you +through URL, optional description, the event-type multi-select, an +optional per-task filter, and a checkbox to auto-generate an +HMAC-SHA256 signing secret. -## Webhooks +![New webhook dialog](/screenshots/dashboard/webhook-create-dialog.png) -For external systems, register webhook URLs to receive HTTP POST requests -on job events. +After save, the new secret is shown **once** in a copy-and-reveal card +— treat it like an API key. The same flow applies when you rotate the +secret later from the row-actions menu. -### Registering a webhook +### Per-row actions + +Each row has a "⋯" menu: + +| Action | Effect | +|---|---| +| **View deliveries** | Open the persistent delivery log (see below) | +| **Send test** | POST a synthetic `test.ping` event synchronously and toast the result | +| **Enable / Disable** | Flip the active flag without losing the configuration | +| **Rotate secret** | Generate a new HMAC secret. Confirm dialog prevents accidents | +| **Delete** | Type-to-confirm destructive dialog removes the subscription | + +### Configure from Python + +The same operations are available programmatically: ```python -queue.add_webhook( - url="https://example.com/hooks/taskito", +from taskito import Queue +from taskito.events import EventType + +queue = Queue(db_path="tasks.db") + +sub = queue.add_webhook( + url="https://hooks.example.com/ops-failures", events=[EventType.JOB_FAILED, EventType.JOB_DEAD], - headers={"Authorization": "Bearer mytoken"}, - secret="my-signing-secret", + secret="whsec_my_signing_secret", + description="Page ops on permanent failures", + max_retries=5, + timeout=8.0, + task_filter=["myapp.tasks.send_email"], # optional per-task gate ) +print(sub.id) # use this to update / remove later + +queue.update_webhook(sub.id, enabled=False) +queue.rotate_webhook_secret(sub.id) +queue.remove_webhook(sub.id) ``` | Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `url` | `str` | — | URL to POST event payloads to (must be `http://` or `https://`) | +|---|---|---|---| +| `url` | `str` | — | http/https URL. SSRF-guarded — see below | | `events` | `list[EventType] \| None` | `None` | Event types to subscribe to. `None` means all events | -| `headers` | `dict[str, str] \| None` | `None` | Extra HTTP headers to include in requests | -| `secret` | `str \| None` | `None` | HMAC-SHA256 signing secret | +| `task_filter` | `list[str] \| None` | `None` | Restrict to specific task names. `None` means all tasks | +| `headers` | `dict[str, str] \| None` | `None` | Extra HTTP headers (e.g. `Authorization`) | +| `secret` | `str \| None` | `None` | HMAC-SHA256 signing key | | `max_retries` | `int` | `3` | Maximum delivery attempts | | `timeout` | `float` | `10.0` | HTTP request timeout in seconds | | `retry_backoff` | `float` | `2.0` | Base for exponential backoff between retries | +| `description` | `str \| None` | `None` | Free-form label shown in the dashboard | ### HMAC signing -When a `secret` is provided, each webhook request includes an -`X-Taskito-Signature` header: - -``` -X-Taskito-Signature: sha256= -``` - -The signature is computed over the JSON request body using HMAC-SHA256. -Verify it on the receiving end: +When a `secret` is set, every webhook request includes +`X-Taskito-Signature: sha256=`. Verify it on the receiving +end: ```python import hashlib @@ -115,46 +150,93 @@ def verify_signature(body: bytes, signature: str, secret: str) -> bool: return hmac.compare_digest(f"sha256={expected}", signature) ``` -### Retry behavior +The signature is computed over the raw JSON request body. **Verify +before parsing the body** — that way a forged payload never reaches +your business logic. + + + The secret column stores the value as-is in the dashboard settings + table. The DB is already trusted with everything else taskito + persists (job payloads, error tracebacks). If you need at-rest + encryption beyond filesystem-level (e.g. SQLite encrypted with + SQLCipher), the dashboard never returns the raw secret after the + initial create / rotate response — only a ``has_secret`` indicator. + -Failed webhook deliveries are retried with exponential backoff. The number -of attempts, request timeout, and backoff base are configurable per webhook -via `max_retries`, `timeout`, and `retry_backoff`. With the defaults -(`max_retries=3`, `retry_backoff=2.0`): +### SSRF guard + +Outbound webhook URLs are validated before the manager queues any +delivery. The guard rejects: + +- Non-`http` / `https` schemes +- `localhost`, `*.local`, `*.internal`, `*.intranet`, `*.lan`, + `*.private` +- Any address that resolves to a loopback, link-local, RFC1918, + multicast, or unspecified range (including the AWS metadata service + at `169.254.169.254`) + +Set `TASKITO_WEBHOOKS_ALLOW_PRIVATE=1` to lift the guard for local +development. Keep it on in production. + +### Retry behaviour + +Failed webhook deliveries are retried with exponential backoff, +configurable per subscription via `max_retries`, `timeout`, and +`retry_backoff`. With the defaults (`max_retries=3`, +`retry_backoff=2.0`): | Attempt | Delay before next retry | -|---------|------------------------| +|---|---| | 1st retry | 1 second (`2.0 ** 0`) | | 2nd retry | 2 seconds (`2.0 ** 1`) | | 3rd retry | — (final) | -4xx responses are not retried. If all attempts fail, a warning is logged -and the event is dropped. +4xx responses are NOT retried — they're treated as client errors and +the delivery is marked `failed`. 5xx responses retry until exhausted, +at which point the delivery is marked `dead`. -### Event filtering +## Delivery log + replay -Subscribe to specific events or all events: +Every webhook attempt — successful, failed, or dead-lettered — is +recorded under the subscription so you can debug failures without +leaving the dashboard. -```python -# Only failure events -queue.add_webhook( - url="https://slack.example.com/webhook", - events=[EventType.JOB_FAILED, EventType.JOB_DEAD], -) +![Delivery log with mixed outcomes](/screenshots/dashboard/webhook-deliveries.png) -# All events -queue.add_webhook(url="https://monitoring.example.com/events") -``` +Each row carries: + +- **When** — relative timestamp ("2 minutes ago") +- **Event** — the event type the delivery was for (`job.completed`, `job.failed`, etc.) +- **Status** — `delivered` (green), `failed` (yellow), `dead` (red) +- **Code** — final HTTP status returned by the endpoint +- **Latency** — wall time of the final attempt +- **Attempts** — total retry count consumed + +Click any row to inspect the full payload, the truncated response body +(first 2 KiB), and any transport-level error. The **Replay** button +re-fires the stored payload synchronously and records the outcome as a +fresh delivery — the original record is preserved for the audit trail. + +### Retention + +Each subscription keeps the most recent 200 deliveries in a FIFO ring +buffer per webhook (configurable via the +`DeliveryStore(max_per_webhook=N)` constructor). Successful and +failed deliveries are stored uniformly so the replay view always +matches what really happened. ## Examples -### Slack notification on job failure +### Slack-on-failure (in-process listener) + +When the receiver is the same Python process and you don't need +persistence, the event bus is the cheapest path: ```python import requests from taskito.events import EventType -def notify_slack(event_type: EventType, payload: dict): +def notify_slack(event_type: EventType, payload: dict) -> None: requests.post( "https://hooks.slack.com/services/T.../B.../xxx", json={ @@ -168,7 +250,11 @@ queue.on_event(EventType.JOB_FAILED, notify_slack) queue.on_event(EventType.JOB_DEAD, notify_slack) ``` -### Webhook to external monitoring +### Persistent webhook to an external service + +When the receiver is a separate service — auditing, monitoring, a +different team's pipeline — use a webhook so the delivery log is +preserved across restarts: ```python queue.add_webhook( @@ -176,58 +262,73 @@ queue.add_webhook( events=[EventType.JOB_COMPLETED, EventType.JOB_FAILED, EventType.JOB_DEAD], secret="whsec_abc123", headers={"X-Source": "taskito-prod"}, + description="Forward terminal job outcomes to the monitoring service", ) ``` -The monitoring service receives JSON payloads like: +Payload shape received by the endpoint: ```json { - "event": "job.failed", - "job_id": "01H5K6X...", - "task_name": "myapp.tasks.process", - "queue": "default", - "error": "ConnectionError: ..." + "event": "job.failed", + "job_id": "01H5K6X...", + "task_name": "myapp.tasks.process", + "queue": "default", + "error": "ConnectionError: ..." } ``` -### Job completion tracking +### Flask receiver + +A minimal Flask app that receives and verifies taskito webhooks: ```python -from taskito.events import EventType +from flask import Flask, request, abort +import hashlib, hmac -completed_count = 0 +app = Flask(__name__) +WEBHOOK_SECRET = "whsec_my_signing_secret" -def track_completion(event_type: EventType, payload: dict): - global completed_count - completed_count += 1 - if completed_count % 100 == 0: - print(f"Milestone: {completed_count} jobs completed") +@app.route("/hooks/taskito", methods=["POST"]) +def receive_webhook(): + signature = request.headers.get("X-Taskito-Signature", "") + expected = hmac.new( + WEBHOOK_SECRET.encode(), request.data, hashlib.sha256 + ).hexdigest() -queue.on_event(EventType.JOB_COMPLETED, track_completion) + if not hmac.compare_digest(f"sha256={expected}", signature): + abort(401) + + event = request.json + print(f"Received event: {event['event']} for job {event['job_id']}") + return "", 204 ``` -### Database logging for audit trail +### Database audit trail (in-process listener) ```python from taskito.events import EventType -def audit_log(event_type: EventType, payload: dict): +def audit_log(event_type: EventType, payload: dict) -> None: db.execute( "INSERT INTO audit_log (event, job_id, task_name, timestamp) VALUES (?, ?, ?, ?)", (event_type.value, payload["job_id"], payload["task_name"], time.time()), ) -# Subscribe to all important events -for event in [EventType.JOB_ENQUEUED, EventType.JOB_COMPLETED, EventType.JOB_FAILED, EventType.JOB_DEAD]: +for event in [ + EventType.JOB_ENQUEUED, + EventType.JOB_COMPLETED, + EventType.JOB_FAILED, + EventType.JOB_DEAD, +]: queue.on_event(event, audit_log) ``` ## Event ordering -Events fire in the order the scheduler processes results — typically the -order jobs complete. For jobs that complete nearly simultaneously, ordering -is **not guaranteed** across different workers or threads. +Events fire in the order the scheduler processes results — typically +the order jobs complete. For jobs that complete nearly simultaneously, +ordering is **not guaranteed** across different workers or threads. Within a single job's lifecycle, events always fire in this order: @@ -236,50 +337,7 @@ Within a single job's lifecycle, events always fire in this order: 3. `JOB_RETRYING` (if retried, before the next attempt) 4. `JOB_DEAD` (if all retries exhausted) -## Backpressure - -Events are dispatched to a thread pool (default size: 4, configurable via -`event_workers=N`). If callbacks are slow and events arrive faster than -they can be processed, they queue in memory. - -For high-volume event scenarios: - -```python -queue = Queue(event_workers=16) # More threads for slow callbacks -``` +## Reference -If a callback raises an exception, it is logged and the event is dropped — -it does not retry or block other callbacks. - -## Webhook failure - -Webhooks retry with exponential backoff (up to `max_retries`). After all -retries are exhausted, the webhook delivery is **logged and dropped** — -there is no dead-letter queue for webhooks. Monitor webhook failures via -the `on_failure` callback or structured logging. - -### Webhook receiver (Flask) - -A minimal Flask app that receives and verifies taskito webhooks: - -```python -from flask import Flask, request, abort -import hashlib, hmac - -app = Flask(__name__) -WEBHOOK_SECRET = "my-signing-secret" - -@app.route("/hooks/taskito", methods=["POST"]) -def receive_webhook(): - signature = request.headers.get("X-Taskito-Signature", "") - expected = hmac.new( - WEBHOOK_SECRET.encode(), request.data, hashlib.sha256 - ).hexdigest() - - if not hmac.compare_digest(f"sha256={expected}", signature): - abort(401) - - event = request.json - print(f"Received event: {event['event']} for job {event['job_id']}") - return "", 204 -``` +- [Dashboard REST API for webhooks and deliveries](/guides/dashboard/rest-api#webhooks) +- [Dashboard auth — how to call these endpoints from a script](/guides/dashboard/authentication) diff --git a/docs/content/docs/guides/meta.json b/docs/content/docs/guides/meta.json index 99dfadf..aaec693 100644 --- a/docs/content/docs/guides/meta.json +++ b/docs/content/docs/guides/meta.json @@ -8,6 +8,7 @@ "advanced-execution", "operations", "observability", + "dashboard", "resources", "workflows", "integrations", diff --git a/docs/content/docs/guides/observability/dashboard-api.mdx b/docs/content/docs/guides/observability/dashboard-api.mdx deleted file mode 100644 index 2e3028e..0000000 --- a/docs/content/docs/guides/observability/dashboard-api.mdx +++ /dev/null @@ -1,234 +0,0 @@ ---- -title: Dashboard REST API -description: "JSON endpoints for stats, jobs, dead letters, metrics, logs, infrastructure, observability." ---- - -The dashboard exposes a JSON API you can use independently of the UI. All -endpoints return `application/json` with `Access-Control-Allow-Origin: *`. - -## Stats - -### `GET /api/stats` - -Queue statistics snapshot. - -```json -{ - "pending": 12, - "running": 3, - "completed": 450, - "failed": 2, - "dead": 1, - "cancelled": 0 -} -``` - -### `GET /api/stats/queues` - -Per-queue statistics. Pass `?queue=name` for a single queue, or omit for all -queues. - -```bash -curl http://localhost:8080/api/stats/queues -curl http://localhost:8080/api/stats/queues?queue=emails -``` - -## Jobs - -### `GET /api/jobs` - -Paginated list of jobs with filtering. - -| Parameter | Type | Default | Description | -|---|---|---|---| -| `status` | `string` | all | Filter by status | -| `queue` | `string` | all | Filter by queue name | -| `task` | `string` | all | Filter by task name | -| `metadata` | `string` | — | Search metadata (LIKE) | -| `error` | `string` | — | Search error text (LIKE) | -| `created_after` | `int` | — | Unix ms timestamp | -| `created_before` | `int` | — | Unix ms timestamp | -| `limit` | `int` | `20` | Page size | -| `offset` | `int` | `0` | Pagination offset | - -```bash -curl http://localhost:8080/api/jobs?status=running&limit=10 -``` - -### `GET /api/jobs/{id}` - -Full detail for a single job. - -### `GET /api/jobs/{id}/errors` - -Error history for a job (one entry per failed attempt). - -### `GET /api/jobs/{id}/logs` - -Task execution logs for a specific job. - -### `GET /api/jobs/{id}/replay-history` - -Replay history for a job that has been replayed. - -### `GET /api/jobs/{id}/dag` - -Dependency graph for a job (nodes and edges). - -### `POST /api/jobs/{id}/cancel` - -Cancel a pending job. - -```json -{ "cancelled": true } -``` - -### `POST /api/jobs/{id}/replay` - -Replay a completed or failed job with the same payload. - -```json -{ "replay_job_id": "01H5K7Y..." } -``` - -## Dead letters - -### `GET /api/dead-letters` - -Paginated list of dead letter entries. Supports `limit` and `offset` -parameters. - -### `POST /api/dead-letters/{id}/retry` - -Re-enqueue a dead letter job. - -```json -{ "new_job_id": "01H5K7Y..." } -``` - -### `POST /api/dead-letters/purge` - -Purge all dead letters. - -```json -{ "purged": 42 } -``` - -## Metrics - -### `GET /api/metrics` - -Per-task execution metrics. - -| Parameter | Type | Default | Description | -|---|---|---|---| -| `task` | `string` | all | Filter by task name | -| `since` | `int` | `3600` | Lookback window in seconds | - -### `GET /api/metrics/timeseries` - -Time-bucketed metrics for charts. - -| Parameter | Type | Default | Description | -|---|---|---|---| -| `task` | `string` | all | Filter by task name | -| `since` | `int` | `3600` | Lookback window in seconds | -| `bucket` | `int` | `60` | Bucket size in seconds | - -## Logs - -### `GET /api/logs` - -Query task execution logs across all jobs. - -| Parameter | Type | Default | Description | -|---|---|---|---| -| `task` | `string` | all | Filter by task name | -| `level` | `string` | all | Filter by log level | -| `since` | `int` | `3600` | Lookback window in seconds | -| `limit` | `int` | `100` | Max entries | - -## Infrastructure - -### `GET /api/workers` - -List registered workers with heartbeat status. - -### `GET /api/circuit-breakers` - -Current state of all circuit breakers. - -### `GET /api/resources` - -Worker resource health and pool status. - -### `GET /api/queues/paused` - -List paused queue names. - -### `POST /api/queues/{name}/pause` - -Pause a queue (jobs stop being dequeued). - -### `POST /api/queues/{name}/resume` - -Resume a paused queue. - -## Observability - -### `GET /api/proxy-stats` - -Per-handler proxy reconstruction metrics. - -### `GET /api/interception-stats` - -Interception strategy performance metrics. - -### `GET /api/scaler` - -KEDA-compatible autoscaler payload. Pass `?queue=name` for a specific queue. - -### `GET /health` - -Liveness check. Always returns `{"status": "ok"}`. - -### `GET /readiness` - -Readiness check with storage, worker, and resource health. - -### `GET /metrics` - -Prometheus metrics endpoint (requires `prometheus-client` package). - -## Using the API programmatically - -```python -import requests - -# Health check script -stats = requests.get("http://localhost:8080/api/stats").json() - -if stats["dead"] > 0: - print(f"WARNING: {stats['dead']} dead letter(s)") - -if stats["running"] > 100: - print(f"WARNING: {stats['running']} jobs running, possible backlog") -``` - -```python -# Pause a queue during deployment -requests.post("http://localhost:8080/api/queues/default/pause") - -# ... deploy ... - -# Resume after deployment -requests.post("http://localhost:8080/api/queues/default/resume") -``` - -```python -# Retry all dead letters -dead = requests.get("http://localhost:8080/api/dead-letters?limit=100").json() -for entry in dead: - requests.post(f"http://localhost:8080/api/dead-letters/{entry['id']}/retry") - print(f"Retried {entry['task_name']}") -``` diff --git a/docs/content/docs/guides/observability/dashboard.mdx b/docs/content/docs/guides/observability/dashboard.mdx deleted file mode 100644 index 192500b..0000000 --- a/docs/content/docs/guides/observability/dashboard.mdx +++ /dev/null @@ -1,265 +0,0 @@ ---- -title: Web Dashboard -description: "Zero-dependency built-in web UI for browsing jobs, metrics, workers, and managing the queue." ---- - -import { Callout } from "fumadocs-ui/components/callout"; -import { Tab, Tabs } from "fumadocs-ui/components/tabs"; - -taskito ships with a built-in web dashboard for monitoring jobs, inspecting -dead letters, and managing your task queue in real time. The dashboard is a -single-page application served directly from the Python package — **zero -extra dependencies required**. - -## Launching the dashboard - - - - -```bash -taskito dashboard --app myapp:queue -``` - -The `--app` argument uses the same `module:attribute` format as the worker. - - - - -```python -from taskito.dashboard import serve_dashboard -from myapp import queue - -serve_dashboard(queue, host="0.0.0.0", port=8000) -``` - - - - -By default the dashboard starts on `http://localhost:8080`. - -### CLI options - -| Flag | Default | Description | -|---|---|---| -| `--app` | *required* | Module path to your `Queue` instance, e.g. `myapp:queue` | -| `--host` | `127.0.0.1` | Bind address | -| `--port` | `8080` | Bind port | - -```bash -# Bind to all interfaces on port 9000 -taskito dashboard --app myapp:queue --host 0.0.0.0 --port 9000 -``` - - - The dashboard reads directly from the same SQLite database as the worker. - You can run them side by side without any coordination: - - ```bash - # Terminal 1 - taskito worker --app myapp:queue - - # Terminal 2 - taskito dashboard --app myapp:queue - ``` - - -## Dashboard features - -The dashboard is a React + Vite + TypeScript SPA routed via TanStack Router, -styled with Tailwind v4 and shadcn/ui, and shipped as hash-busted multi-file -assets under `py_src/taskito/static/dashboard/`. - -### Design - -- **Dark and light mode** — Toggle between themes via the sun/moon button in the header. Preference is stored in `localStorage` and follows the system scheme by default. -- **Auto-refresh** — Configurable refresh interval (2s, 5s, 10s, or off) via the header dropdown. All pages auto-refresh at the selected interval; TanStack Query handles caching and background revalidation. -- **Command palette** — `⌘K` / `Ctrl+K` opens a cmdk palette for route navigation and common actions. -- **Icons** — Lucide icons throughout for visual clarity. -- **Toast notifications** — Every action shows a success or error toast via sonner. Optimistic mutations update the UI immediately and roll back on error. -- **Destructive confirms** — Irreversible actions (purge, retry all) use a type-to-confirm dialog. -- **Loading states** — Skeleton screens for tables and cards, error boundaries with retry. -- **Responsive layout** — Sidebar navigation with grouped sections (Monitoring, Infrastructure, Advanced). The main content area scrolls independently. - -### Pages - -| Page | Description | -|---|---| -| **Overview** | Stats cards with status icons, throughput sparkline chart, recent jobs table | -| **Jobs** | Filterable job listing (status, queue, task, metadata, error, date range) with pagination | -| **Job Detail** | Full job info, error history, task logs, replay history, dependency DAG visualization | -| **Metrics** | Per-task performance table (avg, P50, P95, P99) with timeseries chart and time range selector | -| **Logs** | Structured task execution logs with task/level filters | -| **Workers** | Worker cards with heartbeat status, queue assignments, and tags | -| **Queues** | Per-queue stats (pending/running), pause and resume controls | -| **Resources** | Worker DI runtime status — health, scope, init duration, pool stats, dependencies | -| **Circuit Breakers** | Automatic failure protection state (closed/open/half_open), thresholds, cooldowns | -| **Dead Letters** | Failed jobs that exhausted retries — retry individual entries or purge all | -| **System** | Proxy reconstruction and interception strategy metrics | - - - The built SPA ships inside the Python wheel under - `py_src/taskito/static/dashboard/` and is served by the Python dashboard - process. No Node.js, no pnpm, no CDN at runtime — just `pip install - taskito`. Node.js and pnpm are only needed by contributors rebuilding the - dashboard source in `dashboard/`. - - -## Tutorial - -This walkthrough covers every dashboard page and how to use it. - -### Step 1: start the dashboard - -Start a worker and the dashboard in two terminals: - -```bash -# Terminal 1 — start the worker -taskito worker --app myapp:queue - -# Terminal 2 — start the dashboard -taskito dashboard --app myapp:queue -``` - -You should see: - -``` -taskito dashboard → http://127.0.0.1:8080 -Press Ctrl+C to stop -``` - -Open `http://localhost:8080` in your browser. - -### Step 2: Overview page - -The first page you see is the **Overview**. It shows: - -- **Stats cards** — Six cards at the top showing pending, running, completed, failed, dead, and cancelled job counts. -- **Throughput chart** — A green sparkline showing jobs processed per second over the last 60 refresh intervals. -- **Recent jobs table** — The 10 most recent jobs. Click any row to open its detail view. - -The stats update automatically based on the refresh interval you select in -the header (default: 5 seconds). - -### Step 3: browsing and filtering jobs - -Click **Jobs** in the sidebar. This page shows: - -- **Stats grid** — Same six stat cards as the overview. -- **Filter panel** — Status dropdown, queue, task, metadata, error text, created-after/before pickers. -- **Results table** — Paginated list showing ID, task, queue, status, priority, progress, retries, and created time. - -Use the **Prev / Next** buttons at the bottom to paginate. - -### Step 4: inspecting a job - -Click any job row to open the **Job Detail** page. The detail card shows: - -- A colored top border matching the job status (green for complete, red for failed, etc.) -- Full job ID, status badge, task name, queue, priority, progress bar, retries, timestamps -- **Error** field (if the job failed) displayed in a red-highlighted box -- Unique key and metadata (if set) - -**Actions:** - -- **Cancel Job** — Visible only for pending jobs. Sends a cancel request and shows a toast. -- **Replay** — Re-enqueue the job with the same payload. Navigates to the new job's detail page. - -**Sections below the detail card:** Error History, Task Logs, Replay -History, and a Dependency Graph visualization for jobs with dependencies. - -### Step 5: monitoring metrics - -Click **Metrics** in the sidebar. This page shows a time-range selector (1h -/ 6h / 24h), a stacked bar chart of success/failure counts per time bucket, -and a per-task table with avg / P50 / P95 / P99 / min / max latency. - -### Step 6: viewing logs - -Click **Logs** in the sidebar. Filter by task name or level. Each log entry -shows time, level badge, task name, job ID, message, and structured extra -data. - -### Step 7: workers - -Click **Workers**. Each active worker is displayed as a card showing the -green dot for liveness, worker ID, queues consumed, last heartbeat, -registration time, and tags. - -### Step 8: managing queues - -Click **Queues**. Per-queue table with pending/running counts, pause/resume -buttons, and status badges. - - - Pausing a queue prevents the scheduler from dequeuing new jobs from it. - Jobs already running will complete normally. Enqueuing new jobs still - works — they'll be picked up when the queue is resumed. - - -### Step 9: resources - -Click **Resources**. Shows registered worker DI runtime entries (name, -scope, health, init duration, recreations, dependencies, pool stats). - -### Step 10: circuit breakers - -Click **Circuit Breakers**. State badge (closed/open/half_open), failure -count, threshold, window, cooldown. - -### Step 11: dead letter queue - -Click **Dead Letters**. Retry individual entries with the **Retry** button, -or purge all with the type-to-confirm **Purge All** in the header. - -### Step 12: system internals - -Click **System**. Two tables: Proxy Reconstruction (per-handler metrics) -and Interception (per-strategy metrics). - -### Step 13: switching themes - -Click the sun/moon icon in the top-right of the header. - -### Step 14: changing refresh rate - -Use the **Refresh** dropdown in the header — 2s, 5s, 10s, or off. - - - The dashboard also exposes a full JSON API. See the - [Dashboard REST API](/guides/observability/dashboard-api) reference - for all endpoints. - - -## Development - -Contributors who want to modify the dashboard source: - -```bash -# Install dependencies (pnpm is pinned via the `packageManager` field) -cd dashboard && pnpm install - -# Start Vite dev server (proxies /api/* to localhost:8080) -pnpm run dev - -# In another terminal, start the backend -taskito dashboard --app myapp:queue - -# Build and copy to Python package -pnpm run build -``` - - - Run `corepack enable` once (Node 16+) and pnpm will be provisioned - automatically from the version pinned in `dashboard/package.json`. - - -The build produces a static `index.html` plus hashed JS/CSS chunks under -`py_src/taskito/static/dashboard/`. The built assets aren't committed — -release tooling runs `pnpm -C dashboard build` before packaging so the -wheel ships them. - - - The dashboard does not include authentication. If you expose it beyond - `localhost`, place it behind a reverse proxy with authentication (e.g. - nginx with basic auth, or an OAuth2 proxy). - diff --git a/docs/content/docs/guides/observability/index.mdx b/docs/content/docs/guides/observability/index.mdx index 9570dd9..7b2b507 100644 --- a/docs/content/docs/guides/observability/index.mdx +++ b/docs/content/docs/guides/observability/index.mdx @@ -9,5 +9,8 @@ Monitor, log, and inspect your task queue in real time. |---|---| | [Monitoring & Hooks](/guides/observability/monitoring) | Queue stats, progress tracking, worker heartbeat, and alerting hooks | | [Structured Logging](/guides/observability/logging) | Per-task structured logs with automatic context | -| [Web Dashboard](/guides/observability/dashboard) | Built-in web UI for browsing jobs, metrics, and worker status | -| [Dashboard REST API](/guides/observability/dashboard-api) | Programmatic access to all dashboard data via REST endpoints | +| [Structured Notes](/guides/observability/notes) | Operator-visible metadata attached to individual jobs | + +For the built-in web UI, see the [Dashboard](/guides/dashboard) section — +it covers the browser app, password and SSO login, runtime task/queue +overrides, and the underlying REST API. diff --git a/docs/content/docs/guides/observability/meta.json b/docs/content/docs/guides/observability/meta.json index 474b374..41daf73 100644 --- a/docs/content/docs/guides/observability/meta.json +++ b/docs/content/docs/guides/observability/meta.json @@ -1,4 +1,4 @@ { "title": "Observability", - "pages": ["monitoring", "logging", "notes", "dashboard", "dashboard-api"] + "pages": ["monitoring", "logging", "notes"] } diff --git a/docs/content/docs/guides/observability/monitoring.mdx b/docs/content/docs/guides/observability/monitoring.mdx index 29247a2..c5e1d2e 100644 --- a/docs/content/docs/guides/observability/monitoring.mdx +++ b/docs/content/docs/guides/observability/monitoring.mdx @@ -120,7 +120,7 @@ workers = await queue.aworkers() The worker heartbeat is also available via the dashboard REST API at `GET /api/workers`. See the -[Dashboard](/guides/observability/dashboard) guide for details. +[Dashboard](/guides/dashboard) guide for details. ## Events system diff --git a/docs/content/docs/guides/resources/observability.mdx b/docs/content/docs/guides/resources/observability.mdx index 35c5a19..e9a25b7 100644 --- a/docs/content/docs/guides/resources/observability.mdx +++ b/docs/content/docs/guides/resources/observability.mdx @@ -122,7 +122,7 @@ Start the dashboard: taskito dashboard --app myapp.tasks:queue ``` -See the [Web Dashboard](/guides/observability/dashboard) guide for +See the [Web Dashboard](/guides/dashboard) guide for full dashboard documentation. ## CLI commands diff --git a/docs/public/screenshots/dashboard/auth-login.png b/docs/public/screenshots/dashboard/auth-login.png new file mode 100644 index 0000000..fbcfe56 Binary files /dev/null and b/docs/public/screenshots/dashboard/auth-login.png differ diff --git a/docs/public/screenshots/dashboard/auth-setup.png b/docs/public/screenshots/dashboard/auth-setup.png new file mode 100644 index 0000000..98c0826 Binary files /dev/null and b/docs/public/screenshots/dashboard/auth-setup.png differ diff --git a/docs/public/screenshots/dashboard/jobs.png b/docs/public/screenshots/dashboard/jobs.png new file mode 100644 index 0000000..e6b9bf3 Binary files /dev/null and b/docs/public/screenshots/dashboard/jobs.png differ diff --git a/docs/public/screenshots/dashboard/overview.png b/docs/public/screenshots/dashboard/overview.png new file mode 100644 index 0000000..ef6e3a3 Binary files /dev/null and b/docs/public/screenshots/dashboard/overview.png differ diff --git a/docs/public/screenshots/dashboard/queues.png b/docs/public/screenshots/dashboard/queues.png new file mode 100644 index 0000000..8e79761 Binary files /dev/null and b/docs/public/screenshots/dashboard/queues.png differ diff --git a/docs/public/screenshots/dashboard/task-edit-middleware.png b/docs/public/screenshots/dashboard/task-edit-middleware.png new file mode 100644 index 0000000..c8498b6 Binary files /dev/null and b/docs/public/screenshots/dashboard/task-edit-middleware.png differ diff --git a/docs/public/screenshots/dashboard/task-edit-overrides.png b/docs/public/screenshots/dashboard/task-edit-overrides.png new file mode 100644 index 0000000..282c768 Binary files /dev/null and b/docs/public/screenshots/dashboard/task-edit-overrides.png differ diff --git a/docs/public/screenshots/dashboard/tasks-list.png b/docs/public/screenshots/dashboard/tasks-list.png new file mode 100644 index 0000000..6b9ad85 Binary files /dev/null and b/docs/public/screenshots/dashboard/tasks-list.png differ diff --git a/docs/public/screenshots/dashboard/webhook-create-dialog.png b/docs/public/screenshots/dashboard/webhook-create-dialog.png new file mode 100644 index 0000000..4ddc615 Binary files /dev/null and b/docs/public/screenshots/dashboard/webhook-create-dialog.png differ diff --git a/docs/public/screenshots/dashboard/webhook-deliveries.png b/docs/public/screenshots/dashboard/webhook-deliveries.png new file mode 100644 index 0000000..3a1a318 Binary files /dev/null and b/docs/public/screenshots/dashboard/webhook-deliveries.png differ diff --git a/docs/public/screenshots/dashboard/webhooks-list.png b/docs/public/screenshots/dashboard/webhooks-list.png new file mode 100644 index 0000000..af342e6 Binary files /dev/null and b/docs/public/screenshots/dashboard/webhooks-list.png differ diff --git a/docs/public/screenshots/dashboard/workers.png b/docs/public/screenshots/dashboard/workers.png new file mode 100644 index 0000000..0535e6e Binary files /dev/null and b/docs/public/screenshots/dashboard/workers.png differ diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 8f87ff7..c8273c7 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -38,7 +38,9 @@ QueueInspectionMixin, QueueLifecycleMixin, QueueLockMixin, + QueueMiddlewareAdminMixin, QueueOperationsMixin, + QueueOverridesMixin, QueuePredicateMixin, QueueResourceMixin, QueueRuntimeConfigMixin, @@ -83,6 +85,8 @@ class Queue( QueueInspectionMixin, QueueOperationsMixin, QueueLockMixin, + QueueMiddlewareAdminMixin, + QueueOverridesMixin, QueueSettingsMixin, QueueWorkflowMixin, AsyncQueueMixin, @@ -223,7 +227,7 @@ def __init__( self._drain_timeout = drain_timeout self._queue_configs: dict[str, dict[str, Any]] = {} self._event_bus = EventBus(max_workers=event_workers) - self._webhook_manager = WebhookManager() + self._webhook_manager = WebhookManager(queue_ref=self) # Proxy handlers self._proxy_registry = ProxyRegistry() diff --git a/py_src/taskito/dashboard/_testing.py b/py_src/taskito/dashboard/_testing.py new file mode 100644 index 0000000..67f7d47 --- /dev/null +++ b/py_src/taskito/dashboard/_testing.py @@ -0,0 +1,83 @@ +"""Shared helpers for dashboard endpoint tests. + +The dashboard requires a logged-in session for every API route once setup +is complete. :class:`AuthedClient` wraps the stdlib ``urllib.request`` so +tests can issue authenticated HTTP calls without repeating the cookie / +CSRF dance. +""" + +from __future__ import annotations + +import json +import urllib.error +import urllib.request +from dataclasses import dataclass +from typing import Any + +from taskito import Queue +from taskito.dashboard.auth import AuthStore, Session + + +@dataclass(frozen=True) +class AuthedClient: + """Stateless HTTP helper that attaches session + CSRF to every call.""" + + base: str + session: Session + + @property + def _cookies(self) -> dict[str, str]: + return { + "taskito_session": self.session.token, + "taskito_csrf": self.session.csrf_token, + } + + def _cookie_header(self) -> str: + return "; ".join(f"{k}={v}" for k, v in self._cookies.items()) + + def get(self, path: str, *, raise_for_status: bool = True) -> Any: + url = self.base + path + req = urllib.request.Request(url, method="GET") + req.add_header("Cookie", self._cookie_header()) + try: + with urllib.request.urlopen(req) as resp: + return json.loads(resp.read() or b"{}") + except urllib.error.HTTPError as e: + if raise_for_status: + raise + return {"status": e.code, "body": json.loads(e.read() or b"{}")} + + def post(self, path: str, body: dict | None = None) -> Any: + return self._mutate("POST", path, body) + + def put(self, path: str, body: dict | None = None) -> Any: + return self._mutate("PUT", path, body) + + def delete(self, path: str) -> Any: + return self._mutate("DELETE", path, None) + + def _mutate(self, method: str, path: str, body: dict | None) -> Any: + url = self.base + path + data = json.dumps(body).encode() if body is not None else b"" + req = urllib.request.Request(url, method=method, data=data) + req.add_header("Cookie", self._cookie_header()) + req.add_header("X-CSRF-Token", self.session.csrf_token) + if body is not None: + req.add_header("Content-Type", "application/json") + with urllib.request.urlopen(req) as resp: + return json.loads(resp.read() or b"{}") + + +def seed_admin_and_session( + queue: Queue, + *, + username: str = "test-admin", + password: str = "test-pass-1234", +) -> Session: + """Create a one-off admin and return a fresh session for it.""" + store = AuthStore(queue) + if store.get_user(username) is None: + store.create_user(username, password, role="admin") + user = store.get_user(username) + assert user is not None + return store.create_session(user) diff --git a/py_src/taskito/dashboard/auth.py b/py_src/taskito/dashboard/auth.py new file mode 100644 index 0000000..a7f2b6b --- /dev/null +++ b/py_src/taskito/dashboard/auth.py @@ -0,0 +1,453 @@ +"""Authentication primitives for the dashboard. + +Users and sessions are persisted through ``Queue.set_setting`` / ``get_setting`` +— the same key/value store that already backs dashboard branding and +integration settings. This avoids new database tables and keeps the auth +feature working uniformly across SQLite, Postgres, and Redis backends. + +Key layout in ``dashboard_settings``: + +- ``auth:users`` — JSON object ``{username: {password_hash, role, ...}}`` +- ``auth:session:`` — JSON object describing one active session +- ``auth:csrf_secret`` — random secret used as a HMAC key for CSRF tokens + +Password hashes use PBKDF2-HMAC-SHA256 (stdlib ``hashlib``) with +600,000 iterations — the OWASP 2023+ baseline for PBKDF2. No third-party +crypto dependency is required. +""" + +from __future__ import annotations + +import hashlib +import hmac +import json +import logging +import secrets +import time +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from taskito.app import Queue + + +logger = logging.getLogger("taskito.dashboard.auth") + +# ── Storage keys ─────────────────────────────────────────────────────── + +USERS_KEY = "auth:users" +SESSION_PREFIX = "auth:session:" +CSRF_SECRET_KEY = "auth:csrf_secret" + +# ── Crypto parameters ────────────────────────────────────────────────── + +PBKDF2_ITERATIONS = 600_000 +PBKDF2_SALT_BYTES = 16 +PBKDF2_HASH_BYTES = 32 +SESSION_TOKEN_BYTES = 32 + +# ── Session lifetime ─────────────────────────────────────────────────── + +DEFAULT_SESSION_TTL_SECONDS = 24 * 60 * 60 # 24h + +# ── Validation ───────────────────────────────────────────────────────── + +USERNAME_MAX_LEN = 64 +PASSWORD_MIN_LEN = 8 +PASSWORD_MAX_LEN = 256 +VALID_ROLES = frozenset({"admin", "viewer"}) + +# Sentinel prefix used in ``password_hash`` for OAuth-only users so +# ``verify_password`` can short-circuit-reject any password attempt. +OAUTH_PASSWORD_HASH_PREFIX = "oauth:" + + +# ── Password hashing ─────────────────────────────────────────────────── + + +def hash_password(password: str) -> str: + """Hash a password with PBKDF2-HMAC-SHA256. + + Returns a self-describing string of the form + ``pbkdf2_sha256$$$`` so the verifier can + parse out the salt and iteration count without separate columns. + """ + salt = secrets.token_bytes(PBKDF2_SALT_BYTES) + digest = hashlib.pbkdf2_hmac( + "sha256", password.encode("utf-8"), salt, PBKDF2_ITERATIONS, PBKDF2_HASH_BYTES + ) + return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt.hex()}${digest.hex()}" + + +def verify_password(password: str, encoded: str) -> bool: + """Constant-time verify a password against the encoded hash.""" + # Sentinel for OAuth-only users — they have no real password and must + # never authenticate via the password endpoint. + if encoded.startswith(OAUTH_PASSWORD_HASH_PREFIX): + return False + try: + scheme, iters_str, salt_hex, hash_hex = encoded.split("$") + except ValueError: + return False + if scheme != "pbkdf2_sha256": + return False + try: + iters = int(iters_str) + salt = bytes.fromhex(salt_hex) + expected = bytes.fromhex(hash_hex) + except ValueError: + return False + candidate = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iters, len(expected)) + return hmac.compare_digest(candidate, expected) + + +# ── Tokens ───────────────────────────────────────────────────────────── + + +def generate_session_token() -> str: + """Cryptographically secure URL-safe session token.""" + return secrets.token_urlsafe(SESSION_TOKEN_BYTES) + + +# ── Data classes ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class User: + """A persisted dashboard user. + + ``email`` and ``display_name`` are populated for users created via the + OAuth flow; for password users they are typically ``None`` until set + by an admin. + """ + + username: str + password_hash: str + role: str + created_at: int + last_login_at: int | None = None + email: str | None = None + display_name: str | None = None + + @property + def is_oauth(self) -> bool: + return self.password_hash.startswith(OAUTH_PASSWORD_HASH_PREFIX) + + +@dataclass(frozen=True) +class Session: + """An active dashboard session.""" + + token: str + username: str + role: str + created_at: int + expires_at: int + csrf_token: str + + def is_expired(self, now: int | None = None) -> bool: + return (now if now is not None else int(time.time())) >= self.expires_at + + +# ── Validation helpers ───────────────────────────────────────────────── + + +def _validate_username(username: str) -> None: + if not username: + raise ValueError("username must not be empty") + if len(username) > USERNAME_MAX_LEN: + raise ValueError(f"username must be <= {USERNAME_MAX_LEN} chars") + if not all(c.isalnum() or c in "._-" for c in username): + raise ValueError("username may only contain letters, digits, '.', '_', or '-'") + + +def _validate_password(password: str) -> None: + if len(password) < PASSWORD_MIN_LEN: + raise ValueError(f"password must be >= {PASSWORD_MIN_LEN} chars") + if len(password) > PASSWORD_MAX_LEN: + raise ValueError(f"password must be <= {PASSWORD_MAX_LEN} chars") + + +def _validate_role(role: str) -> None: + if role not in VALID_ROLES: + raise ValueError(f"role must be one of {sorted(VALID_ROLES)}") + + +def _oauth_bootstrap_role( + *, + email: str | None, + email_verified: bool, + admin_emails: tuple[str, ...], + user_table_empty: bool, +) -> str: + """Decide the role for a freshly-created OAuth user. + + Order: any path to ``admin`` requires a verified email (defence against + spoofed claims). If an explicit admin list is configured, only listed + emails get ``admin`` — the first-user-wins fallback is skipped. With no + admin list, the very first user (empty table) gets ``admin``, everyone + else gets ``viewer``. + """ + if not email_verified or not email: + return "viewer" + normalised = email.lower() + if admin_emails: + if normalised in {e.lower() for e in admin_emails}: + return "admin" + return "viewer" + if user_table_empty: + return "admin" + return "viewer" + + +# ── Auth store ───────────────────────────────────────────────────────── + + +class AuthStore: + """Read/write users and sessions through ``Queue``'s settings store.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + # ── Users ────────────────────────────────────────────────────── + + def _load_users(self) -> dict[str, dict[str, object]]: + raw = self._queue.get_setting(USERS_KEY) + if not raw: + return {} + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("auth:users entry is not valid JSON; treating as empty") + return {} + return data if isinstance(data, dict) else {} + + def _save_users(self, users: dict[str, dict[str, object]]) -> None: + self._queue.set_setting(USERS_KEY, json.dumps(users, separators=(",", ":"))) + + def count_users(self) -> int: + return len(self._load_users()) + + def list_users(self) -> list[User]: + return [self._row_to_user(name, row) for name, row in self._load_users().items()] + + def get_user(self, username: str) -> User | None: + row = self._load_users().get(username) + return self._row_to_user(username, row) if row else None + + def create_user(self, username: str, password: str, role: str = "admin") -> User: + _validate_username(username) + _validate_password(password) + _validate_role(role) + users = self._load_users() + if username in users: + raise ValueError(f"user '{username}' already exists") + now = int(time.time()) + users[username] = { + "password_hash": hash_password(password), + "role": role, + "created_at": now, + "last_login_at": None, + } + self._save_users(users) + return self._row_to_user(username, users[username]) + + def update_password(self, username: str, new_password: str) -> None: + _validate_password(new_password) + users = self._load_users() + if username not in users: + raise ValueError(f"user '{username}' does not exist") + users[username]["password_hash"] = hash_password(new_password) + self._save_users(users) + + def delete_user(self, username: str) -> bool: + users = self._load_users() + if username not in users: + return False + del users[username] + self._save_users(users) + return True + + def authenticate(self, username: str, password: str) -> User | None: + """Return the user iff username+password match; updates last_login_at.""" + users = self._load_users() + row = users.get(username) + if not row: + # Run a dummy verify against a fixed hash to keep timing constant + # for unknown vs. known usernames. + verify_password(password, _DUMMY_HASH) + return None + if not verify_password(password, str(row["password_hash"])): + return None + row["last_login_at"] = int(time.time()) + users[username] = row + self._save_users(users) + return self._row_to_user(username, row) + + @staticmethod + def _row_to_user(username: str, row: dict[str, object] | None) -> User: + assert row is not None + created_raw = row["created_at"] + last_raw = row.get("last_login_at") + email_raw = row.get("email") + name_raw = row.get("display_name") + return User( + username=username, + password_hash=str(row["password_hash"]), + role=str(row["role"]), + created_at=int(created_raw) if isinstance(created_raw, (int, float, str)) else 0, + last_login_at=(int(last_raw) if isinstance(last_raw, (int, float, str)) else None), + email=str(email_raw) if isinstance(email_raw, str) and email_raw else None, + display_name=str(name_raw) if isinstance(name_raw, str) and name_raw else None, + ) + + # ── OAuth users ──────────────────────────────────────────────── + + def get_or_create_oauth_user( + self, + slot: str, + subject: str, + email: str | None, + name: str | None, + email_verified: bool, + admin_emails: tuple[str, ...] = (), + ) -> User: + """Look up or create the User row backing an OAuth identity. + + Username is ``f"{slot}:{subject}"``. On first sight, the role is + assigned by :func:`_oauth_bootstrap_role`. On subsequent logins, + the role is left alone but ``email`` / ``display_name`` are refreshed + from the latest provider claims. + """ + username = f"{slot}:{subject}" + users = self._load_users() + existing = users.get(username) + if existing is not None: + if email and existing.get("email") != email: + existing["email"] = email + if name and existing.get("display_name") != name: + existing["display_name"] = name + existing["last_login_at"] = int(time.time()) + users[username] = existing + self._save_users(users) + return self._row_to_user(username, existing) + + role = _oauth_bootstrap_role( + email=email, + email_verified=email_verified, + admin_emails=admin_emails, + user_table_empty=not users, + ) + now = int(time.time()) + users[username] = { + "password_hash": f"{OAUTH_PASSWORD_HASH_PREFIX}{slot}", + "role": role, + "created_at": now, + "last_login_at": now, + "email": email, + "display_name": name, + } + self._save_users(users) + return self._row_to_user(username, users[username]) + + # ── Sessions ─────────────────────────────────────────────────── + + def create_session( + self, user: User, ttl_seconds: int = DEFAULT_SESSION_TTL_SECONDS + ) -> Session: + now = int(time.time()) + token = generate_session_token() + session = Session( + token=token, + username=user.username, + role=user.role, + created_at=now, + expires_at=now + ttl_seconds, + csrf_token=generate_session_token(), + ) + self._queue.set_setting( + SESSION_PREFIX + token, + json.dumps( + {k: v for k, v in asdict(session).items() if k != "token"}, + separators=(",", ":"), + ), + ) + return session + + def get_session(self, token: str) -> Session | None: + if not token: + return None + raw = self._queue.get_setting(SESSION_PREFIX + token) + if not raw: + return None + try: + data = json.loads(raw) + except json.JSONDecodeError: + return None + try: + session = Session(token=token, **data) + except TypeError: + return None + if session.is_expired(): + self.delete_session(token) + return None + return session + + def delete_session(self, token: str) -> bool: + if not token: + return False + return self._queue.delete_setting(SESSION_PREFIX + token) + + def prune_expired_sessions(self) -> int: + """Best-effort cleanup of expired session entries. Returns count removed.""" + now = int(time.time()) + removed = 0 + for key, value in self._queue.list_settings().items(): + if not key.startswith(SESSION_PREFIX): + continue + try: + data = json.loads(value) + expires_at = int(data.get("expires_at", 0)) + except (json.JSONDecodeError, TypeError, ValueError): + continue + if expires_at <= now: + self._queue.delete_setting(key) + removed += 1 + return removed + + +# Fixed hash used to keep authentication timing constant for unknown users. +# Value computed once with a throw-away password — never used for real auth. +_DUMMY_HASH = ( + "pbkdf2_sha256$600000$" + "00000000000000000000000000000000$" + "0000000000000000000000000000000000000000000000000000000000000000" +) + + +# ── Bootstrap from environment ───────────────────────────────────────── + + +def bootstrap_admin_from_env(queue: Queue) -> User | None: + """Idempotently create the first admin from environment variables. + + If ``TASKITO_DASHBOARD_ADMIN_USER`` and ``TASKITO_DASHBOARD_ADMIN_PASSWORD`` + are set AND the user does not exist yet, create it. Safe to call on every + startup — does nothing if the user already exists. + """ + import os + + username = os.environ.get("TASKITO_DASHBOARD_ADMIN_USER") + password = os.environ.get("TASKITO_DASHBOARD_ADMIN_PASSWORD") + if not username or not password: + return None + store = AuthStore(queue) + if store.get_user(username): + return None + try: + user = store.create_user(username, password, role="admin") + except ValueError as e: + logger.warning("Failed to bootstrap admin %r from env: %s", username, e) + return None + logger.info("Bootstrapped admin user %r from environment", username) + return user diff --git a/py_src/taskito/dashboard/delivery_store.py b/py_src/taskito/dashboard/delivery_store.py new file mode 100644 index 0000000..0efc8d9 --- /dev/null +++ b/py_src/taskito/dashboard/delivery_store.py @@ -0,0 +1,208 @@ +"""Persistent webhook delivery log. + +Each subscription gets its own JSON list under the key +``webhooks:deliveries:{subscription_id}`` in the ``dashboard_settings`` +table. The store is append-only with FIFO eviction once the per-webhook +cap is hit (default 200 entries) — enough to debug recent activity +without unbounded growth. + +The structure: + + [ + { + "id": "uuid", + "subscription_id": "sub-uuid", + "event": "job.completed", + "task_name": "send_email" | null, + "job_id": "abc123" | null, + "payload": {...}, + "status": "delivered" | "failed" | "dead", + "attempts": 3, + "response_code": 200 | null, + "response_body": "..." | null, + "latency_ms": 42, + "error": "..." | null, + "created_at": 1234567890000, + "completed_at": 1234567890420 + }, + ... + ] + +Records are inserted in chronological order; listing reverses for newest-first. +""" + +from __future__ import annotations + +import json +import logging +import time +import uuid +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from taskito.app import Queue + + +DELIVERY_PREFIX = "webhooks:deliveries:" +DEFAULT_MAX_PER_WEBHOOK = 200 +RESPONSE_BODY_MAX_BYTES = 2048 + +logger = logging.getLogger("taskito.dashboard.deliveries") + + +@dataclass +class DeliveryRecord: + """A single attempted webhook delivery.""" + + id: str + subscription_id: str + event: str + payload: dict[str, Any] + task_name: str | None = None + job_id: str | None = None + status: str = "pending" # "delivered" | "failed" | "dead" | "pending" + attempts: int = 0 + response_code: int | None = None + response_body: str | None = None + latency_ms: int | None = None + error: str | None = None + created_at: int = field(default_factory=lambda: int(time.time() * 1000)) + completed_at: int | None = None + + @classmethod + def from_row(cls, row: dict[str, Any]) -> DeliveryRecord: + return cls( + id=str(row["id"]), + subscription_id=str(row["subscription_id"]), + event=str(row["event"]), + payload=dict(row.get("payload") or {}), + task_name=row.get("task_name"), + job_id=row.get("job_id"), + status=str(row.get("status", "pending")), + attempts=int(row.get("attempts", 0)), + response_code=row.get("response_code"), + response_body=row.get("response_body"), + latency_ms=row.get("latency_ms"), + error=row.get("error"), + created_at=int(row.get("created_at", 0)), + completed_at=row.get("completed_at"), + ) + + +def _new_id() -> str: + return uuid.uuid4().hex + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def _truncate(body: str | None, *, max_bytes: int = RESPONSE_BODY_MAX_BYTES) -> str | None: + if body is None: + return None + encoded = body.encode("utf-8", errors="replace") + if len(encoded) <= max_bytes: + return body + return encoded[:max_bytes].decode("utf-8", errors="replace") + "…" + + +class DeliveryStore: + """List/insert/update delivery records keyed by subscription id.""" + + def __init__(self, queue: Queue, *, max_per_webhook: int = DEFAULT_MAX_PER_WEBHOOK) -> None: + self._queue = queue + self._max = max_per_webhook + + # ── Internal ──────────────────────────────────────────────── + + def _key(self, subscription_id: str) -> str: + return DELIVERY_PREFIX + subscription_id + + def _load(self, subscription_id: str) -> list[dict[str, Any]]: + raw = self._queue.get_setting(self._key(subscription_id)) + if not raw: + return [] + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("delivery log for %s is corrupt; resetting", subscription_id) + return [] + return data if isinstance(data, list) else [] + + def _save(self, subscription_id: str, rows: list[dict[str, Any]]) -> None: + self._queue.set_setting( + self._key(subscription_id), + json.dumps(rows, separators=(",", ":")), + ) + + # ── Public API ───────────────────────────────────────────── + + def record_attempt( + self, + subscription_id: str, + event: str, + payload: dict[str, Any], + *, + status: str, + attempts: int, + response_code: int | None = None, + response_body: str | None = None, + latency_ms: int | None = None, + error: str | None = None, + task_name: str | None = None, + job_id: str | None = None, + ) -> DeliveryRecord: + """Append a delivery row and trim to the per-webhook cap.""" + now = _now_ms() + record = DeliveryRecord( + id=_new_id(), + subscription_id=subscription_id, + event=event, + payload=payload, + task_name=task_name, + job_id=job_id, + status=status, + attempts=attempts, + response_code=response_code, + response_body=_truncate(response_body), + latency_ms=latency_ms, + error=error, + created_at=now, + completed_at=now if status != "pending" else None, + ) + rows = self._load(subscription_id) + rows.append(asdict(record)) + if len(rows) > self._max: + rows = rows[-self._max :] + self._save(subscription_id, rows) + return record + + def list_for( + self, + subscription_id: str, + *, + status: str | None = None, + event: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[DeliveryRecord]: + rows = list(reversed(self._load(subscription_id))) # newest first + if status: + rows = [r for r in rows if r.get("status") == status] + if event: + rows = [r for r in rows if r.get("event") == event] + page = rows[offset : offset + limit] + return [DeliveryRecord.from_row(r) for r in page] + + def get(self, subscription_id: str, delivery_id: str) -> DeliveryRecord | None: + for row in self._load(subscription_id): + if row.get("id") == delivery_id: + return DeliveryRecord.from_row(row) + return None + + def delete_for(self, subscription_id: str) -> bool: + return self._queue.delete_setting(self._key(subscription_id)) + + def count_for(self, subscription_id: str) -> int: + return len(self._load(subscription_id)) diff --git a/py_src/taskito/dashboard/handlers/auth.py b/py_src/taskito/dashboard/handlers/auth.py new file mode 100644 index 0000000..e9c4b29 --- /dev/null +++ b/py_src/taskito/dashboard/handlers/auth.py @@ -0,0 +1,128 @@ +"""Authentication route handlers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.auth import AuthStore +from taskito.dashboard.errors import _BadRequest, _NotFound + +if TYPE_CHECKING: + from taskito.app import Queue + from taskito.dashboard.request_context import RequestContext + + +def _require_field(body: dict, key: str) -> str: + value = body.get(key) + if not isinstance(value, str) or not value: + raise _BadRequest(f"missing or empty field '{key}'") + return value + + +def _serialize_user(user: Any) -> dict[str, Any]: + return { + "username": user.username, + "role": user.role, + "created_at": user.created_at, + "last_login_at": user.last_login_at, + } + + +def _serialize_session(session: Any) -> dict[str, Any]: + return { + "username": session.username, + "role": session.role, + "expires_at": session.expires_at, + "csrf_token": session.csrf_token, + } + + +def handle_auth_status(queue: Queue, _qs: dict) -> dict[str, Any]: + """Public endpoint: tells the SPA whether setup is required. + + Returns ``{setup_required: bool}``. The SPA uses this on cold-load to + decide between showing the setup page and the login page. Provider + listing is fetched separately via ``GET /api/auth/providers`` so this + endpoint stays free of any OAuth dependency. + """ + return {"setup_required": AuthStore(queue).count_users() == 0} + + +def handle_setup(queue: Queue, body: dict) -> dict[str, Any]: + """Create the first admin user. Only callable when zero users exist.""" + store = AuthStore(queue) + if store.count_users() > 0: + raise _BadRequest("setup already complete") + username = _require_field(body, "username") + password = _require_field(body, "password") + try: + user = store.create_user(username, password, role="admin") + except ValueError as e: + raise _BadRequest(str(e)) from None + return {"user": _serialize_user(user)} + + +def handle_login(queue: Queue, body: dict) -> dict[str, Any]: + """Verify credentials and create a session. + + Returns ``{user, session}`` on success. The caller (server) reads the + session token from the returned object and sets it as an HttpOnly cookie. + On failure raises ``_BadRequest`` to drive a 400 — we intentionally + return the same generic error for unknown user / bad password to avoid + revealing which one was wrong. + """ + store = AuthStore(queue) + if store.count_users() == 0: + raise _BadRequest("setup_required") + username = _require_field(body, "username") + password = _require_field(body, "password") + user = store.authenticate(username, password) + if not user: + raise _BadRequest("invalid_credentials") + session = store.create_session(user) + return { + "user": _serialize_user(user), + "session": _serialize_session(session) | {"token": session.token}, + } + + +def handle_logout(queue: Queue, ctx: RequestContext) -> dict[str, bool]: + """Invalidate the current session. Idempotent.""" + if not ctx.session: + return {"ok": True} + AuthStore(queue).delete_session(ctx.session.token) + return {"ok": True} + + +def handle_whoami(queue: Queue, ctx: RequestContext) -> dict[str, Any]: + """Return the current user, or 401-equivalent if no session.""" + if not ctx.session: + raise _NotFound("not_authenticated") + store = AuthStore(queue) + user = store.get_user(ctx.session.username) + if not user: + # Session valid but user deleted — invalidate and treat as logged out. + store.delete_session(ctx.session.token) + raise _NotFound("not_authenticated") + return { + "user": _serialize_user(user), + "csrf_token": ctx.session.csrf_token, + "expires_at": ctx.session.expires_at, + } + + +def handle_change_password(queue: Queue, body: dict, ctx: RequestContext) -> dict[str, bool]: + """Change the current user's password. Requires the old password.""" + if not ctx.session: + raise _BadRequest("not_authenticated") + old_password = _require_field(body, "old_password") + new_password = _require_field(body, "new_password") + store = AuthStore(queue) + user = store.authenticate(ctx.session.username, old_password) + if not user: + raise _BadRequest("invalid_credentials") + try: + store.update_password(user.username, new_password) + except ValueError as e: + raise _BadRequest(str(e)) from None + return {"ok": True} diff --git a/py_src/taskito/dashboard/handlers/middleware.py b/py_src/taskito/dashboard/handlers/middleware.py new file mode 100644 index 0000000..e0fd85b --- /dev/null +++ b/py_src/taskito/dashboard/handlers/middleware.py @@ -0,0 +1,62 @@ +"""Middleware discovery + per-task enable/disable endpoints.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.errors import _BadRequest, _NotFound + +if TYPE_CHECKING: + from taskito.app import Queue + + +def handle_list_middleware(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + """Return every registered middleware with its scopes.""" + return queue.list_middleware() + + +def handle_get_task_middleware(queue: Queue, _qs: dict, task_name: str) -> dict[str, Any]: + """Return the middleware chain that fires for ``task_name`` with each + entry's enabled/disabled state.""" + chain = queue._get_middleware_chain(task_name) + disabled = set(queue.get_disabled_middleware_for(task_name)) + # Build the full would-fire chain INCLUDING disabled entries so the UI + # can render every toggle. + base_chain = queue._global_middleware + queue._task_middleware.get(task_name, []) + entries: list[dict[str, Any]] = [] + chain_names = {getattr(mw, "name", "") for mw in chain} + for mw in base_chain: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + entries.append( + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "disabled": name in disabled, + "effective": name in chain_names, + } + ) + return {"task": task_name, "middleware": entries} + + +def handle_put_task_middleware(queue: Queue, body: dict, ids: tuple[str, str]) -> dict[str, Any]: + task_name, mw_name = ids + if not isinstance(body, dict) or "enabled" not in body: + raise _BadRequest('body must include {"enabled": bool}') + if not isinstance(body["enabled"], bool): + raise _BadRequest("'enabled' must be a boolean") + # Confirm the middleware exists in the relevant chain so a typo doesn't + # silently write a no-op disable entry. + base_chain = queue._global_middleware + queue._task_middleware.get(task_name, []) + names = {getattr(mw, "name", "") for mw in base_chain} + if mw_name not in names: + raise _NotFound(f"middleware '{mw_name}' is not registered on task '{task_name}'") + if body["enabled"]: + new = queue.enable_middleware_for_task(task_name, mw_name) + else: + new = queue.disable_middleware_for_task(task_name, mw_name) + return {"task": task_name, "disabled": new} + + +def handle_delete_task_middleware(queue: Queue, task_name: str) -> dict[str, bool]: + """Clear ALL disables for a task — every middleware fires again.""" + return {"cleared": queue.clear_middleware_disables(task_name)} diff --git a/py_src/taskito/dashboard/handlers/oauth.py b/py_src/taskito/dashboard/handlers/oauth.py new file mode 100644 index 0000000..a9b3557 --- /dev/null +++ b/py_src/taskito/dashboard/handlers/oauth.py @@ -0,0 +1,115 @@ +"""HTTP handlers for the OAuth login flow. + +These handlers are not JSON-producing like the rest of ``handlers/`` — +they emit 302 redirects (and, on a successful callback, set the session +cookies). The server wires them into ``_handle_get`` directly rather +than through the generic JSON dispatcher. + +The handlers themselves are network-IO-free aside from what the wrapped +:class:`OAuthFlow` does internally. They translate provider/flow +exceptions to dashboard ``_BadRequest`` / ``_NotFound`` for the server's +error machinery to pick up, and they return :class:`OAuthRedirect` — +a tiny adapter type that tells the server "emit 302 to URL, optionally +with these cookies attached". +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from taskito.dashboard.errors import _BadRequest, _NotFound +from taskito.dashboard.oauth.identity import ( + AllowlistDenied, + IdentityFetchError, + ProviderNotConfigured, + StateValidationError, +) + +if TYPE_CHECKING: + from taskito.app import Queue + from taskito.dashboard.auth import Session + from taskito.dashboard.oauth.flow import OAuthFlow + + +@dataclass(frozen=True) +class OAuthRedirect: + """Server adapter: emit ``302 Location: url``. + + ``session`` is set on a successful callback so the server can attach + the same ``taskito_session`` + ``taskito_csrf`` cookies it sets for + password login. On the ``/start`` redirect ``session`` is ``None``. + """ + + url: str + session: Session | None = None + status: int = 302 + + +def handle_providers(queue: Queue, _qs: dict, flow: OAuthFlow | None) -> dict: + """List configured providers + whether password auth is enabled. + + Returns ``{password_enabled: bool, providers: [{slot, label, type}]}``. + Always callable; returns ``providers: []`` when OAuth is not configured. + """ + if flow is None: + return {"password_enabled": True, "providers": []} + return { + "password_enabled": flow.password_auth_enabled, + "providers": flow.providers_listing(), + } + + +def handle_start( + queue: Queue, + qs: dict[str, list[str]], + slot: str, + flow: OAuthFlow | None, +) -> OAuthRedirect: + """Begin an OAuth login: mint state, return a 302 to the provider URL.""" + if flow is None: + raise _NotFound("oauth_not_configured") + next_values = qs.get("next") or [] + next_url = next_values[0] if next_values else None + try: + provider_url = flow.start(slot, next_url) + except ProviderNotConfigured as e: + raise _NotFound(str(e)) from None + return OAuthRedirect(url=provider_url) + + +def handle_callback( + queue: Queue, + qs: dict[str, list[str]], + slot: str, + flow: OAuthFlow | None, +) -> OAuthRedirect: + """Land an OAuth login: verify state, create a session, redirect home. + + The returned :class:`OAuthRedirect` carries the new :class:`Session`; + the server attaches the standard ``taskito_session`` + ``taskito_csrf`` + cookies before sending the 302. + """ + if flow is None: + raise _NotFound("oauth_not_configured") + + def _first(name: str) -> str | None: + values = qs.get(name) or [] + return values[0] if values else None + + code = _first("code") + state_token = _first("state") + error = _first("error") + try: + session, next_url = flow.handle_callback( + slot, code=code, state_token=state_token, error=error + ) + except ProviderNotConfigured as e: + raise _NotFound(str(e)) from None + except StateValidationError as e: + raise _BadRequest(f"oauth_state_invalid: {e}") from None + except IdentityFetchError as e: + raise _BadRequest(f"oauth_identity_failed: {e}") from None + except AllowlistDenied as e: + raise _BadRequest(f"oauth_allowlist_denied: {e}") from None + return OAuthRedirect(url=next_url, session=session) diff --git a/py_src/taskito/dashboard/handlers/overrides.py b/py_src/taskito/dashboard/handlers/overrides.py new file mode 100644 index 0000000..c125441 --- /dev/null +++ b/py_src/taskito/dashboard/handlers/overrides.py @@ -0,0 +1,95 @@ +"""Task & queue override endpoints.""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.errors import _BadRequest, _NotFound +from taskito.dashboard.overrides_store import ( + QUEUE_OVERRIDE_FIELDS, + TASK_OVERRIDE_FIELDS, + OverridesStore, +) + +if TYPE_CHECKING: + from taskito.app import Queue + + +def handle_list_tasks(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + """Return every registered task with decorator defaults + active override.""" + return queue.registered_tasks() + + +def handle_list_queues(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + return queue.registered_queues() + + +def _coerce_override_body(body: Any, allowed: frozenset[str]) -> dict[str, Any]: + if not isinstance(body, dict): + raise _BadRequest("body must be a JSON object") + unknown = set(body) - allowed + if unknown: + raise _BadRequest( + f"unknown override fields: {sorted(unknown)}; allowed: {sorted(allowed)}" + ) + return body + + +# ── Task override endpoints ─────────────────────────────────────────── + + +def handle_get_task_override(queue: Queue, _qs: dict, task_name: str) -> dict[str, Any]: + override = OverridesStore(queue).get_task(task_name) + if override is None: + raise _NotFound(f"no override set for task '{task_name}'") + return asdict(override) + + +def handle_put_task_override(queue: Queue, body: dict, task_name: str) -> dict[str, Any]: + fields = _coerce_override_body(body, TASK_OVERRIDE_FIELDS) + try: + override = OverridesStore(queue).set_task(task_name, fields) + except ValueError as e: + raise _BadRequest(str(e)) from None + return asdict(override) + + +def handle_delete_task_override(queue: Queue, task_name: str) -> dict[str, bool]: + removed = OverridesStore(queue).clear_task(task_name) + return {"cleared": removed} + + +# ── Queue override endpoints ────────────────────────────────────────── + + +def handle_get_queue_override(queue: Queue, _qs: dict, queue_name: str) -> dict[str, Any]: + override = OverridesStore(queue).get_queue(queue_name) + if override is None: + raise _NotFound(f"no override set for queue '{queue_name}'") + return asdict(override) + + +def handle_put_queue_override(queue: Queue, body: dict, queue_name: str) -> dict[str, Any]: + fields = _coerce_override_body(body, QUEUE_OVERRIDE_FIELDS) + try: + override = OverridesStore(queue).set_queue(queue_name, fields) + except ValueError as e: + raise _BadRequest(str(e)) from None + # Reflect "paused" immediately by touching the paused_queues store + # (this state DOES propagate to a running worker — independent of the + # static override consumed at worker startup). + if "paused" in fields: + try: + if fields["paused"]: + queue.pause(queue_name) + else: + queue.resume(queue_name) + except Exception: # pragma: no cover - safety net only + pass + return asdict(override) + + +def handle_delete_queue_override(queue: Queue, queue_name: str) -> dict[str, bool]: + removed = OverridesStore(queue).clear_queue(queue_name) + return {"cleared": removed} diff --git a/py_src/taskito/dashboard/handlers/webhook_deliveries.py b/py_src/taskito/dashboard/handlers/webhook_deliveries.py new file mode 100644 index 0000000..aa5bbe4 --- /dev/null +++ b/py_src/taskito/dashboard/handlers/webhook_deliveries.py @@ -0,0 +1,111 @@ +"""Webhook delivery log endpoints (list / get / replay).""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.delivery_store import DeliveryRecord, DeliveryStore +from taskito.dashboard.errors import _BadRequest, _NotFound +from taskito.dashboard.webhook_store import WebhookSubscriptionStore + +if TYPE_CHECKING: + from taskito.app import Queue + + +_MAX_PAGE_SIZE = 200 + + +def _serialize(record: DeliveryRecord) -> dict[str, Any]: + return asdict(record) + + +def _parse_int_param(qs: dict, name: str, default: int, *, minimum: int = 0) -> int: + raw = qs.get(name, [None])[0] + if raw is None or raw == "": + return default + try: + value = int(raw) + except ValueError: + raise _BadRequest(f"{name} must be an integer") from None + if value < minimum: + raise _BadRequest(f"{name} must be >= {minimum}") + return value + + +def _ensure_subscription(queue: Queue, subscription_id: str) -> None: + sub = WebhookSubscriptionStore(queue).get(subscription_id) + if sub is None: + raise _NotFound(f"webhook '{subscription_id}' not found") + + +def handle_list_deliveries(queue: Queue, qs: dict, subscription_id: str) -> dict[str, Any]: + """List recent deliveries for a subscription. Supports ``status``, + ``event``, ``limit``, and ``offset`` query parameters.""" + _ensure_subscription(queue, subscription_id) + + status = qs.get("status", [None])[0] + if status is not None and status not in {"delivered", "failed", "dead", "pending"}: + raise _BadRequest("status must be one of: delivered, failed, dead, pending") + event = qs.get("event", [None])[0] + + limit = min(_parse_int_param(qs, "limit", 50, minimum=1), _MAX_PAGE_SIZE) + offset = _parse_int_param(qs, "offset", 0) + + store = DeliveryStore(queue) + items = store.list_for(subscription_id, status=status, event=event, limit=limit, offset=offset) + return { + "items": [_serialize(r) for r in items], + "limit": limit, + "offset": offset, + "total": store.count_for(subscription_id), + } + + +def handle_get_delivery( + queue: Queue, _qs: dict, sub_and_delivery_id: tuple[str, str] +) -> dict[str, Any]: + subscription_id, delivery_id = sub_and_delivery_id + record = DeliveryStore(queue).get(subscription_id, delivery_id) + if record is None: + raise _NotFound(f"delivery '{delivery_id}' not found") + return _serialize(record) + + +def handle_replay_delivery(queue: Queue, sub_and_delivery_id: tuple[str, str]) -> dict[str, Any]: + """Re-enqueue a stored delivery's original payload as a fresh attempt. + + The replay creates a NEW delivery record on top of the existing one + so the audit trail is preserved. Returns the new delivery's id and + the synchronous HTTP status from the first attempt. + """ + subscription_id, delivery_id = sub_and_delivery_id + sub = WebhookSubscriptionStore(queue).get(subscription_id) + if sub is None: + raise _NotFound(f"webhook '{subscription_id}' not found") + record = DeliveryStore(queue).get(subscription_id, delivery_id) + if record is None: + raise _NotFound(f"delivery '{delivery_id}' not found") + + from taskito.webhooks import WebhookManager + + runtime = WebhookManager._subscription_to_runtime(sub) + payload = {**record.payload, "replay_of": record.id} + status = queue._webhook_manager.deliver_now(runtime, payload) + # deliver_now does NOT write to the log. Record a replay entry so the + # operator can see it appear in the deliveries list. + DeliveryStore(queue).record_attempt( + subscription_id, + event=str(payload.get("event", record.event)), + payload=payload, + status="delivered" if status is not None and status < 400 else "failed", + attempts=1, + response_code=status, + task_name=record.task_name, + job_id=record.job_id, + ) + return { + "replayed_of": record.id, + "status": status, + "delivered": status is not None and status < 400, + } diff --git a/py_src/taskito/dashboard/handlers/webhooks.py b/py_src/taskito/dashboard/handlers/webhooks.py new file mode 100644 index 0000000..47a3abc --- /dev/null +++ b/py_src/taskito/dashboard/handlers/webhooks.py @@ -0,0 +1,241 @@ +"""Webhook subscription CRUD endpoints.""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.errors import _BadRequest, _NotFound +from taskito.dashboard.url_safety import UnsafeWebhookUrl, validate_webhook_url +from taskito.dashboard.webhook_store import ( + WebhookSubscription, + WebhookSubscriptionStore, + generate_secret, +) +from taskito.events import EventType + +if TYPE_CHECKING: + from taskito.app import Queue + + +_VALID_EVENT_VALUES = frozenset(e.value for e in EventType) + + +# ── Serialization ───────────────────────────────────────────────────── + + +def _serialize( + subscription: WebhookSubscription, *, reveal_secret: bool = False +) -> dict[str, Any]: + """Convert to a JSON-safe dict. The raw secret is redacted unless the + caller is ``reveal_secret``-ing (used by the create and rotate endpoints, + which need to surface the value to the user exactly once).""" + row = asdict(subscription) + secret = row.pop("secret", None) + row["has_secret"] = bool(secret) + if reveal_secret and secret: + row["secret"] = secret + return row + + +# ── Validation helpers ──────────────────────────────────────────────── + + +def _require_str(body: dict, key: str) -> str: + value = body.get(key) + if not isinstance(value, str) or not value: + raise _BadRequest(f"missing or empty field '{key}'") + return value + + +def _coerce_event_list(value: Any) -> list[str]: + if value is None: + return [] + if not isinstance(value, list): + raise _BadRequest("events must be a list of event type strings") + events: list[str] = [] + for item in value: + if not isinstance(item, str): + raise _BadRequest("events must contain only strings") + if item not in _VALID_EVENT_VALUES: + raise _BadRequest(f"unknown event type {item!r}") + events.append(item) + return events + + +def _coerce_task_filter(value: Any) -> list[str] | None: + if value is None: + return None + if not isinstance(value, list): + raise _BadRequest("task_filter must be a list of task names or null") + out: list[str] = [] + for item in value: + if not isinstance(item, str) or not item: + raise _BadRequest("task_filter entries must be non-empty strings") + out.append(item) + return out + + +def _coerce_headers(value: Any) -> dict[str, str]: + if value is None: + return {} + if not isinstance(value, dict): + raise _BadRequest("headers must be an object of string→string") + out: dict[str, str] = {} + for k, v in value.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise _BadRequest("headers must map strings to strings") + out[k] = v + return out + + +def _coerce_positive_int(value: Any, name: str, default: int) -> int: + if value is None: + return default + if not isinstance(value, int) or isinstance(value, bool) or value < 0: + raise _BadRequest(f"{name} must be a non-negative integer") + return value + + +def _coerce_positive_float(value: Any, name: str, default: float) -> float: + if value is None: + return default + if isinstance(value, bool) or not isinstance(value, (int, float)) or value <= 0: + raise _BadRequest(f"{name} must be a positive number") + return float(value) + + +# ── Handlers ────────────────────────────────────────────────────────── + + +def handle_list_webhooks(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + return [_serialize(s) for s in WebhookSubscriptionStore(queue).list_all()] + + +def handle_get_webhook(queue: Queue, _qs: dict, subscription_id: str) -> dict[str, Any]: + sub = WebhookSubscriptionStore(queue).get(subscription_id) + if sub is None: + raise _NotFound(f"webhook '{subscription_id}' not found") + return _serialize(sub) + + +def handle_create_webhook(queue: Queue, body: dict) -> dict[str, Any]: + if not isinstance(body, dict): + raise _BadRequest("body must be a JSON object") + url = _require_str(body, "url") + try: + validate_webhook_url(url) + except UnsafeWebhookUrl as e: + raise _BadRequest(str(e)) from None + + events = _coerce_event_list(body.get("events")) + task_filter = _coerce_task_filter(body.get("task_filter")) + headers = _coerce_headers(body.get("headers")) + max_retries = _coerce_positive_int(body.get("max_retries"), "max_retries", 3) + timeout_seconds = _coerce_positive_float(body.get("timeout_seconds"), "timeout_seconds", 10.0) + retry_backoff = _coerce_positive_float(body.get("retry_backoff"), "retry_backoff", 2.0) + + secret = body.get("secret") + if secret is not None and not isinstance(secret, str): + raise _BadRequest("secret must be a string or null") + if body.get("generate_secret"): + secret = generate_secret() + + description = body.get("description") + if description is not None and not isinstance(description, str): + raise _BadRequest("description must be a string or null") + + sub = queue.add_webhook( + url=url, + events=[EventType(v) for v in events] if events else None, + headers=headers, + secret=secret, + max_retries=max_retries, + timeout=timeout_seconds, + retry_backoff=retry_backoff, + task_filter=task_filter, + description=description, + ) + return _serialize(sub, reveal_secret=True) + + +def handle_update_webhook(queue: Queue, body: dict, subscription_id: str) -> dict[str, Any]: + if not isinstance(body, dict): + raise _BadRequest("body must be a JSON object") + sub = WebhookSubscriptionStore(queue).get(subscription_id) + if sub is None: + raise _NotFound(f"webhook '{subscription_id}' not found") + + changes: dict[str, Any] = {} + if "url" in body: + url = _require_str(body, "url") + try: + validate_webhook_url(url) + except UnsafeWebhookUrl as e: + raise _BadRequest(str(e)) from None + changes["url"] = url + if "events" in body: + changes["events"] = _coerce_event_list(body["events"]) + if "task_filter" in body: + changes["task_filter"] = _coerce_task_filter(body["task_filter"]) + if "headers" in body: + changes["headers"] = _coerce_headers(body["headers"]) + if "max_retries" in body: + changes["max_retries"] = _coerce_positive_int(body["max_retries"], "max_retries", 3) + if "timeout_seconds" in body: + changes["timeout_seconds"] = _coerce_positive_float( + body["timeout_seconds"], "timeout_seconds", 10.0 + ) + if "retry_backoff" in body: + changes["retry_backoff"] = _coerce_positive_float( + body["retry_backoff"], "retry_backoff", 2.0 + ) + if "enabled" in body: + if not isinstance(body["enabled"], bool): + raise _BadRequest("enabled must be a boolean") + changes["enabled"] = body["enabled"] + if "description" in body: + description = body["description"] + if description is not None and not isinstance(description, str): + raise _BadRequest("description must be a string or null") + changes["description"] = description + + updated = queue.update_webhook(subscription_id, **changes) + return _serialize(updated) + + +def handle_delete_webhook(queue: Queue, subscription_id: str) -> dict[str, bool]: + removed = queue.remove_webhook(subscription_id) + if not removed: + raise _NotFound(f"webhook '{subscription_id}' not found") + return {"deleted": True} + + +def handle_rotate_secret(queue: Queue, subscription_id: str) -> dict[str, Any]: + if WebhookSubscriptionStore(queue).get(subscription_id) is None: + raise _NotFound(f"webhook '{subscription_id}' not found") + secret = queue.rotate_webhook_secret(subscription_id) + return {"id": subscription_id, "secret": secret} + + +def handle_test_webhook(queue: Queue, subscription_id: str) -> dict[str, Any]: + """Synchronously POST a synthetic event and return the result inline.""" + sub = WebhookSubscriptionStore(queue).get(subscription_id) + if sub is None: + raise _NotFound(f"webhook '{subscription_id}' not found") + + from taskito.webhooks import WebhookManager + + runtime = WebhookManager._subscription_to_runtime(sub) + payload = { + "event": "test.ping", + "task_name": None, + "subscription_id": sub.id, + "message": "synthetic test event from dashboard", + } + status = queue._webhook_manager.deliver_now(runtime, payload) + return {"status": status, "delivered": status is not None and status < 400} + + +def handle_list_event_types(_queue: Queue, _qs: dict) -> list[str]: + return sorted(e.value for e in EventType) diff --git a/py_src/taskito/dashboard/middleware_store.py b/py_src/taskito/dashboard/middleware_store.py new file mode 100644 index 0000000..0c2554b --- /dev/null +++ b/py_src/taskito/dashboard/middleware_store.py @@ -0,0 +1,88 @@ +"""Per-task middleware disable list. + +Operators turn individual middlewares off for individual tasks from the +dashboard. The disable list is persisted under +``middleware:disabled:`` as a JSON array of middleware names, +read by :meth:`~taskito.mixins.decorators.QueueDecoratorMixin._get_middleware_chain` +at every task invocation so changes take effect immediately on the next +job without a worker restart. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from taskito.app import Queue + + +DISABLE_PREFIX = "middleware:disabled:" + +logger = logging.getLogger("taskito.dashboard.middleware") + + +def _parse(raw: str | None) -> list[str]: + if not raw: + return [] + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("middleware disable list is not valid JSON; treating as empty") + return [] + if not isinstance(data, list): + return [] + return [str(x) for x in data if isinstance(x, str)] + + +class MiddlewareDisableStore: + """List/set/clear per-task middleware disables.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + def _key(self, task_name: str) -> str: + return DISABLE_PREFIX + task_name + + def list_all(self) -> dict[str, list[str]]: + """Return ``{task_name: [disabled_mw_name, ...]}`` for every task that + has at least one disabled middleware.""" + out: dict[str, list[str]] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(DISABLE_PREFIX): + continue + task_name = key[len(DISABLE_PREFIX) :] + names = _parse(raw) + if names: + out[task_name] = names + return out + + def get_for(self, task_name: str) -> list[str]: + return _parse(self._queue.get_setting(self._key(task_name))) + + def is_disabled(self, task_name: str, mw_name: str) -> bool: + return mw_name in self.get_for(task_name) + + def set_disabled(self, task_name: str, mw_name: str, disabled: bool) -> list[str]: + """Flip a middleware on/off for a task and return the new disable list.""" + if not task_name: + raise ValueError("task_name must not be empty") + if not mw_name: + raise ValueError("mw_name must not be empty") + current = self.get_for(task_name) + if disabled: + if mw_name not in current: + current.append(mw_name) + else: + current = [n for n in current if n != mw_name] + if current: + self._queue.set_setting( + self._key(task_name), json.dumps(current, separators=(",", ":")) + ) + else: + self._queue.delete_setting(self._key(task_name)) + return current + + def clear_for(self, task_name: str) -> bool: + return self._queue.delete_setting(self._key(task_name)) diff --git a/py_src/taskito/dashboard/oauth/__init__.py b/py_src/taskito/dashboard/oauth/__init__.py new file mode 100644 index 0000000..28be1e4 --- /dev/null +++ b/py_src/taskito/dashboard/oauth/__init__.py @@ -0,0 +1,46 @@ +"""OAuth2 / OIDC support for the Taskito dashboard. + +Adds Google, GitHub, and one-or-more generic OIDC providers (Okta, Auth0, +Keycloak, Microsoft Entra, …) alongside the existing password login. Auth +state continues to live in ``dashboard_settings``; OAuth users are stored +in the same ``auth:users`` blob as password users, with a sentinel +``password_hash`` prefix (``oauth:{slot}``) that ``verify_password`` +refuses. +""" + +from __future__ import annotations + +from taskito.dashboard.oauth.config import ( + GitHubConfig, + GoogleConfig, + OAuthConfig, + OAuthConfigError, + OIDCConfig, +) +from taskito.dashboard.oauth.identity import ( + AllowlistDenied, + IdentityFetchError, + OAuthError, + OAuthProvider, + ProviderIdentity, + ProviderNotConfigured, + StateValidationError, +) +from taskito.dashboard.oauth.state_store import OAuthState, OAuthStateStore + +__all__ = [ + "AllowlistDenied", + "GitHubConfig", + "GoogleConfig", + "IdentityFetchError", + "OAuthConfig", + "OAuthConfigError", + "OAuthError", + "OAuthProvider", + "OAuthState", + "OAuthStateStore", + "OIDCConfig", + "ProviderIdentity", + "ProviderNotConfigured", + "StateValidationError", +] diff --git a/py_src/taskito/dashboard/oauth/config.py b/py_src/taskito/dashboard/oauth/config.py new file mode 100644 index 0000000..dfc7977 --- /dev/null +++ b/py_src/taskito/dashboard/oauth/config.py @@ -0,0 +1,257 @@ +"""Operator-facing OAuth configuration and env-var parsing. + +All settings come from environment variables (or an equivalent +:class:`OAuthConfig` instance passed programmatically — used by tests). +Secrets are never stored in the dashboard settings DB. + +See ``docs/content/docs/dashboard/oauth.mdx`` for the full env-var +reference. +""" + +from __future__ import annotations + +import os +import re +import urllib.parse +from collections.abc import Mapping +from dataclasses import dataclass, field + + +class OAuthConfigError(ValueError): + """Raised when env-var configuration is invalid.""" + + +SLOT_RE = re.compile(r"^[a-z][a-z0-9_-]{0,31}$") +RESERVED_SLOTS = frozenset({"google", "github"}) + +# Hostnames where http:// is accepted for ``redirect_base_url`` (dev only). +_LOCAL_HOSTS = frozenset({"localhost", "127.0.0.1", "::1"}) + + +def _split_csv(raw: str | None) -> tuple[str, ...]: + if not raw: + return () + return tuple(part.strip() for part in raw.split(",") if part.strip()) + + +@dataclass(frozen=True) +class GoogleConfig: + client_id: str + client_secret: str + allowed_domains: tuple[str, ...] = () + + slot: str = "google" + label: str = "Google" + type: str = "google" + + +@dataclass(frozen=True) +class GitHubConfig: + client_id: str + client_secret: str + allowed_orgs: tuple[str, ...] = () + + slot: str = "github" + label: str = "GitHub" + type: str = "github" + + +@dataclass(frozen=True) +class OIDCConfig: + slot: str + client_id: str + client_secret: str + discovery_url: str + allowed_domains: tuple[str, ...] = () + label: str = "" + type: str = "oidc" + + def __post_init__(self) -> None: + if not SLOT_RE.match(self.slot): + raise OAuthConfigError(f"OIDC slot {self.slot!r} must match {SLOT_RE.pattern}") + if self.slot in RESERVED_SLOTS: + raise OAuthConfigError(f"OIDC slot {self.slot!r} collides with built-in provider") + if not self.discovery_url: + raise OAuthConfigError(f"OIDC slot {self.slot!r}: discovery_url is required") + + +@dataclass(frozen=True) +class OAuthConfig: + """Top-level OAuth configuration. + + ``redirect_base_url`` is the public origin the dashboard is served at — + every callback URL is built from it (``{redirect_base_url}/api/auth/oauth/callback/{slot}``). + OAuth is considered disabled if no provider is configured. + """ + + redirect_base_url: str + google: GoogleConfig | None = None + github: GitHubConfig | None = None + oidc: tuple[OIDCConfig, ...] = () + password_auth_enabled: bool = True + admin_emails: tuple[str, ...] = field(default=()) + + def __post_init__(self) -> None: + _validate_redirect_base_url(self.redirect_base_url) + + @property + def is_enabled(self) -> bool: + return self.google is not None or self.github is not None or bool(self.oidc) + + def providers(self) -> tuple[GoogleConfig | GitHubConfig | OIDCConfig, ...]: + """Configured providers in display order (Google, GitHub, then OIDC slots).""" + out: list[GoogleConfig | GitHubConfig | OIDCConfig] = [] + if self.google is not None: + out.append(self.google) + if self.github is not None: + out.append(self.github) + out.extend(self.oidc) + return tuple(out) + + def find_provider(self, slot: str) -> GoogleConfig | GitHubConfig | OIDCConfig | None: + for p in self.providers(): + if p.slot == slot: + return p + return None + + def callback_url(self, slot: str) -> str: + return f"{self.redirect_base_url.rstrip('/')}/api/auth/oauth/callback/{slot}" + + +def _validate_redirect_base_url(url: str) -> None: + if not url: + raise OAuthConfigError("redirect_base_url must be set when OAuth is enabled") + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ("http", "https"): + raise OAuthConfigError(f"redirect_base_url must be http(s), got {parsed.scheme!r}") + if not parsed.hostname: + raise OAuthConfigError("redirect_base_url must include a hostname") + if parsed.scheme == "http" and parsed.hostname not in _LOCAL_HOSTS: + raise OAuthConfigError( + f"redirect_base_url must use https for non-local hosts (got http://{parsed.hostname})" + ) + + +def from_env(environ: Mapping[str, str] | None = None) -> OAuthConfig | None: + """Parse :class:`OAuthConfig` from the environment. + + Returns ``None`` when neither ``TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL`` + nor any provider client-id env var is set — i.e. OAuth is not configured. + Raises :class:`OAuthConfigError` if some but not all required vars are set + for a configured provider (fail-fast on partial configuration). + """ + env = environ if environ is not None else os.environ + + base_url = env.get("TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL", "").strip() + google_id = env.get("TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID", "").strip() + github_id = env.get("TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_ID", "").strip() + oidc_slots_raw = env.get("TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS", "").strip() + + any_provider_signal = bool(google_id or github_id or oidc_slots_raw) + if not any_provider_signal and not base_url: + return None + if any_provider_signal and not base_url: + raise OAuthConfigError( + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL must be set when any " + "OAuth provider is configured" + ) + + google = _parse_google(env) if google_id else None + github = _parse_github(env) if github_id else None + oidc = _parse_oidc_slots(env, oidc_slots_raw) + password_enabled = _parse_bool( + env.get("TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED", "true"), default=True + ) + admin_emails = _split_csv(env.get("TASKITO_DASHBOARD_OAUTH_ADMIN_EMAILS")) + + config = OAuthConfig( + redirect_base_url=base_url, + google=google, + github=github, + oidc=oidc, + password_auth_enabled=password_enabled, + admin_emails=admin_emails, + ) + + if not config.is_enabled and not password_enabled: + raise OAuthConfigError( + "password auth disabled but no OAuth providers configured — no way to log in" + ) + + return config + + +def _parse_google(env: Mapping[str, str]) -> GoogleConfig: + cid = env.get("TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID", "").strip() + secret = env.get("TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET", "").strip() + if not secret: + raise OAuthConfigError( + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET is required when google client_id is set" + ) + return GoogleConfig( + client_id=cid, + client_secret=secret, + allowed_domains=_split_csv(env.get("TASKITO_DASHBOARD_OAUTH_GOOGLE_ALLOWED_DOMAINS")), + ) + + +def _parse_github(env: Mapping[str, str]) -> GitHubConfig: + cid = env.get("TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_ID", "").strip() + secret = env.get("TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_SECRET", "").strip() + if not secret: + raise OAuthConfigError( + "TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_SECRET is required when github client_id is set" + ) + return GitHubConfig( + client_id=cid, + client_secret=secret, + allowed_orgs=_split_csv(env.get("TASKITO_DASHBOARD_OAUTH_GITHUB_ALLOWED_ORGS")), + ) + + +def _parse_oidc_slots(env: Mapping[str, str], slots_raw: str) -> tuple[OIDCConfig, ...]: + slot_names = _split_csv(slots_raw) + if not slot_names: + return () + out: list[OIDCConfig] = [] + seen: set[str] = set() + for raw_slot in slot_names: + slot = raw_slot.lower() + if slot in seen: + raise OAuthConfigError( + f"OIDC slot {slot!r} listed twice in TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS" + ) + seen.add(slot) + out.append(_parse_oidc_slot(env, slot)) + return tuple(out) + + +def _parse_oidc_slot(env: Mapping[str, str], slot: str) -> OIDCConfig: + prefix = f"TASKITO_DASHBOARD_OAUTH_OIDC_{slot.upper().replace('-', '_')}" + cid = env.get(f"{prefix}_CLIENT_ID", "").strip() + secret = env.get(f"{prefix}_CLIENT_SECRET", "").strip() + discovery = env.get(f"{prefix}_DISCOVERY_URL", "").strip() + if not cid or not secret or not discovery: + raise OAuthConfigError( + f"OIDC slot {slot!r} requires {prefix}_CLIENT_ID, _CLIENT_SECRET, and _DISCOVERY_URL" + ) + default_label = slot.replace("-", " ").replace("_", " ").title() + label = env.get(f"{prefix}_LABEL", "").strip() or default_label + allowed = _split_csv(env.get(f"{prefix}_ALLOWED_DOMAINS")) + return OIDCConfig( + slot=slot, + client_id=cid, + client_secret=secret, + discovery_url=discovery, + allowed_domains=allowed, + label=label, + ) + + +def _parse_bool(raw: str, *, default: bool) -> bool: + lowered = raw.strip().lower() + if lowered in ("1", "true", "yes", "on"): + return True + if lowered in ("0", "false", "no", "off"): + return False + return default diff --git a/py_src/taskito/dashboard/oauth/flow.py b/py_src/taskito/dashboard/oauth/flow.py new file mode 100644 index 0000000..aacf8e7 --- /dev/null +++ b/py_src/taskito/dashboard/oauth/flow.py @@ -0,0 +1,167 @@ +"""End-to-end OAuth flow orchestration. + +:class:`OAuthFlow` is the seam between the HTTP handler layer and the +provider implementations. It owns the registry of configured providers, +the state store, and the :class:`AuthStore` integration. Handlers call +``start()`` to mint a redirect URL and ``handle_callback()`` to land a +session. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from taskito.dashboard.auth import AuthStore +from taskito.dashboard.oauth.config import ( + GitHubConfig, + GoogleConfig, + OAuthConfig, + OIDCConfig, +) +from taskito.dashboard.oauth.identity import ( + IdentityFetchError, + OAuthProvider, + ProviderNotConfigured, + StateValidationError, +) +from taskito.dashboard.oauth.pkce import s256_challenge +from taskito.dashboard.oauth.providers import ( + GenericOIDCProvider, + GitHubProvider, + GoogleProvider, +) +from taskito.dashboard.oauth.state_store import OAuthStateStore +from taskito.dashboard.url_safety import is_safe_redirect + +if TYPE_CHECKING: + from taskito.app import Queue + from taskito.dashboard.auth import Session + + +def build_providers( + config: OAuthConfig, +) -> dict[str, OAuthProvider]: + """Instantiate one provider per configured slot, keyed by slot.""" + registry: dict[str, OAuthProvider] = {} + for entry in config.providers(): + if isinstance(entry, GoogleConfig): + registry[entry.slot] = GoogleProvider(entry) + elif isinstance(entry, GitHubConfig): + registry[entry.slot] = GitHubProvider(entry) + elif isinstance(entry, OIDCConfig): + registry[entry.slot] = GenericOIDCProvider(entry) + return registry + + +class OAuthFlow: + """Ties together config, providers, state store, and the auth store.""" + + def __init__( + self, + queue: Queue, + config: OAuthConfig, + *, + providers: dict[str, OAuthProvider] | None = None, + state_store: OAuthStateStore | None = None, + ) -> None: + self._queue = queue + self._config = config + self._providers: dict[str, OAuthProvider] = ( + providers if providers is not None else build_providers(config) + ) + self._state_store = state_store or OAuthStateStore(queue) + + # ── Introspection ──────────────────────────────────────────────── + + @property + def password_auth_enabled(self) -> bool: + return self._config.password_auth_enabled + + def has_provider(self, slot: str) -> bool: + return slot in self._providers + + def providers_listing(self) -> list[dict[str, str]]: + """Compact provider summary for the login UI (no secrets).""" + return [ + {"slot": p.slot, "label": p.label, "type": p.type} for p in self._providers.values() + ] + + # ── Flow ───────────────────────────────────────────────────────── + + def start(self, slot: str, next_url: str | None) -> str: + """Mint a state row and return the provider's authorize URL. + + ``next_url`` is sanitised against :func:`is_safe_redirect` and falls + back to ``"/"`` if it fails the check. + """ + provider = self._require_provider(slot) + safe_next = next_url if next_url and is_safe_redirect(next_url) else "/" + state = self._state_store.create(slot=slot, next_url=safe_next) + challenge = s256_challenge(state.code_verifier) + return provider.authorization_url( + state=state.state, + nonce=state.nonce, + code_challenge=challenge, + redirect_uri=self._config.callback_url(slot), + ) + + def handle_callback( + self, + slot: str, + *, + code: str | None, + state_token: str | None, + error: str | None, + ) -> tuple[Session, str]: + """Exchange ``code`` for an identity and create a session. + + Returns ``(session, next_url)`` on success. Raises: + + - :class:`StateValidationError` for missing/expired/replayed state + - :class:`IdentityFetchError` for any token / userinfo / claim issue + - :class:`AllowlistDenied` if the identity is outside the allowlist + """ + if error: + raise IdentityFetchError(f"provider returned error: {error}") + if not code or not state_token: + raise StateValidationError("missing code or state parameter") + + row = self._state_store.consume(state_token) + if row is None: + raise StateValidationError("state is invalid, expired, or already used") + if row.slot != slot: + raise StateValidationError("state slot does not match callback slot") + + provider = self._require_provider(slot) + identity = provider.exchange_code( + code=code, + code_verifier=row.code_verifier, + redirect_uri=self._config.callback_url(slot), + expected_nonce=row.nonce, + ) + provider.check_allowlist(identity) + + store = AuthStore(self._queue) + user = store.get_or_create_oauth_user( + slot=identity.slot, + subject=identity.subject, + email=identity.email, + name=identity.name, + email_verified=identity.email_verified, + admin_emails=self._config.admin_emails, + ) + session = store.create_session(user) + return session, row.next_url + + # ── Maintenance ────────────────────────────────────────────────── + + def prune_state(self) -> int: + return self._state_store.prune_expired() + + # ── Internal ───────────────────────────────────────────────────── + + def _require_provider(self, slot: str) -> OAuthProvider: + provider = self._providers.get(slot) + if provider is None: + raise ProviderNotConfigured(f"OAuth provider {slot!r} is not configured") + return provider diff --git a/py_src/taskito/dashboard/oauth/identity.py b/py_src/taskito/dashboard/oauth/identity.py new file mode 100644 index 0000000..59f906e --- /dev/null +++ b/py_src/taskito/dashboard/oauth/identity.py @@ -0,0 +1,89 @@ +"""Identity types and the provider contract. + +A :class:`ProviderIdentity` is the canonical shape of "who just logged in" +that every provider must return. The :class:`OAuthProvider` protocol is +the contract every concrete provider (Google, GitHub, generic OIDC) +satisfies. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + + +class OAuthError(Exception): + """Base class for any OAuth-flow error surfaced to the handler layer.""" + + +class StateValidationError(OAuthError): + """Raised when the callback state is missing, expired, replayed, or forged.""" + + +class IdentityFetchError(OAuthError): + """Raised when the provider returns an error during token / userinfo fetch.""" + + +class AllowlistDenied(OAuthError): + """Raised when a verified identity is rejected by a configured allowlist.""" + + +class ProviderNotConfigured(OAuthError): + """Raised when a request references an OAuth slot that is not registered.""" + + +@dataclass(frozen=True) +class ProviderIdentity: + """Normalised identity returned by every provider after a successful flow. + + ``slot`` is the registry key (``google``, ``github``, or the operator- + chosen OIDC slot name). ``subject`` is the provider's stable unique ID + for the user (``sub`` claim, GitHub ``id``); never the email, which + can change. Both together form the Taskito username ``f"{slot}:{subject}"``. + """ + + slot: str + subject: str + email: str | None + email_verified: bool + name: str | None = None + picture: str | None = None + + +class OAuthProvider(Protocol): + """Contract every OAuth provider implementation must satisfy.""" + + slot: str + """URL-safe unique identifier used in the callback path.""" + + label: str + """Human-readable button label rendered by the dashboard.""" + + type: str + """One of ``"google"``, ``"github"``, ``"oidc"`` — chooses the icon.""" + + def authorization_url( + self, + *, + state: str, + nonce: str, + code_challenge: str, + redirect_uri: str, + ) -> str: + """Build the provider-side authorize URL the browser is redirected to.""" + ... + + def exchange_code( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + expected_nonce: str | None, + ) -> ProviderIdentity: + """Exchange the auth code for an identity, raising on any failure.""" + ... + + def check_allowlist(self, identity: ProviderIdentity) -> None: + """Raise :class:`AllowlistDenied` if the identity is not permitted.""" + ... diff --git a/py_src/taskito/dashboard/oauth/pkce.py b/py_src/taskito/dashboard/oauth/pkce.py new file mode 100644 index 0000000..1d5f7f0 --- /dev/null +++ b/py_src/taskito/dashboard/oauth/pkce.py @@ -0,0 +1,16 @@ +"""PKCE S256 code-challenge derivation.""" + +from __future__ import annotations + +import base64 +import hashlib + + +def s256_challenge(verifier: str) -> str: + """Return the S256 code-challenge for ``verifier`` per RFC 7636. + + The challenge is ``base64url(sha256(verifier))`` with trailing ``=`` + padding stripped, matching every OAuth provider's PKCE implementation. + """ + digest = hashlib.sha256(verifier.encode("ascii")).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") diff --git a/py_src/taskito/dashboard/oauth/providers.py b/py_src/taskito/dashboard/oauth/providers.py new file mode 100644 index 0000000..1df51fe --- /dev/null +++ b/py_src/taskito/dashboard/oauth/providers.py @@ -0,0 +1,441 @@ +"""Concrete provider implementations: Google, GitHub, generic OIDC. + +Every provider satisfies :class:`OAuthProvider`. The split between +``exchange_code`` (network IO + claim normalisation) and +``check_allowlist`` (pure-data permission check) is deliberate so tests +can drive either path in isolation. + +Providers depend on :class:`HttpClient`, a structural Protocol over +the small subset of ``requests.Session`` they actually use (one ``get`` +method). Production code passes a ``requests.Session``; tests pass an +in-memory stub. The Protocol keeps the provider layer framework-free +and test-friendly without runtime ``cast`` calls at either boundary. +""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Protocol +from urllib.parse import urlencode + +import requests +from authlib.integrations.requests_client import OAuth2Session +from joserfc import jwt +from joserfc.errors import JoseError +from joserfc.jwk import KeySet + +from taskito.dashboard.oauth.identity import ( + AllowlistDenied, + IdentityFetchError, + ProviderIdentity, +) + +if TYPE_CHECKING: + from taskito.dashboard.oauth.config import ( + GitHubConfig, + GoogleConfig, + OIDCConfig, + ) + + +class HttpResponse(Protocol): + """Minimal response shape a provider consumes from its HTTP client.""" + + status_code: int + text: str + + def json(self) -> Any: ... + + def raise_for_status(self) -> None: ... + + +class HttpClient(Protocol): + """Minimal HTTP client shape — ``requests.Session`` satisfies this.""" + + def get( + self, + url: str, + *, + headers: dict[str, str] | None = ..., + timeout: float = ..., + ) -> HttpResponse: ... + + +GOOGLE_DISCOVERY_URL = "https://accounts.google.com/.well-known/openid-configuration" +GITHUB_AUTH_URL = "https://github.com/login/oauth/authorize" +GITHUB_TOKEN_URL = "https://github.com/login/oauth/access_token" +GITHUB_API_BASE = "https://api.github.com" + +_HTTP_TIMEOUT = 10.0 + + +def _email_domain(email: str | None) -> str | None: + if not email or "@" not in email: + return None + return email.rsplit("@", 1)[-1].lower() + + +def _audience_matches(aud: Any, client_id: str) -> bool: + if isinstance(aud, str): + return aud == client_id + if isinstance(aud, list): + return client_id in aud + return False + + +# ── OIDC provider (shared logic for Google + generic OIDC) ───────────── + + +class _OIDCProviderBase: + """Shared OIDC machinery: discovery, JWKS caching, ID-token decoding.""" + + slot: str + label: str + type: str + client_id: str + client_secret: str + discovery_url: str + scope: str = "openid email profile" + + def __init__(self, *, http: HttpClient | None = None) -> None: + self._http = http or requests.Session() + self._discovery: dict[str, Any] | None = None + self._jwks: dict[str, Any] | None = None + + # Sub-classes override / extend `_extra_auth_params` to add hints. + def _extra_auth_params(self) -> dict[str, str]: + return {} + + def _get_discovery(self) -> dict[str, Any]: + if self._discovery is None: + resp = self._http.get(self.discovery_url, timeout=_HTTP_TIMEOUT) + resp.raise_for_status() + self._discovery = resp.json() + return self._discovery + + def _get_jwks(self) -> dict[str, Any]: + if self._jwks is None: + resp = self._http.get(self._get_discovery()["jwks_uri"], timeout=_HTTP_TIMEOUT) + resp.raise_for_status() + self._jwks = resp.json() + return self._jwks + + def authorization_url( + self, + *, + state: str, + nonce: str, + code_challenge: str, + redirect_uri: str, + ) -> str: + params: dict[str, str] = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": redirect_uri, + "scope": self.scope, + "state": state, + "nonce": nonce, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + params.update(self._extra_auth_params()) + return f"{self._get_discovery()['authorization_endpoint']}?{urlencode(params)}" + + def _fetch_token( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + ) -> dict[str, Any]: + """POST the auth code to the token endpoint. Returns the raw token dict. + + Isolated so tests can stub it without involving Authlib's HTTP stack. + """ + client = OAuth2Session( + client_id=self.client_id, + client_secret=self.client_secret, + ) + try: + token = client.fetch_token( + self._get_discovery()["token_endpoint"], + code=code, + code_verifier=code_verifier, + redirect_uri=redirect_uri, + grant_type="authorization_code", + ) + except Exception as e: + raise IdentityFetchError(f"token exchange failed: {e}") from e + return dict(token) + + def exchange_code( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + expected_nonce: str | None, + ) -> ProviderIdentity: + token = self._fetch_token( + code=code, code_verifier=code_verifier, redirect_uri=redirect_uri + ) + id_token = token.get("id_token") + if not id_token: + raise IdentityFetchError("no id_token in token response") + + try: + # joserfc's KeySet.import_key_set is typed as accepting its own + # KeySetSerialization TypedDict, but operationally it accepts any + # standard JWKS dict. Mypy 2.x widened the stub; 1.x still + # complains. The dual code suppresses both directions. + key_set = KeySet.import_key_set(self._get_jwks()) # type: ignore[arg-type, unused-ignore] + decoded = jwt.decode(id_token, key_set) + claims = decoded.claims + except JoseError as e: + raise IdentityFetchError(f"id_token validation failed: {e}") from e + + issuer = self._get_discovery().get("issuer") + if issuer and claims.get("iss") != issuer: + raise IdentityFetchError( + f"id_token issuer mismatch: expected {issuer!r}, got {claims.get('iss')!r}" + ) + if not _audience_matches(claims.get("aud"), self.client_id): + raise IdentityFetchError(f"id_token audience mismatch: {claims.get('aud')!r}") + if expected_nonce is not None and claims.get("nonce") != expected_nonce: + raise IdentityFetchError("id_token nonce mismatch") + + exp = claims.get("exp") + if isinstance(exp, (int, float)) and exp < int(time.time()) - 60: + # 60s clock skew tolerance. + raise IdentityFetchError("id_token expired") + + sub = claims.get("sub") + if not sub: + raise IdentityFetchError("id_token missing 'sub' claim") + + return ProviderIdentity( + slot=self.slot, + subject=str(sub), + email=claims.get("email"), + email_verified=bool(claims.get("email_verified")), + name=claims.get("name"), + picture=claims.get("picture"), + ) + + +class GoogleProvider(_OIDCProviderBase): + slot = "google" + type = "google" + discovery_url = GOOGLE_DISCOVERY_URL + + def __init__(self, config: GoogleConfig, *, http: HttpClient | None = None) -> None: + super().__init__(http=http) + self.config = config + self.label = config.label + self.client_id = config.client_id + self.client_secret = config.client_secret + + def _extra_auth_params(self) -> dict[str, str]: + params = {"prompt": "select_account"} + # When exactly one domain is allowlisted, pass it as ``hd`` so Google + # pre-selects the right account. This is a UX hint only — the real + # enforcement happens in ``check_allowlist``. + if len(self.config.allowed_domains) == 1: + params["hd"] = self.config.allowed_domains[0] + return params + + def check_allowlist(self, identity: ProviderIdentity) -> None: + if not self.config.allowed_domains: + return + if not identity.email or not identity.email_verified: + raise AllowlistDenied("verified email required for domain check") + domain = _email_domain(identity.email) + allowed = {d.lower() for d in self.config.allowed_domains} + if domain not in allowed: + raise AllowlistDenied(f"email domain {domain!r} is not in the allowed domains list") + + +class GenericOIDCProvider(_OIDCProviderBase): + type = "oidc" + + def __init__(self, config: OIDCConfig, *, http: HttpClient | None = None) -> None: + super().__init__(http=http) + self.config = config + self.slot = config.slot + self.label = config.label or config.slot.title() + self.client_id = config.client_id + self.client_secret = config.client_secret + self.discovery_url = config.discovery_url + + def check_allowlist(self, identity: ProviderIdentity) -> None: + if not self.config.allowed_domains: + return + if not identity.email or not identity.email_verified: + raise AllowlistDenied("verified email required for domain check") + domain = _email_domain(identity.email) + allowed = {d.lower() for d in self.config.allowed_domains} + if domain not in allowed: + raise AllowlistDenied(f"email domain {domain!r} is not in the allowed domains list") + + +# ── GitHub (OAuth2-only, no OIDC) ────────────────────────────────────── + + +class GitHubProvider: + slot = "github" + type = "github" + scope = "read:user user:email" + + def __init__(self, config: GitHubConfig, *, http: HttpClient | None = None) -> None: + self.config = config + self.label = config.label + self._http = http or requests.Session() + + def authorization_url( + self, + *, + state: str, + nonce: str, + code_challenge: str, + redirect_uri: str, + ) -> str: + # GitHub does not implement OIDC: ``nonce`` is unused. PKCE is honoured + # — GitHub added support for it in 2023. + params = { + "client_id": self.config.client_id, + "redirect_uri": redirect_uri, + "scope": self.scope, + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "allow_signup": "false", + } + # Request read:org scope when allowlist is configured, so the + # membership endpoint returns reliable results. + if self.config.allowed_orgs: + params["scope"] = self.scope + " read:org" + return f"{GITHUB_AUTH_URL}?{urlencode(params)}" + + def _fetch_token( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + ) -> dict[str, Any]: + client = OAuth2Session( + client_id=self.config.client_id, + client_secret=self.config.client_secret, + ) + try: + token = client.fetch_token( + GITHUB_TOKEN_URL, + code=code, + code_verifier=code_verifier, + redirect_uri=redirect_uri, + grant_type="authorization_code", + # GitHub returns form-encoded by default; ask for JSON. + headers={"Accept": "application/json"}, + ) + except Exception as e: + raise IdentityFetchError(f"token exchange failed: {e}") from e + return dict(token) + + def _api_get(self, path: str, access_token: str) -> Any: + resp = self._http.get( + f"{GITHUB_API_BASE}{path}", + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + }, + timeout=_HTTP_TIMEOUT, + ) + if resp.status_code >= 400 and resp.status_code != 404: + raise IdentityFetchError( + f"GitHub API {path} returned {resp.status_code}: {resp.text[:200]}" + ) + return resp + + def exchange_code( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + expected_nonce: str | None, + ) -> ProviderIdentity: + token = self._fetch_token( + code=code, code_verifier=code_verifier, redirect_uri=redirect_uri + ) + access_token = token.get("access_token") + if not access_token: + raise IdentityFetchError("no access_token in token response") + + user_resp = self._api_get("/user", access_token) + if user_resp.status_code != 200: + raise IdentityFetchError(f"GET /user failed: {user_resp.status_code}") + user = user_resp.json() + gh_id = user.get("id") + login = user.get("login") + if gh_id is None or not login: + raise IdentityFetchError("GitHub /user response missing 'id' or 'login'") + + primary_email, verified = self._primary_email(access_token) + + # Org membership requires the access token, so we enforce it here + # rather than in ``check_allowlist`` (which is a no-op for GitHub). + # Any denial raises :class:`AllowlistDenied` straight through. + self._verify_org_membership(access_token, str(login)) + + return ProviderIdentity( + slot=self.slot, + subject=str(gh_id), + email=primary_email, + email_verified=verified, + name=user.get("name") or user.get("login"), + picture=user.get("avatar_url"), + ) + + def _primary_email(self, access_token: str) -> tuple[str | None, bool]: + """Return ``(primary_verified_email_or_None, verified_flag)``. + + Falls back to ``None`` if no verified primary exists. We never trust + an unverified email for any access decision. + """ + resp = self._api_get("/user/emails", access_token) + if resp.status_code != 200: + return None, False + for entry in resp.json(): + if entry.get("primary") and entry.get("verified"): + return entry.get("email"), True + return None, False + + def _verify_org_membership(self, access_token: str, login: str) -> None: + if not self.config.allowed_orgs: + return + for org in self.config.allowed_orgs: + resp = self._http.get( + f"{GITHUB_API_BASE}/orgs/{org}/members/{login}", + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + }, + timeout=_HTTP_TIMEOUT, + ) + if resp.status_code == 204: + return + if resp.status_code not in (302, 404): + raise IdentityFetchError(f"GitHub org membership check failed: {resp.status_code}") + raise AllowlistDenied( + f"user is not a member of any allowed GitHub org " + f"({', '.join(self.config.allowed_orgs)})" + ) + + def check_allowlist(self, identity: ProviderIdentity) -> None: + """No-op — GitHub's org check happens inside :meth:`exchange_code`. + + Required by the :class:`OAuthProvider` protocol for interface symmetry. + """ + return diff --git a/py_src/taskito/dashboard/oauth/state_store.py b/py_src/taskito/dashboard/oauth/state_store.py new file mode 100644 index 0000000..5eb541b --- /dev/null +++ b/py_src/taskito/dashboard/oauth/state_store.py @@ -0,0 +1,137 @@ +"""Short-lived store for in-flight OAuth flows. + +When the dashboard redirects a browser to a provider's ``/authorize`` URL, +we stash the ``state``, ``nonce``, PKCE ``code_verifier``, target slot, +and post-login ``next_url`` server-side, keyed by ``state``. On callback +we look the row up, validate ``state`` (single-use, time-bounded), and +delete it. + +Rows live in ``dashboard_settings`` under the ``auth:oauth_state:`` +key namespace alongside sessions and users, so they work uniformly across +SQLite / Postgres / Redis with no new migrations. +""" + +from __future__ import annotations + +import json +import logging +import secrets +import time +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from taskito.app import Queue + + +logger = logging.getLogger("taskito.dashboard.oauth") + +STATE_PREFIX = "auth:oauth_state:" +DEFAULT_STATE_TTL_SECONDS = 5 * 60 # 5 min — covers consent UX + reasonable network latency. + +STATE_TOKEN_BYTES = 32 +NONCE_BYTES = 16 +CODE_VERIFIER_BYTES = 32 + + +@dataclass(frozen=True) +class OAuthState: + """One in-flight OAuth flow, stored server-side until callback or expiry.""" + + state: str + nonce: str + code_verifier: str + slot: str + next_url: str + created_at: int + expires_at: int + + def is_expired(self, now: int | None = None) -> bool: + return (now if now is not None else int(time.time())) >= self.expires_at + + +def generate_state() -> str: + return secrets.token_urlsafe(STATE_TOKEN_BYTES) + + +def generate_nonce() -> str: + return secrets.token_urlsafe(NONCE_BYTES) + + +def generate_code_verifier() -> str: + # RFC 7636 section 4.1: high-entropy URL-safe string, 43-128 chars. + # 32 bytes yields 43 chars base64url, comfortably above the minimum. + return secrets.token_urlsafe(CODE_VERIFIER_BYTES) + + +class OAuthStateStore: + """Create, consume (read+delete), and prune short-lived OAuth state rows.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + def create( + self, + slot: str, + next_url: str, + ttl_seconds: int = DEFAULT_STATE_TTL_SECONDS, + ) -> OAuthState: + """Mint a fresh state/nonce/verifier triple and persist it.""" + now = int(time.time()) + state = OAuthState( + state=generate_state(), + nonce=generate_nonce(), + code_verifier=generate_code_verifier(), + slot=slot, + next_url=next_url, + created_at=now, + expires_at=now + ttl_seconds, + ) + payload = {k: v for k, v in asdict(state).items() if k != "state"} + self._queue.set_setting( + STATE_PREFIX + state.state, json.dumps(payload, separators=(",", ":")) + ) + return state + + def consume(self, state_token: str) -> OAuthState | None: + """Look up ``state_token`` and atomically delete it. Returns ``None`` + if the row is missing, malformed, or expired. Single-use — the row + is always deleted, so a replayed state never re-validates. + """ + if not state_token: + return None + key = STATE_PREFIX + state_token + raw = self._queue.get_setting(key) + if not raw: + return None + # Always delete first so any subsequent request with the same state + # sees a missing row, even if parsing fails below. + self._queue.delete_setting(key) + try: + data = json.loads(raw) + except json.JSONDecodeError: + return None + try: + row = OAuthState(state=state_token, **data) + except TypeError: + return None + if row.is_expired(): + return None + return row + + def prune_expired(self) -> int: + """Best-effort sweep of expired state rows. Returns count removed.""" + now = int(time.time()) + removed = 0 + for key, value in self._queue.list_settings().items(): + if not key.startswith(STATE_PREFIX): + continue + try: + data = json.loads(value) + expires_at = int(data.get("expires_at", 0)) + except (json.JSONDecodeError, TypeError, ValueError): + continue + if expires_at <= now: + self._queue.delete_setting(key) + removed += 1 + return removed diff --git a/py_src/taskito/dashboard/overrides_store.py b/py_src/taskito/dashboard/overrides_store.py new file mode 100644 index 0000000..d5d70f1 --- /dev/null +++ b/py_src/taskito/dashboard/overrides_store.py @@ -0,0 +1,341 @@ +"""Persistent task & queue runtime overrides. + +Operators tune individual task or queue behaviour (rate limits, concurrency +caps, retry policy, timeouts, priority, paused state) at runtime via the +dashboard. The decorator-declared values become the *defaults* — any override +recorded here wins. + +Storage layout in ``dashboard_settings``: + +- ``overrides:task:`` — JSON of overridden fields for that task +- ``overrides:queue:`` — JSON of overridden fields for that queue + +Overrides are applied at worker startup (see +:meth:`taskito.mixins.lifecycle.QueueLifecycleMixin.start_worker`). +Changes to the store DO NOT take effect on a running worker until it is +restarted — the dashboard surfaces this so operators aren't surprised. + +The contract is intentionally minimal: only the fields below can be +overridden. The store rejects anything else so a typo can't write garbage +through the dashboard. +""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from taskito.app import Queue + + +TASK_PREFIX = "overrides:task:" +QUEUE_PREFIX = "overrides:queue:" + +logger = logging.getLogger("taskito.dashboard.overrides") + + +# ── Allowed override fields ──────────────────────────────────────────── + + +TASK_OVERRIDE_FIELDS: frozenset[str] = frozenset( + { + "rate_limit", + "max_concurrent", + "max_retries", + "retry_backoff", + "timeout", + "priority", + "paused", + } +) + +QUEUE_OVERRIDE_FIELDS: frozenset[str] = frozenset( + { + "rate_limit", + "max_concurrent", + "paused", + } +) + + +# ── Data classes ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class TaskOverride: + """An operator-set override for a registered task.""" + + task_name: str + rate_limit: str | None = None + max_concurrent: int | None = None + max_retries: int | None = None + retry_backoff: float | None = None + timeout: int | None = None + priority: int | None = None + paused: bool = False + updated_at: int = 0 + + def as_patch(self) -> dict[str, Any]: + """Return a dict of only the non-default fields (those the operator + actually set). The empty/default values are NOT patched onto the + underlying ``PyTaskConfig`` — they continue to use the decorator + value.""" + patch: dict[str, Any] = {} + for field in TASK_OVERRIDE_FIELDS: + if field == "paused": + continue # handled separately; not a PyTaskConfig field + value = getattr(self, field) + if value is not None: + patch[field] = value + return patch + + +@dataclass(frozen=True) +class QueueOverride: + """An operator-set override for a queue.""" + + queue_name: str + rate_limit: str | None = None + max_concurrent: int | None = None + paused: bool = False + updated_at: int = 0 + + +# ── Validation ───────────────────────────────────────────────────────── + + +def _validate_task_fields(fields: dict[str, Any]) -> None: + unknown = set(fields) - TASK_OVERRIDE_FIELDS + if unknown: + raise ValueError(f"unknown task override fields: {sorted(unknown)}") + _validate_rate_limit(fields.get("rate_limit")) + _validate_max_concurrent(fields.get("max_concurrent")) + _validate_int_field(fields, "max_retries", minimum=0) + _validate_float_field(fields, "retry_backoff", minimum=0) + _validate_int_field(fields, "timeout", minimum=1) + _validate_int_field(fields, "priority") + _validate_bool_field(fields, "paused") + + +def _validate_queue_fields(fields: dict[str, Any]) -> None: + unknown = set(fields) - QUEUE_OVERRIDE_FIELDS + if unknown: + raise ValueError(f"unknown queue override fields: {sorted(unknown)}") + _validate_rate_limit(fields.get("rate_limit")) + _validate_max_concurrent(fields.get("max_concurrent")) + _validate_bool_field(fields, "paused") + + +def _validate_rate_limit(value: Any) -> None: + if value is None: + return + if not isinstance(value, str) or not value: + raise ValueError("rate_limit must be a non-empty string like '100/m'") + # Cheap shape check; rate-limit parsing happens in Rust. + if "/" not in value: + raise ValueError("rate_limit must contain a unit, e.g. '10/s', '100/m', '3600/h'") + + +def _validate_max_concurrent(value: Any) -> None: + if value is None: + return + if not isinstance(value, int) or isinstance(value, bool) or value < 0: + raise ValueError("max_concurrent must be a non-negative integer") + + +def _validate_int_field(fields: dict[str, Any], name: str, *, minimum: int | None = None) -> None: + value = fields.get(name) + if value is None: + return + if not isinstance(value, int) or isinstance(value, bool): + raise ValueError(f"{name} must be an integer") + if minimum is not None and value < minimum: + raise ValueError(f"{name} must be >= {minimum}") + + +def _validate_float_field( + fields: dict[str, Any], name: str, *, minimum: float | None = None +) -> None: + value = fields.get(name) + if value is None: + return + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ValueError(f"{name} must be a number") + if minimum is not None and value < minimum: + raise ValueError(f"{name} must be >= {minimum}") + + +def _validate_bool_field(fields: dict[str, Any], name: str) -> None: + value = fields.get(name) + if value is not None and not isinstance(value, bool): + raise ValueError(f"{name} must be a boolean") + + +# ── Store ────────────────────────────────────────────────────────────── + + +def _now() -> int: + return int(time.time()) + + +def _parse_json(raw: str | None) -> dict[str, Any]: + if not raw: + return {} + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("overrides entry is not valid JSON; treating as empty") + return {} + return data if isinstance(data, dict) else {} + + +class OverridesStore: + """CRUD for per-task and per-queue runtime overrides.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + # ── Tasks ────────────────────────────────────────────────── + + def list_tasks(self) -> dict[str, TaskOverride]: + """Return ``{task_name: TaskOverride}`` for every task with an override.""" + out: dict[str, TaskOverride] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(TASK_PREFIX): + continue + task_name = key[len(TASK_PREFIX) :] + out[task_name] = self._row_to_task(task_name, _parse_json(raw)) + return out + + def get_task(self, task_name: str) -> TaskOverride | None: + raw = self._queue.get_setting(TASK_PREFIX + task_name) + if not raw: + return None + return self._row_to_task(task_name, _parse_json(raw)) + + def set_task(self, task_name: str, fields: dict[str, Any]) -> TaskOverride: + _validate_task_fields(fields) + if not task_name: + raise ValueError("task_name must not be empty") + existing = self.get_task(task_name) + merged: dict[str, Any] = {} + if existing is not None: + merged.update({k: v for k, v in asdict(existing).items() if v is not None}) + merged.pop("task_name", None) + merged.pop("updated_at", None) + for k, v in fields.items(): + if v is None: + merged.pop(k, None) + else: + merged[k] = v + merged["updated_at"] = _now() + self._queue.set_setting(TASK_PREFIX + task_name, json.dumps(merged, separators=(",", ":"))) + return self._row_to_task(task_name, merged) + + def clear_task(self, task_name: str) -> bool: + return self._queue.delete_setting(TASK_PREFIX + task_name) + + @staticmethod + def _row_to_task(task_name: str, row: dict[str, Any]) -> TaskOverride: + return TaskOverride( + task_name=task_name, + rate_limit=row.get("rate_limit"), + max_concurrent=row.get("max_concurrent"), + max_retries=row.get("max_retries"), + retry_backoff=row.get("retry_backoff"), + timeout=row.get("timeout"), + priority=row.get("priority"), + paused=bool(row.get("paused", False)), + updated_at=int(row.get("updated_at", 0)), + ) + + # ── Queues ───────────────────────────────────────────────── + + def list_queues(self) -> dict[str, QueueOverride]: + out: dict[str, QueueOverride] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(QUEUE_PREFIX): + continue + queue_name = key[len(QUEUE_PREFIX) :] + out[queue_name] = self._row_to_queue(queue_name, _parse_json(raw)) + return out + + def get_queue(self, queue_name: str) -> QueueOverride | None: + raw = self._queue.get_setting(QUEUE_PREFIX + queue_name) + if not raw: + return None + return self._row_to_queue(queue_name, _parse_json(raw)) + + def set_queue(self, queue_name: str, fields: dict[str, Any]) -> QueueOverride: + _validate_queue_fields(fields) + if not queue_name: + raise ValueError("queue_name must not be empty") + existing = self.get_queue(queue_name) + merged: dict[str, Any] = {} + if existing is not None: + merged.update({k: v for k, v in asdict(existing).items() if v is not None}) + merged.pop("queue_name", None) + merged.pop("updated_at", None) + for k, v in fields.items(): + if v is None: + merged.pop(k, None) + else: + merged[k] = v + merged["updated_at"] = _now() + self._queue.set_setting( + QUEUE_PREFIX + queue_name, json.dumps(merged, separators=(",", ":")) + ) + return self._row_to_queue(queue_name, merged) + + def clear_queue(self, queue_name: str) -> bool: + return self._queue.delete_setting(QUEUE_PREFIX + queue_name) + + @staticmethod + def _row_to_queue(queue_name: str, row: dict[str, Any]) -> QueueOverride: + return QueueOverride( + queue_name=queue_name, + rate_limit=row.get("rate_limit"), + max_concurrent=row.get("max_concurrent"), + paused=bool(row.get("paused", False)), + updated_at=int(row.get("updated_at", 0)), + ) + + # ── Apply (used at worker startup) ───────────────────────── + + def apply_task_overrides(self, configs: list[Any]) -> list[str]: + """Mutate each :class:`PyTaskConfig` in ``configs`` with any matching + task override. Returns a list of task names that are paused (so the + caller can skip enqueuing them). + """ + overrides = self.list_tasks() + paused: list[str] = [] + for config in configs: + override = overrides.get(config.name) + if override is None: + continue + for field, value in override.as_patch().items(): + if hasattr(config, field): + setattr(config, field, value) + if override.paused: + paused.append(config.name) + return paused + + def apply_queue_overrides( + self, queue_configs: dict[str, dict[str, Any]] + ) -> dict[str, dict[str, Any]]: + """Merge queue overrides into ``queue_configs``. Returns the merged + dict (a copy).""" + merged: dict[str, dict[str, Any]] = {k: dict(v) for k, v in queue_configs.items()} + for queue_name, override in self.list_queues().items(): + slot = merged.setdefault(queue_name, {}) + if override.rate_limit is not None: + slot["rate_limit"] = override.rate_limit + if override.max_concurrent is not None: + slot["max_concurrent"] = override.max_concurrent + if override.paused: + slot["paused"] = True + return merged diff --git a/py_src/taskito/dashboard/request_context.py b/py_src/taskito/dashboard/request_context.py new file mode 100644 index 0000000..03010c9 --- /dev/null +++ b/py_src/taskito/dashboard/request_context.py @@ -0,0 +1,94 @@ +"""Per-request authentication context for the dashboard. + +The HTTP server populates a :class:`RequestContext` for every request and +hands it to the dispatcher. Handlers that need the calling user (login, +logout, whoami, etc.) accept it as a keyword argument; pure-data handlers +ignore it. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from email.message import Message as _EmailMessage + +from taskito.dashboard.auth import Session + +# Cookie name used for the session token. HttpOnly + SameSite=Strict — the +# session cookie must never be readable from JavaScript or sent on +# third-party requests. +SESSION_COOKIE = "taskito_session" + +# Cookie name for the CSRF token. NOT HttpOnly — the SPA reads it and +# echoes it in the X-CSRF-Token header on state-changing requests. +CSRF_COOKIE = "taskito_csrf" +CSRF_HEADER = "X-CSRF-Token" + + +@dataclass(frozen=True) +class RequestContext: + """Auth state attached to a single HTTP request.""" + + session: Session | None + csrf_cookie: str | None + csrf_header: str | None + + @property + def is_authenticated(self) -> bool: + return self.session is not None + + @property + def username(self) -> str | None: + return self.session.username if self.session else None + + @property + def role(self) -> str | None: + return self.session.role if self.session else None + + def csrf_valid(self) -> bool: + """Double-submit cookie check. + + For state-changing requests we require: + - a non-empty CSRF cookie + - an ``X-CSRF-Token`` header that equals it byte-for-byte + - the value matches the session's stored CSRF token (defends + against an attacker who pre-seeds the cookie) + """ + if not self.session: + return False + if not self.csrf_cookie or not self.csrf_header: + return False + if self.csrf_cookie != self.csrf_header: + return False + return self.csrf_cookie == self.session.csrf_token + + +def parse_cookies(header: str | None) -> dict[str, str]: + """Parse a raw ``Cookie:`` header into a ``{name: value}`` dict. + + Empty or malformed cookies are silently skipped; only the first value + is kept for any duplicated cookie name. + """ + if not header: + return {} + cookies: dict[str, str] = {} + for part in header.split(";"): + if "=" not in part: + continue + name, _, value = part.strip().partition("=") + name = name.strip() + value = value.strip() + if name and name not in cookies: + cookies[name] = value + return cookies + + +def build_context(headers: _EmailMessage, session: Session | None) -> RequestContext: + """Construct a :class:`RequestContext` from raw HTTP headers and the + session resolved by the server. ``headers`` is the email.message-style + ``http.client.HTTPMessage`` exposed by :class:`BaseHTTPRequestHandler`.""" + cookies = parse_cookies(headers.get("Cookie")) + return RequestContext( + session=session, + csrf_cookie=cookies.get(CSRF_COOKIE), + csrf_header=headers.get(CSRF_HEADER), + ) diff --git a/py_src/taskito/dashboard/routes.py b/py_src/taskito/dashboard/routes.py index 1b45981..7034250 100644 --- a/py_src/taskito/dashboard/routes.py +++ b/py_src/taskito/dashboard/routes.py @@ -4,6 +4,17 @@ JSON-serializable data. Handlers may raise :class:`~taskito.dashboard.errors._BadRequest` (→ 400) or :class:`~taskito.dashboard.errors._NotFound` (→ 404). + +Authentication and authorization: + +- ``PUBLIC_PATHS`` — exact paths that bypass auth entirely. Used for the + setup/login/status endpoints, health checks, and Prometheus metrics. +- Routes outside ``PUBLIC_PATHS`` require a valid session cookie when at + least one user exists in the auth store. Without users, the server + returns ``503 setup_required`` for every API route so the SPA can show + the setup flow. +- State-changing routes (POST/PUT/DELETE) additionally require a valid + CSRF token. Login and setup are exempt because no session exists yet. """ from __future__ import annotations @@ -11,6 +22,14 @@ import re from typing import Any +from taskito.dashboard.handlers.auth import ( + handle_auth_status, + handle_change_password, + handle_login, + handle_logout, + handle_setup, + handle_whoami, +) from taskito.dashboard.handlers.dead_letters import _handle_dead_letters from taskito.dashboard.handlers.jobs import ( _handle_get_job, @@ -19,6 +38,22 @@ ) from taskito.dashboard.handlers.logs import _handle_logs from taskito.dashboard.handlers.metrics import _handle_metrics, _handle_metrics_timeseries +from taskito.dashboard.handlers.middleware import ( + handle_delete_task_middleware, + handle_get_task_middleware, + handle_list_middleware, + handle_put_task_middleware, +) +from taskito.dashboard.handlers.overrides import ( + handle_delete_queue_override, + handle_delete_task_override, + handle_get_queue_override, + handle_get_task_override, + handle_list_queues, + handle_list_tasks, + handle_put_queue_override, + handle_put_task_override, +) from taskito.dashboard.handlers.queues import _handle_stats_queues from taskito.dashboard.handlers.scaler import build_scaler_response from taskito.dashboard.handlers.settings import ( @@ -27,6 +62,58 @@ _handle_list_settings, _handle_set_setting, ) +from taskito.dashboard.handlers.webhook_deliveries import ( + handle_get_delivery, + handle_list_deliveries, + handle_replay_delivery, +) +from taskito.dashboard.handlers.webhooks import ( + handle_create_webhook, + handle_delete_webhook, + handle_get_webhook, + handle_list_event_types, + handle_list_webhooks, + handle_rotate_secret, + handle_test_webhook, + handle_update_webhook, +) + +# ── Auth-exempt paths ────────────────────────────────────────────────── +# +# These bypass the session check. Static SPA files are also exempt but +# they are served outside the API dispatcher. +PUBLIC_PATHS: frozenset[str] = frozenset( + { + "/api/auth/status", + "/api/auth/login", + "/api/auth/setup", + "/api/auth/providers", + "/health", + "/readiness", + "/metrics", + } +) + +# Path prefixes that bypass auth — used by the OAuth flow whose paths +# contain a provider slot in the URL (e.g. ``/api/auth/oauth/start/google``). +PUBLIC_PATH_PREFIXES: tuple[str, ...] = ( + "/api/auth/oauth/start/", + "/api/auth/oauth/callback/", +) + + +def is_public_path(path: str) -> bool: + """Whether ``path`` should bypass the session/CSRF gate.""" + return path in PUBLIC_PATHS or any(path.startswith(p) for p in PUBLIC_PATH_PREFIXES) + + +# Paths handled directly by the server (live outside the regular dispatch +# tables because they take a RequestContext as well as the queue). +AUTH_CONTEXT_GET_PATHS: frozenset[str] = frozenset({"/api/auth/whoami"}) +AUTH_CONTEXT_POST_PATHS: frozenset[str] = frozenset( + {"/api/auth/logout", "/api/auth/change-password"} +) + # ── Exact-match GET routes: path → handler(queue, qs) → JSON data ── GET_ROUTES: dict[str, Any] = { @@ -45,6 +132,12 @@ "/api/stats/queues": _handle_stats_queues, "/api/scaler": lambda q, qs: build_scaler_response(q, queue_name=qs.get("queue", [None])[0]), "/api/settings": _handle_list_settings, + "/api/auth/status": handle_auth_status, + "/api/webhooks": handle_list_webhooks, + "/api/event-types": handle_list_event_types, + "/api/tasks": handle_list_tasks, + "/api/queues": handle_list_queues, + "/api/middleware": handle_list_middleware, } # ── Parameterized GET routes: regex → handler(queue, qs, captured_id) ── @@ -59,6 +152,22 @@ (re.compile(r"^/api/jobs/([^/]+)/dag$"), lambda q, qs, jid: q.job_dag(jid)), (re.compile(r"^/api/jobs/([^/]+)$"), _handle_get_job), (re.compile(r"^/api/settings/(.+)$"), _handle_get_setting), + ( + re.compile(r"^/api/webhooks/([^/]+)/deliveries$"), + handle_list_deliveries, + ), + (re.compile(r"^/api/webhooks/([^/]+)$"), handle_get_webhook), + (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_get_task_override), + (re.compile(r"^/api/queues/([^/]+)/override$"), handle_get_queue_override), + (re.compile(r"^/api/tasks/([^/]+)/middleware$"), handle_get_task_middleware), +] + +# GET routes with 2 captured groups (handler signature: queue, qs, (g1, g2)) +GET_PARAM2_ROUTES: list[tuple[re.Pattern, Any]] = [ + ( + re.compile(r"^/api/webhooks/([^/]+)/deliveries/([^/]+)$"), + handle_get_delivery, + ), ] # ── Exact-match POST routes: path → handler(queue) → JSON data ── @@ -66,6 +175,28 @@ "/api/dead-letters/purge": lambda q: {"purged": q.purge_dead(0)}, } +# Exact-match POST routes that take a body (path → handler(queue, body)) +POST_BODY_ROUTES: dict[str, Any] = { + "/api/auth/login": handle_login, + "/api/auth/setup": handle_setup, + "/api/webhooks": handle_create_webhook, +} + +# Auth-context POST routes: path → handler(queue, ctx) — no body +POST_CTX_ROUTES: dict[str, Any] = { + "/api/auth/logout": handle_logout, +} + +# Auth-context POST routes with body: path → handler(queue, body, ctx) +POST_CTX_BODY_ROUTES: dict[str, Any] = { + "/api/auth/change-password": handle_change_password, +} + +# Auth-context GET routes: path → handler(queue, ctx) +GET_CTX_ROUTES: dict[str, Any] = { + "/api/auth/whoami": handle_whoami, +} + # ── Parameterized POST routes: regex → handler(queue, captured_id) ── POST_PARAM_ROUTES: list[tuple[re.Pattern, Any]] = [ ( @@ -82,14 +213,54 @@ re.compile(r"^/api/queues/([^/]+)/resume$"), lambda q, n: (q.resume(n), {"resumed": n})[1], ), + (re.compile(r"^/api/webhooks/([^/]+)/test$"), handle_test_webhook), + (re.compile(r"^/api/webhooks/([^/]+)/rotate-secret$"), handle_rotate_secret), +] + +# Routes with two captures (sub_id + delivery_id) — handled by the POST +# dispatcher when patterns yield 2 groups. +POST_PARAM2_ROUTES: list[tuple[re.Pattern, Any]] = [ + ( + re.compile(r"^/api/webhooks/([^/]+)/deliveries/([^/]+)/replay$"), + handle_replay_delivery, + ), ] # ── Parameterized PUT routes: regex → handler(queue, body, captured_id) ── PUT_PARAM_ROUTES: list[tuple[re.Pattern, Any]] = [ (re.compile(r"^/api/settings/(.+)$"), _handle_set_setting), + (re.compile(r"^/api/webhooks/([^/]+)$"), handle_update_webhook), + (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_put_task_override), + (re.compile(r"^/api/queues/([^/]+)/override$"), handle_put_queue_override), +] + +# PUT routes with 2 captured groups (handler signature: queue, body, (g1, g2)) +PUT_PARAM2_ROUTES: list[tuple[re.Pattern, Any]] = [ + ( + re.compile(r"^/api/tasks/([^/]+)/middleware/([^/]+)$"), + handle_put_task_middleware, + ), ] # ── Parameterized DELETE routes: regex → handler(queue, captured_id) ── DELETE_PARAM_ROUTES: list[tuple[re.Pattern, Any]] = [ (re.compile(r"^/api/settings/(.+)$"), _handle_delete_setting), + (re.compile(r"^/api/webhooks/([^/]+)$"), handle_delete_webhook), + (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_delete_task_override), + (re.compile(r"^/api/queues/([^/]+)/override$"), handle_delete_queue_override), + (re.compile(r"^/api/tasks/([^/]+)/middleware$"), handle_delete_task_middleware), ] + + +def is_state_changing_method(method: str) -> bool: + """POST/PUT/DELETE/PATCH all require a CSRF token.""" + return method in {"POST", "PUT", "DELETE", "PATCH"} + + +def is_csrf_exempt(path: str) -> bool: + """Login and setup happen before a session exists, so they're CSRF-exempt. + + Every other state-changing endpoint requires a valid CSRF token even + though the session cookie is enforced — defense in depth. + """ + return path in {"/api/auth/login", "/api/auth/setup"} diff --git a/py_src/taskito/dashboard/server.py b/py_src/taskito/dashboard/server.py index 2b4f1d9..f5d239a 100644 --- a/py_src/taskito/dashboard/server.py +++ b/py_src/taskito/dashboard/server.py @@ -1,4 +1,12 @@ -"""HTTP server that wires routes to a Queue instance and serves the SPA.""" +"""HTTP server that wires routes to a Queue instance and serves the SPA. + +The server enforces dashboard authentication when at least one user has been +registered with :class:`taskito.dashboard.auth.AuthStore`. Until the first +user is created, all API routes return ``503 setup_required`` so the SPA can +guide the operator through one-time setup. ``TASKITO_DASHBOARD_ADMIN_USER`` / +``TASKITO_DASHBOARD_ADMIN_PASSWORD`` environment variables bootstrap a user +idempotently on server start. +""" from __future__ import annotations @@ -6,16 +14,49 @@ import logging from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import TYPE_CHECKING, Any -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, unquote, urlparse +from taskito.dashboard.auth import ( + DEFAULT_SESSION_TTL_SECONDS, + AuthStore, + bootstrap_admin_from_env, +) from taskito.dashboard.errors import _BadRequest, _NotFound +from taskito.dashboard.handlers.oauth import ( + OAuthRedirect, + handle_providers, +) +from taskito.dashboard.handlers.oauth import ( + handle_callback as handle_oauth_callback, +) +from taskito.dashboard.handlers.oauth import ( + handle_start as handle_oauth_start, +) +from taskito.dashboard.request_context import ( + CSRF_COOKIE, + SESSION_COOKIE, + RequestContext, + build_context, +) from taskito.dashboard.routes import ( + AUTH_CONTEXT_GET_PATHS, + AUTH_CONTEXT_POST_PATHS, DELETE_PARAM_ROUTES, + GET_CTX_ROUTES, + GET_PARAM2_ROUTES, GET_PARAM_ROUTES, GET_ROUTES, + POST_BODY_ROUTES, + POST_CTX_BODY_ROUTES, + POST_CTX_ROUTES, + POST_PARAM2_ROUTES, POST_PARAM_ROUTES, POST_ROUTES, + PUT_PARAM2_ROUTES, PUT_PARAM_ROUTES, + is_csrf_exempt, + is_public_path, + is_state_changing_method, ) from taskito.dashboard.static import ( IMMUTABLE_PREFIX, @@ -28,6 +69,7 @@ if TYPE_CHECKING: from taskito.app import Queue + from taskito.dashboard.oauth.flow import OAuthFlow logger = logging.getLogger("taskito.dashboard") @@ -40,20 +82,27 @@ _LOG_UNSAFE_CHARS[127] = None _LOG_PATH_MAX = 256 -# Hard cap on the request body we'll parse for PUT requests. Settings and -# other config writes are tiny; anything larger is almost certainly an -# attacker probing for memory exhaustion. +# Hard cap on the request body we'll parse for PUT/POST requests. _MAX_BODY_BYTES = 1 * 1024 * 1024 # 1 MiB def _safe_path(path: str) -> str: - """Return ``path`` with control characters stripped and length capped. + """Return ``path`` with control characters stripped and length capped.""" + return path.translate(_LOG_UNSAFE_CHARS)[:_LOG_PATH_MAX] + + +def _session_cookies(session: Any) -> tuple[str, ...]: + """Build the standard ``Set-Cookie`` headers for a freshly-created session. - Used when including the request URI in log messages — never trust - user-controlled strings to be free of CR/LF/null bytes that would let - an attacker forge fake log lines. + Used by both password login and OAuth callback so the cookie shape + stays in lockstep across login methods. """ - return path.translate(_LOG_UNSAFE_CHARS)[:_LOG_PATH_MAX] + return ( + f"{SESSION_COOKIE}={session.token}; HttpOnly; SameSite=Strict; Path=/; " + f"Max-Age={DEFAULT_SESSION_TTL_SECONDS}", + f"{CSRF_COOKIE}={session.csrf_token}; SameSite=Strict; Path=/; " + f"Max-Age={DEFAULT_SESSION_TTL_SECONDS}", + ) def serve_dashboard( @@ -62,6 +111,7 @@ def serve_dashboard( port: int = 8080, *, static_assets: StaticAssets | None = None, + oauth_flow: OAuthFlow | None = None, ) -> None: """Start the dashboard HTTP server (blocking). @@ -72,8 +122,14 @@ def serve_dashboard( static_assets: Override the default SPA asset source. Mainly a test seam; downstream embedders can also use it to ship a customised dashboard bundle from a different location. + oauth_flow: Configured :class:`OAuthFlow` to enable social login. + When unset, OAuth endpoints respond 404 and the providers list + is empty. """ - handler = _make_handler(queue, static_assets=static_assets) + bootstrap_admin_from_env(queue) + if oauth_flow is None: + oauth_flow = _build_oauth_flow_from_env(queue) + handler = _make_handler(queue, static_assets=static_assets, oauth_flow=oauth_flow) server = ThreadingHTTPServer((host, port), handler) print(f"taskito dashboard → http://{host}:{port}") print("Press Ctrl+C to stop") @@ -86,17 +142,37 @@ def serve_dashboard( server.server_close() -def _make_handler(queue: Queue, *, static_assets: StaticAssets | None = None) -> type: - """Create a request handler class bound to the given queue. +def _build_oauth_flow_from_env(queue: Queue) -> OAuthFlow | None: + """Build :class:`OAuthFlow` from environment variables, or ``None``. - Args: - queue: Queue inspected by the JSON routes. - static_assets: SPA asset source. Defaults to the package-bundled - assets resolved once per process; tests inject their own. + Failures in the env-var config are logged and treated as "OAuth not + configured" — the dashboard still starts with password auth only. """ + try: + from taskito.dashboard.oauth.config import from_env as oauth_from_env + from taskito.dashboard.oauth.flow import OAuthFlow + + config = oauth_from_env() + if config is None or not config.is_enabled: + return None + return OAuthFlow(queue, config) + except Exception: + logger.exception("OAuth env-var configuration is invalid; OAuth disabled") + return None + + +def _make_handler( + queue: Queue, + *, + static_assets: StaticAssets | None = None, + oauth_flow: OAuthFlow | None = None, +) -> type: + """Create a request handler class bound to the given queue.""" assets = static_assets if static_assets is not None else _get_default_assets() class DashboardHandler(BaseHTTPRequestHandler): + # ── Entry points ──────────────────────────────────────────── + def do_GET(self) -> None: try: self._handle_get() @@ -106,35 +182,87 @@ def do_GET(self) -> None: logger.exception("Error handling GET %s", _safe_path(self.path)) self._json_response({"error": "Internal server error"}, status=500) + def do_POST(self) -> None: + try: + self._handle_post() + except BrokenPipeError: + pass + except Exception: + logger.exception("Error handling POST %s", _safe_path(self.path)) + self._json_response({"error": "Internal server error"}, status=500) + + def do_PUT(self) -> None: + try: + self._handle_put() + except BrokenPipeError: + pass + except Exception: + logger.exception("Error handling PUT %s", _safe_path(self.path)) + self._json_response({"error": "Internal server error"}, status=500) + + def do_DELETE(self) -> None: + try: + self._handle_delete() + except BrokenPipeError: + pass + except Exception: + logger.exception("Error handling DELETE %s", _safe_path(self.path)) + self._json_response({"error": "Internal server error"}, status=500) + + # ── Per-method dispatchers ────────────────────────────────── + def _handle_get(self) -> None: parsed = urlparse(self.path) path = parsed.path qs = parse_qs(parsed.query) - # Exact-match API routes + if not path.startswith("/api/") and path not in {"/health", "/readiness", "/metrics"}: + self._serve_spa(path) + return + + ctx, denied = self._authorize(path, "GET") + if denied: + return + + # ── OAuth flow paths (public, redirect-emitting) ──────── + if path == "/api/auth/providers": + self._dispatch_with_handler(handle_providers, lambda h: h(queue, qs, oauth_flow)) + return + if path.startswith("/api/auth/oauth/start/"): + slot = unquote(path[len("/api/auth/oauth/start/") :]) + self._dispatch_oauth_redirect(handle_oauth_start, queue, qs, slot, oauth_flow) + return + if path.startswith("/api/auth/oauth/callback/"): + slot = unquote(path[len("/api/auth/oauth/callback/") :]) + self._dispatch_oauth_redirect(handle_oauth_callback, queue, qs, slot, oauth_flow) + return + + if path in AUTH_CONTEXT_GET_PATHS: + self._dispatch_with_handler(GET_CTX_ROUTES.get(path), lambda h: h(queue, ctx)) + return + handler = GET_ROUTES.get(path) if handler: - try: - self._json_response(handler(queue, qs)) - except _BadRequest as e: - self._json_response({"error": e.message}, status=400) - except _NotFound as e: - self._json_response({"error": e.message}, status=404) + self._dispatch_with_handler(handler, lambda h: h(queue, qs)) return - # Parameterized API routes for pattern, param_handler in GET_PARAM_ROUTES: m = pattern.match(path) if m: - try: - self._json_response(param_handler(queue, qs, m.group(1))) - except _BadRequest as e: - self._json_response({"error": e.message}, status=400) - except _NotFound as e: - self._json_response({"error": e.message}, status=404) + g1 = unquote(m.group(1)) + self._dispatch_with_handler(param_handler, lambda h, g1=g1: h(queue, qs, g1)) + return + + for pattern, param_handler in GET_PARAM2_ROUTES: + m = pattern.match(path) + if m: + g1, g2 = unquote(m.group(1)), unquote(m.group(2)) + self._dispatch_with_handler( + param_handler, + lambda h, g1=g1, g2=g2: h(queue, qs, (g1, g2)), + ) return - # Non-JSON routes if path == "/health": self._json_response(check_health()) elif path == "/readiness": @@ -142,90 +270,259 @@ def _handle_get(self) -> None: elif path == "/metrics": self._serve_prometheus_metrics() else: - self._serve_spa(path) - - def do_POST(self) -> None: - try: - self._handle_post() - except BrokenPipeError: - pass - except Exception: - logger.exception("Error handling POST %s", _safe_path(self.path)) - self._json_response({"error": "Internal server error"}, status=500) + self._json_response({"error": "Not found"}, status=404) def _handle_post(self) -> None: path = urlparse(self.path).path + ctx, denied = self._authorize(path, "POST") + if denied: + return + + if path == "/api/auth/login": + body = self._read_json_body() + if body is None: + return + self._dispatch_with_handler( + POST_BODY_ROUTES[path], + lambda h: h(queue, body), + on_success=lambda resp: self._set_login_cookies(resp), + ) + return + + if path == "/api/auth/setup": + body = self._read_json_body() + if body is None: + return + self._dispatch_with_handler(POST_BODY_ROUTES[path], lambda h: h(queue, body)) + return + + if path in AUTH_CONTEXT_POST_PATHS: + if path in POST_CTX_BODY_ROUTES: + body = self._read_json_body() + if body is None: + return + self._dispatch_with_handler( + POST_CTX_BODY_ROUTES[path], + lambda h: h(queue, body, ctx), + ) + else: + self._dispatch_with_handler( + POST_CTX_ROUTES[path], + lambda h: h(queue, ctx), + on_success=lambda _resp: ( + self._clear_login_cookies() if path == "/api/auth/logout" else None + ), + ) + return - # Exact-match POST routes handler = POST_ROUTES.get(path) if handler: - self._json_response(handler(queue)) + self._dispatch_with_handler(handler, lambda h: h(queue)) + return + + body_handler = POST_BODY_ROUTES.get(path) + if body_handler: + body = self._read_json_body() + if body is None: + return + self._dispatch_with_handler(body_handler, lambda h, body=body: h(queue, body)) return - # Parameterized POST routes for pattern, param_handler in POST_PARAM_ROUTES: m = pattern.match(path) if m: - self._json_response(param_handler(queue, m.group(1))) + g1 = unquote(m.group(1)) + self._dispatch_with_handler(param_handler, lambda h, g1=g1: h(queue, g1)) return - self._json_response({"error": "Not found"}, status=404) + for pattern, param_handler in POST_PARAM2_ROUTES: + m = pattern.match(path) + if m: + g1, g2 = unquote(m.group(1)), unquote(m.group(2)) + self._dispatch_with_handler( + param_handler, + lambda h, g1=g1, g2=g2: h(queue, (g1, g2)), + ) + return - def do_PUT(self) -> None: - try: - self._handle_put() - except BrokenPipeError: - pass - except Exception: - logger.exception("Error handling PUT %s", _safe_path(self.path)) - self._json_response({"error": "Internal server error"}, status=500) + self._json_response({"error": "Not found"}, status=404) def _handle_put(self) -> None: path = urlparse(self.path).path + _ctx, denied = self._authorize(path, "PUT") + if denied: + return + for pattern, param_handler in PUT_PARAM_ROUTES: m = pattern.match(path) if m: body = self._read_json_body() if body is None: return - try: - self._json_response(param_handler(queue, body, m.group(1))) - except _BadRequest as e: - self._json_response({"error": e.message}, status=400) - except _NotFound as e: - self._json_response({"error": e.message}, status=404) + g1 = unquote(m.group(1)) + self._dispatch_with_handler( + param_handler, lambda h, g1=g1, body=body: h(queue, body, g1) + ) + return + for pattern, param_handler in PUT_PARAM2_ROUTES: + m = pattern.match(path) + if m: + body = self._read_json_body() + if body is None: + return + g1, g2 = unquote(m.group(1)), unquote(m.group(2)) + self._dispatch_with_handler( + param_handler, + lambda h, g1=g1, g2=g2, body=body: h(queue, body, (g1, g2)), + ) return self._json_response({"error": "Not found"}, status=404) - def do_DELETE(self) -> None: - try: - self._handle_delete() - except BrokenPipeError: - pass - except Exception: - logger.exception("Error handling DELETE %s", _safe_path(self.path)) - self._json_response({"error": "Internal server error"}, status=500) - def _handle_delete(self) -> None: path = urlparse(self.path).path + _ctx, denied = self._authorize(path, "DELETE") + if denied: + return + for pattern, param_handler in DELETE_PARAM_ROUTES: m = pattern.match(path) if m: - try: - self._json_response(param_handler(queue, m.group(1))) - except _BadRequest as e: - self._json_response({"error": e.message}, status=400) - except _NotFound as e: - self._json_response({"error": e.message}, status=404) + g1 = unquote(m.group(1)) + self._dispatch_with_handler(param_handler, lambda h, g1=g1: h(queue, g1)) return self._json_response({"error": "Not found"}, status=404) - def _read_json_body(self) -> Any | None: - """Read and parse the request body as JSON. + # ── Auth gating ───────────────────────────────────────────── + + def _authorize(self, path: str, method: str) -> tuple[RequestContext, bool]: + """Return ``(ctx, denied)``. When ``denied`` is true a response + has already been written and the caller must return.""" + ctx = self._build_context() + + # Setup-required short-circuit: before the first user is created + # every API endpoint (except the public ones) returns 503 so the + # SPA can show the setup page. + if ( + path.startswith("/api/") + and not is_public_path(path) + and AuthStore(queue).count_users() == 0 + ): + self._json_response({"error": "setup_required"}, status=503) + return ctx, True + + if is_public_path(path) or not path.startswith("/api/"): + # CSRF still applies to public state-changing routes that are + # NOT exempt — but login/setup are the only public POSTs and + # they're exempt. + return ctx, False + + if not ctx.is_authenticated: + self._json_response({"error": "not_authenticated"}, status=401) + return ctx, True + + if ( + is_state_changing_method(method) + and not is_csrf_exempt(path) + and not ctx.csrf_valid() + ): + self._json_response({"error": "csrf_failed"}, status=403) + return ctx, True + + return ctx, False + + def _build_context(self) -> RequestContext: + cookies_header = self.headers.get("Cookie") + session = None + if cookies_header: + from taskito.dashboard.request_context import parse_cookies + + cookies = parse_cookies(cookies_header) + token = cookies.get(SESSION_COOKIE) + if token: + session = AuthStore(queue).get_session(token) + return build_context(self.headers, session) + + # ── Cookie management ─────────────────────────────────────── + + def _set_login_cookies(self, response: dict[str, Any]) -> None: + """Set HttpOnly session cookie and CSRF cookie on a login response.""" + session = response.get("session") or {} + token = session.get("token") + csrf = session.get("csrf_token") + if not token or not csrf: + return + # 24-hour Max-Age matches the session TTL. + self._extra_set_cookies = [ + f"{SESSION_COOKIE}={token}; HttpOnly; SameSite=Strict; Path=/; " + f"Max-Age={DEFAULT_SESSION_TTL_SECONDS}", + f"{CSRF_COOKIE}={csrf}; SameSite=Strict; Path=/; " + f"Max-Age={DEFAULT_SESSION_TTL_SECONDS}", + ] + # Don't leak the raw token in the JSON body — the cookie holds it. + response["session"] = {k: v for k, v in session.items() if k != "token"} + + def _clear_login_cookies(self) -> None: + self._extra_set_cookies = [ + f"{SESSION_COOKIE}=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0", + f"{CSRF_COOKIE}=; SameSite=Strict; Path=/; Max-Age=0", + ] + + # ── Dispatch helper ───────────────────────────────────────── + + def _dispatch_with_handler( + self, + handler: Any, + invoke: Any, + *, + on_success: Any | None = None, + ) -> None: + if handler is None: + self._json_response({"error": "Not found"}, status=404) + return + try: + result = invoke(handler) + except _BadRequest as e: + self._json_response({"error": e.message}, status=400) + return + except _NotFound as e: + self._json_response({"error": e.message}, status=404) + return + if on_success is not None: + on_success(result) + self._json_response(result) + + def _dispatch_oauth_redirect( + self, + handler: Any, + queue: Any, + qs: dict[str, list[str]], + slot: str, + flow: OAuthFlow | None, + ) -> None: + try: + redirect: OAuthRedirect = handler(queue, qs, slot, flow) + except _BadRequest as e: + self._json_response({"error": e.message}, status=400) + return + except _NotFound as e: + self._json_response({"error": e.message}, status=404) + return + cookies: list[str] = [] + if redirect.session is not None: + cookies = list(_session_cookies(redirect.session)) + self.send_response(redirect.status) + self.send_header("Location", redirect.url) + self.send_header("Content-Length", "0") + self.send_header("Cache-Control", "no-store") + for cookie in cookies: + self.send_header("Set-Cookie", cookie) + self.end_headers() - Returns ``None`` after writing the appropriate error response - (400/413) when the body is missing, malformed, or oversized. - """ + # ── Body / response helpers ───────────────────────────────── + + def _read_json_body(self) -> Any | None: + """Read and parse the request body as JSON. Returns ``None`` after + writing the appropriate error response (400/413).""" length_header = self.headers.get("Content-Length") try: length = int(length_header) if length_header is not None else 0 @@ -252,7 +549,10 @@ def _json_response(self, data: Any, status: int = 200) -> None: self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(body))) - self.send_header("Access-Control-Allow-Origin", "*") + # Cookies are first-party only — no wildcard CORS. The SPA is + # served from the same origin as the API. + for cookie in getattr(self, "_extra_set_cookies", ()): + self.send_header("Set-Cookie", cookie) self.end_headers() self.wfile.write(body) @@ -270,9 +570,6 @@ def _serve_prometheus_metrics(self) -> None: self._json_response({"error": "prometheus-client not installed"}, status=501) def _serve_spa(self, req_path: str) -> None: - """Serve a static asset from the SPA bundle, falling back to - ``index.html`` so client-side routes deep-link correctly. - """ if not assets.available: self._serve_missing_assets() return @@ -317,3 +614,6 @@ def log_message(self, format: str, *args: Any) -> None: pass return DashboardHandler + + +__all__ = ["_make_handler", "serve_dashboard"] diff --git a/py_src/taskito/dashboard/url_safety.py b/py_src/taskito/dashboard/url_safety.py new file mode 100644 index 0000000..a0db4e3 --- /dev/null +++ b/py_src/taskito/dashboard/url_safety.py @@ -0,0 +1,115 @@ +"""Outbound URL safety checks for dashboard-configured webhooks. + +We refuse to deliver to loopback, link-local, and RFC1918 addresses by +default — an operator who can write to ``dashboard_settings`` could +otherwise turn the worker into an SSRF proxy. The ``TASKITO_WEBHOOKS_ALLOW_PRIVATE`` +environment variable disables the guard for local development. +""" + +from __future__ import annotations + +import ipaddress +import os +import socket +import urllib.parse + +# Hostnames that always resolve to loopback / never-leave-this-host regardless +# of DNS, but might be missed by a strict ``ipaddress.is_private`` check. +_BLOCKED_HOSTNAME_SUFFIXES = ( + ".localhost", + ".local", + ".internal", + ".intranet", + ".lan", + ".private", +) +_BLOCKED_HOSTNAMES = frozenset( + {"localhost", "localhost.localdomain", "ip6-localhost", "ip6-loopback"} +) + +_ALLOW_ENV_VAR = "TASKITO_WEBHOOKS_ALLOW_PRIVATE" + + +class UnsafeWebhookUrl(ValueError): + """Raised when a webhook URL targets an address we won't deliver to.""" + + +def _is_private_ip(ip: str) -> bool: + try: + address = ipaddress.ip_address(ip) + except ValueError: + return False + return ( + address.is_private + or address.is_loopback + or address.is_link_local + or address.is_multicast + or address.is_reserved + or address.is_unspecified + ) + + +def _hostname_is_blocked(hostname: str) -> bool: + lowered = hostname.lower() + if lowered in _BLOCKED_HOSTNAMES: + return True + return any(lowered.endswith(suffix) for suffix in _BLOCKED_HOSTNAME_SUFFIXES) + + +def validate_webhook_url(url: str) -> None: + """Reject ``url`` if it targets a private/loopback/link-local destination. + + Set ``TASKITO_WEBHOOKS_ALLOW_PRIVATE=1`` in the environment to disable + the guard (intended for local development against ``http://localhost``). + + Raises: + UnsafeWebhookUrl: on scheme other than http/https, missing host, or + a host that resolves to a private/loopback IP. + """ + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ("http", "https"): + raise UnsafeWebhookUrl(f"URL scheme must be http or https, got {parsed.scheme!r}") + if not parsed.hostname: + raise UnsafeWebhookUrl("URL must include a hostname") + + if os.environ.get(_ALLOW_ENV_VAR): + return + + hostname = parsed.hostname + if _hostname_is_blocked(hostname): + raise UnsafeWebhookUrl(f"URL host {hostname!r} resolves to a private network") + + # Literal IPs are checked directly; named hosts are resolved. + try: + ipaddress.ip_address(hostname) + addresses: list[str] = [hostname] + except ValueError: + try: + addresses = [ + str(info[4][0]) + for info in socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM) + ] + except OSError as e: + raise UnsafeWebhookUrl(f"could not resolve {hostname!r}: {e}") from None + + for ip in addresses: + if _is_private_ip(ip): + raise UnsafeWebhookUrl(f"URL host {hostname!r} resolves to private address {ip}") + + +def is_safe_redirect(path: str | None) -> bool: + """Whether ``path`` is safe to use as a post-login same-origin redirect target. + + Accepts only relative paths rooted at ``/``. Rejects absolute URLs + (``http://evil.com/x``), protocol-relative URLs (``//evil.com/x``), + and anything without a leading slash. Empty / ``None`` is rejected so + the caller can fall back to a default explicitly. + """ + if not path: + return False + if not path.startswith("/"): + return False + if path.startswith("//") or path.startswith("/\\"): + return False + parsed = urllib.parse.urlparse(path) + return not (parsed.scheme or parsed.netloc) diff --git a/py_src/taskito/dashboard/webhook_store.py b/py_src/taskito/dashboard/webhook_store.py new file mode 100644 index 0000000..7793d69 --- /dev/null +++ b/py_src/taskito/dashboard/webhook_store.py @@ -0,0 +1,204 @@ +"""Persistent webhook subscription store. + +Webhook subscriptions are stored as a JSON list under the +``webhooks:subscriptions`` key in the ``dashboard_settings`` table. This +gives us cross-backend persistence (SQLite, Postgres, Redis) without +adding new tables, while keeping the data structured enough for the +dashboard CRUD UI. + +Each entry is fully described by :class:`WebhookSubscription`. The +``secret`` field stores the HMAC signing secret in plaintext (the +storage backend is already trusted with everything else taskito +persists); the dashboard API NEVER returns the raw secret — only a +``has_secret`` indicator. Use :meth:`WebhookSubscriptionStore.rotate_secret` +to generate a new value and surface it once on rotation. +""" + +from __future__ import annotations + +import json +import logging +import secrets +import time +import uuid +from dataclasses import asdict, dataclass, field, replace +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from taskito.app import Queue + + +SUBSCRIPTIONS_KEY = "webhooks:subscriptions" +SECRET_BYTES = 32 + +logger = logging.getLogger("taskito.dashboard.webhooks") + + +@dataclass(frozen=True) +class WebhookSubscription: + """A single persisted webhook subscription.""" + + id: str + url: str + events: list[str] = field(default_factory=list) # empty = all + task_filter: list[str] | None = None # None = all tasks + headers: dict[str, str] = field(default_factory=dict) + secret: str | None = None + max_retries: int = 3 + timeout_seconds: float = 10.0 + retry_backoff: float = 2.0 + enabled: bool = True + description: str | None = None + created_at: int = 0 + updated_at: int = 0 + + def matches(self, event: str, task_name: str | None) -> bool: + """Return True iff this subscription should fire for the event.""" + if not self.enabled: + return False + if self.events and event not in self.events: + return False + return not (self.task_filter is not None and task_name not in self.task_filter) + + +def _new_id() -> str: + return uuid.uuid4().hex + + +def _now() -> int: + return int(time.time()) + + +def generate_secret() -> str: + """Return a fresh URL-safe webhook signing secret.""" + return secrets.token_urlsafe(SECRET_BYTES) + + +class WebhookSubscriptionStore: + """CRUD for webhook subscriptions backed by ``Queue``'s settings store.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + # ── Internal load/save ─────────────────────────────────────── + + def _load_raw(self) -> list[dict[str, Any]]: + raw = self._queue.get_setting(SUBSCRIPTIONS_KEY) + if not raw: + return [] + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("webhooks:subscriptions is not valid JSON; treating as empty") + return [] + return data if isinstance(data, list) else [] + + def _save_raw(self, items: list[dict[str, Any]]) -> None: + self._queue.set_setting(SUBSCRIPTIONS_KEY, json.dumps(items, separators=(",", ":"))) + + @staticmethod + def _row_to_subscription(row: dict[str, Any]) -> WebhookSubscription: + return WebhookSubscription( + id=str(row["id"]), + url=str(row["url"]), + events=list(row.get("events") or []), + task_filter=(list(row["task_filter"]) if row.get("task_filter") is not None else None), + headers=dict(row.get("headers") or {}), + secret=row.get("secret"), + max_retries=int(row.get("max_retries", 3)), + timeout_seconds=float(row.get("timeout_seconds", 10.0)), + retry_backoff=float(row.get("retry_backoff", 2.0)), + enabled=bool(row.get("enabled", True)), + description=row.get("description"), + created_at=int(row.get("created_at", 0)), + updated_at=int(row.get("updated_at", 0)), + ) + + # ── Public API ─────────────────────────────────────────────── + + def list_all(self) -> list[WebhookSubscription]: + return [self._row_to_subscription(r) for r in self._load_raw()] + + def get(self, subscription_id: str) -> WebhookSubscription | None: + for row in self._load_raw(): + if row.get("id") == subscription_id: + return self._row_to_subscription(row) + return None + + def create( + self, + *, + url: str, + events: list[str] | None = None, + task_filter: list[str] | None = None, + headers: dict[str, str] | None = None, + secret: str | None = None, + max_retries: int = 3, + timeout_seconds: float = 10.0, + retry_backoff: float = 2.0, + enabled: bool = True, + description: str | None = None, + ) -> WebhookSubscription: + now = _now() + sub = WebhookSubscription( + id=_new_id(), + url=url, + events=list(events or []), + task_filter=list(task_filter) if task_filter is not None else None, + headers=dict(headers or {}), + secret=secret, + max_retries=max_retries, + timeout_seconds=timeout_seconds, + retry_backoff=retry_backoff, + enabled=enabled, + description=description, + created_at=now, + updated_at=now, + ) + rows = self._load_raw() + rows.append(asdict(sub)) + self._save_raw(rows) + return sub + + def update(self, subscription_id: str, **changes: Any) -> WebhookSubscription: + """Patch a subscription. Pass only the fields you want to change. + + Raises ``KeyError`` if the subscription does not exist. + """ + rows = self._load_raw() + for idx, row in enumerate(rows): + if row.get("id") != subscription_id: + continue + existing = self._row_to_subscription(row) + allowed = { + "url", + "events", + "task_filter", + "headers", + "secret", + "max_retries", + "timeout_seconds", + "retry_backoff", + "enabled", + "description", + } + patch = {k: v for k, v in changes.items() if k in allowed} + updated = replace(existing, updated_at=_now(), **patch) + rows[idx] = asdict(updated) + self._save_raw(rows) + return updated + raise KeyError(subscription_id) + + def delete(self, subscription_id: str) -> bool: + rows = self._load_raw() + remaining = [r for r in rows if r.get("id") != subscription_id] + if len(remaining) == len(rows): + return False + self._save_raw(remaining) + return True + + def rotate_secret(self, subscription_id: str) -> str: + """Generate a fresh secret for a subscription. Returns the new value.""" + secret = generate_secret() + self.update(subscription_id, secret=secret) + return secret diff --git a/py_src/taskito/middleware.py b/py_src/taskito/middleware.py index 8650641..077ff33 100644 --- a/py_src/taskito/middleware.py +++ b/py_src/taskito/middleware.py @@ -55,12 +55,20 @@ def after(self, ctx, result, error): print(f"Finished {ctx.task_name}: {status}") """ + #: Stable identifier used to refer to this middleware from the dashboard + #: when toggling it on/off per task. Defaults to the class' fully-qualified + #: name so it survives restarts. Override on a subclass to pin a + #: shorter / more user-facing name. + name: str = "" + def __init__( self, *, predicate: Predicate | Callable[..., Any] | None = None, ) -> None: self._predicate = coerce_predicate(predicate) + if not type(self).name: + type(self).name = f"{type(self).__module__}.{type(self).__qualname__}" def _should_apply(self, ctx: JobContext | None, task_name: str = "") -> bool: """Decide whether this middleware's hooks should fire for ``ctx``. diff --git a/py_src/taskito/mixins/__init__.py b/py_src/taskito/mixins/__init__.py index f9b07ba..2d54c05 100644 --- a/py_src/taskito/mixins/__init__.py +++ b/py_src/taskito/mixins/__init__.py @@ -5,7 +5,9 @@ from taskito.mixins.inspection import QueueInspectionMixin from taskito.mixins.lifecycle import QueueLifecycleMixin from taskito.mixins.locks import QueueLockMixin +from taskito.mixins.middleware_admin import QueueMiddlewareAdminMixin from taskito.mixins.operations import QueueOperationsMixin +from taskito.mixins.overrides import QueueOverridesMixin from taskito.mixins.predicates import QueuePredicateMixin from taskito.mixins.resources import QueueResourceMixin from taskito.mixins.runtime_config import QueueRuntimeConfigMixin @@ -17,7 +19,9 @@ "QueueInspectionMixin", "QueueLifecycleMixin", "QueueLockMixin", + "QueueMiddlewareAdminMixin", "QueueOperationsMixin", + "QueueOverridesMixin", "QueuePredicateMixin", "QueueResourceMixin", "QueueRuntimeConfigMixin", diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 671e9a2..e33c940 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -16,6 +16,7 @@ from taskito._taskito import PyTaskConfig from taskito.async_support.helpers import run_maybe_async from taskito.context import _clear_context, current_job +from taskito.dashboard.middleware_store import MiddlewareDisableStore from taskito.events import EventType from taskito.exceptions import TaskCancelledError from taskito.inject import Inject, _InjectAlias @@ -111,9 +112,18 @@ class QueueDecoratorMixin: _apply_dispatch_predicate: Callable[..., None] def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: - """Get the combined global + per-task middleware list.""" + """Get the combined global + per-task middleware list, minus any + middleware the operator has disabled for this task from the dashboard.""" per_task = self._task_middleware.get(task_name, []) - return self._global_middleware + per_task + chain = self._global_middleware + per_task + try: + disabled = MiddlewareDisableStore(self).get_for(task_name) # type: ignore[arg-type] + except Exception: # pragma: no cover - storage read failure is non-fatal + disabled = [] + if not disabled: + return chain + disabled_set = set(disabled) + return [mw for mw in chain if getattr(mw, "name", "") not in disabled_set] def _wrap_task( self, fn: Callable, task_name: str, soft_timeout: float | None = None diff --git a/py_src/taskito/mixins/events.py b/py_src/taskito/mixins/events.py index 936c4c1..1aa6e6c 100644 --- a/py_src/taskito/mixins/events.py +++ b/py_src/taskito/mixins/events.py @@ -5,6 +5,9 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any +from taskito.dashboard.url_safety import validate_webhook_url +from taskito.dashboard.webhook_store import WebhookSubscription, WebhookSubscriptionStore + if TYPE_CHECKING: from taskito.events import EventBus, EventType from taskito.webhooks import WebhookManager @@ -30,6 +33,8 @@ def on_event(self, event_type: EventType, callback: Callable[..., Any]) -> None: """ self._event_bus.on(event_type, callback) + # ── Webhook subscriptions (persistent) ──────────────────────── + def add_webhook( self, url: str, @@ -39,24 +44,74 @@ def add_webhook( max_retries: int = 3, timeout: float = 10.0, retry_backoff: float = 2.0, - ) -> None: + task_filter: list[str] | None = None, + description: str | None = None, + ) -> WebhookSubscription: """Register a webhook endpoint for job events. + Persisted through the dashboard settings store, so the subscription + survives restarts and is shared across every worker pointed at the + same backend. + Args: url: URL to POST event payloads to. - events: Event types to subscribe to (None = all). + events: Event types to subscribe to (``None`` = all). headers: Extra HTTP headers. - secret: HMAC-SHA256 signing secret. - max_retries: Maximum delivery attempts (default 3). - timeout: HTTP request timeout in seconds (default 10.0). - retry_backoff: Base for exponential backoff between retries (default 2.0). + secret: HMAC-SHA256 signing secret. Stored as plaintext; rotate + via :meth:`rotate_webhook_secret`. + max_retries: Maximum delivery attempts. + timeout: HTTP request timeout in seconds. + retry_backoff: Base for exponential backoff between retries. + task_filter: When set, deliver only when the event's + ``task_name`` is in this list. + description: Free-form label shown in the dashboard. + + Returns: + The persisted :class:`WebhookSubscription`. """ - self._webhook_manager.add_webhook( - url, - events, - headers, - secret, + validate_webhook_url(url) + store = WebhookSubscriptionStore(self) # type: ignore[arg-type] + sub = store.create( + url=url, + events=[e.value for e in events] if events else None, + task_filter=task_filter, + headers=headers, + secret=secret, max_retries=max_retries, - timeout=timeout, + timeout_seconds=timeout, retry_backoff=retry_backoff, + description=description, ) + self._webhook_manager.reload() + return sub + + def list_webhooks(self) -> list[WebhookSubscription]: + """Return every persisted webhook subscription.""" + return WebhookSubscriptionStore(self).list_all() # type: ignore[arg-type] + + def get_webhook(self, subscription_id: str) -> WebhookSubscription | None: + return WebhookSubscriptionStore(self).get(subscription_id) # type: ignore[arg-type] + + def update_webhook(self, subscription_id: str, **changes: Any) -> WebhookSubscription: + """Patch fields of an existing subscription. Reloads the manager.""" + if "url" in changes: + validate_webhook_url(changes["url"]) + store = WebhookSubscriptionStore(self) # type: ignore[arg-type] + updated = store.update(subscription_id, **changes) + self._webhook_manager.reload() + return updated + + def remove_webhook(self, subscription_id: str) -> bool: + """Delete a subscription. Returns ``True`` if it existed.""" + store = WebhookSubscriptionStore(self) # type: ignore[arg-type] + removed = store.delete(subscription_id) + if removed: + self._webhook_manager.reload() + return removed + + def rotate_webhook_secret(self, subscription_id: str) -> str: + """Generate a fresh signing secret. Returns the new value.""" + store = WebhookSubscriptionStore(self) # type: ignore[arg-type] + secret = store.rotate_secret(subscription_id) + self._webhook_manager.reload() + return secret diff --git a/py_src/taskito/mixins/lifecycle.py b/py_src/taskito/mixins/lifecycle.py index 874d553..9b912e7 100644 --- a/py_src/taskito/mixins/lifecycle.py +++ b/py_src/taskito/mixins/lifecycle.py @@ -16,6 +16,7 @@ import taskito from taskito._taskito import PyQueue, PyTaskConfig from taskito.context import _set_queue_ref +from taskito.dashboard.overrides_store import OverridesStore from taskito.events import EventType from taskito.log_config import configure as configure_logging from taskito.log_config import restore_asyncio_pipe_noise, silence_asyncio_pipe_noise @@ -231,7 +232,24 @@ def sighup_handler(signum: int, frame: Any) -> None: ) try: - queue_configs_json = json.dumps(self._queue_configs) if self._queue_configs else None + overrides = OverridesStore(self) # type: ignore[arg-type] + # Mutate the in-memory PyTaskConfig list so the Rust scheduler + # sees the override values; merge queue-level overrides into + # the JSON blob passed to run_worker. Paused tasks/queues get + # their pause state propagated to the existing paused_queues + # mechanism for tasks-by-queue, but per-task pause is left to + # the application-level guard in enqueue (out of scope here). + paused_tasks = overrides.apply_task_overrides(self._task_configs) + if paused_tasks: + logger.info("Paused task overrides in effect: %s", paused_tasks) + merged_queue_configs = overrides.apply_queue_overrides(self._queue_configs) + for queue_name, slot in merged_queue_configs.items(): + if slot.get("paused"): + try: + self.pause(queue_name) # type: ignore[attr-defined] + except Exception: + logger.exception("Failed to apply paused state for queue %s", queue_name) + queue_configs_json = json.dumps(merged_queue_configs) if merged_queue_configs else None self._inner.run_worker( task_registry=self._task_registry, task_configs=self._task_configs, diff --git a/py_src/taskito/mixins/middleware_admin.py b/py_src/taskito/mixins/middleware_admin.py new file mode 100644 index 0000000..dd9af80 --- /dev/null +++ b/py_src/taskito/mixins/middleware_admin.py @@ -0,0 +1,70 @@ +"""Middleware discovery and per-task disable management on :class:`Queue`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.middleware_store import MiddlewareDisableStore + +if TYPE_CHECKING: + from taskito.middleware import TaskMiddleware + + +class QueueMiddlewareAdminMixin: + """Discovery + per-task enable/disable for registered middlewares.""" + + _global_middleware: list[TaskMiddleware] + _task_middleware: dict[str, list[TaskMiddleware]] + + # ── Discovery ────────────────────────────────────────────────── + + def list_middleware(self) -> list[dict[str, Any]]: + """Return every registered middleware (global + per-task) with its + name, source ("global" or task name), and Python class path. The + ``name`` is the value the disable list keys on.""" + seen: dict[str, dict[str, Any]] = {} + for mw in self._global_middleware: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + seen.setdefault( + name, + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "scopes": [], + }, + )["scopes"].append({"kind": "global"}) + for task_name, mws in self._task_middleware.items(): + for mw in mws: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + entry = seen.setdefault( + name, + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "scopes": [], + }, + ) + entry["scopes"].append({"kind": "task", "task": task_name}) + return sorted(seen.values(), key=lambda x: x["name"]) + + # ── Disable management ───────────────────────────────────────── + + def list_middleware_disables(self) -> dict[str, list[str]]: + """Return every task that has at least one disabled middleware.""" + return MiddlewareDisableStore(self).list_all() # type: ignore[arg-type] + + def get_disabled_middleware_for(self, task_name: str) -> list[str]: + return MiddlewareDisableStore(self).get_for(task_name) # type: ignore[arg-type] + + def disable_middleware_for_task(self, task_name: str, mw_name: str) -> list[str]: + return MiddlewareDisableStore(self).set_disabled( # type: ignore[arg-type] + task_name, mw_name, disabled=True + ) + + def enable_middleware_for_task(self, task_name: str, mw_name: str) -> list[str]: + return MiddlewareDisableStore(self).set_disabled( # type: ignore[arg-type] + task_name, mw_name, disabled=False + ) + + def clear_middleware_disables(self, task_name: str) -> bool: + return MiddlewareDisableStore(self).clear_for(task_name) # type: ignore[arg-type] diff --git a/py_src/taskito/mixins/overrides.py b/py_src/taskito/mixins/overrides.py new file mode 100644 index 0000000..aae9ace --- /dev/null +++ b/py_src/taskito/mixins/overrides.py @@ -0,0 +1,151 @@ +"""Task & queue runtime override management on :class:`taskito.app.Queue`. + +These knobs let operators tune retry policy, concurrency caps, rate +limits, timeouts, priority, and pause/resume state without touching +code. Overrides land in the dashboard settings store and apply on the +next worker startup. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.overrides_store import ( + OverridesStore, + QueueOverride, + TaskOverride, +) + +if TYPE_CHECKING: + from taskito._taskito import PyTaskConfig + + +class QueueOverridesMixin: + """CRUD for task + queue overrides, plus a task-discovery API for the UI.""" + + _task_configs: list[PyTaskConfig] + _queue_configs: dict[str, dict[str, Any]] + + # ── Task overrides ───────────────────────────────────────────── + + def list_task_overrides(self) -> dict[str, TaskOverride]: + """Return every persisted task override keyed by task name.""" + return OverridesStore(self).list_tasks() # type: ignore[arg-type] + + def get_task_override(self, task_name: str) -> TaskOverride | None: + return OverridesStore(self).get_task(task_name) # type: ignore[arg-type] + + def set_task_override(self, task_name: str, **fields: Any) -> TaskOverride: + """Set or update an override. Pass ``None`` for a field to clear it. + + Allowed fields: ``rate_limit``, ``max_concurrent``, ``max_retries``, + ``retry_backoff``, ``timeout``, ``priority``, ``paused``. + """ + return OverridesStore(self).set_task(task_name, fields) # type: ignore[arg-type] + + def clear_task_override(self, task_name: str) -> bool: + return OverridesStore(self).clear_task(task_name) # type: ignore[arg-type] + + # ── Queue overrides ──────────────────────────────────────────── + + def list_queue_overrides(self) -> dict[str, QueueOverride]: + return OverridesStore(self).list_queues() # type: ignore[arg-type] + + def get_queue_override(self, queue_name: str) -> QueueOverride | None: + return OverridesStore(self).get_queue(queue_name) # type: ignore[arg-type] + + def set_queue_override(self, queue_name: str, **fields: Any) -> QueueOverride: + """Set or update a queue override. Allowed fields: ``rate_limit``, + ``max_concurrent``, ``paused``.""" + return OverridesStore(self).set_queue(queue_name, fields) # type: ignore[arg-type] + + def clear_queue_override(self, queue_name: str) -> bool: + return OverridesStore(self).clear_queue(queue_name) # type: ignore[arg-type] + + # ── Task discovery (for the dashboard) ───────────────────────── + + def registered_tasks(self) -> list[dict[str, Any]]: + """Return every registered task with its decorator defaults and any + active override. Each entry contains: + + - ``name``, ``queue``, ``priority`` + - ``defaults``: the decorator-declared values + - ``override``: the override fields (or ``None`` if no override exists) + - ``effective``: the values that will be used on the next worker start + """ + overrides = self.list_task_overrides() + out: list[dict[str, Any]] = [] + for config in self._task_configs: + defaults = { + "max_retries": config.max_retries, + "retry_backoff": config.retry_backoff, + "timeout": config.timeout, + "priority": config.priority, + "rate_limit": config.rate_limit, + "max_concurrent": config.max_concurrent, + } + override = overrides.get(config.name) + override_dict: dict[str, Any] | None + if override is None: + override_dict = None + effective = dict(defaults) + paused = False + else: + patch = override.as_patch() + override_dict = dict(patch) + if override.paused: + override_dict["paused"] = True + effective = {**defaults, **patch} + paused = override.paused + out.append( + { + "name": config.name, + "queue": config.queue, + "defaults": defaults, + "override": override_dict, + "effective": effective, + "paused": paused, + } + ) + return out + + def registered_queues(self) -> list[dict[str, Any]]: + """Return every queue mentioned by a task config plus any + configured-from-Python queue, with its current overrides + paused + state.""" + queue_names: set[str] = set() + queue_names.update(self._queue_configs.keys()) + for config in self._task_configs: + queue_names.add(config.queue) + overrides = self.list_queue_overrides() + paused_set = set( + self.paused_queues() # type: ignore[attr-defined] + ) + out: list[dict[str, Any]] = [] + for name in sorted(queue_names): + base = dict(self._queue_configs.get(name, {})) + override = overrides.get(name) + override_dict: dict[str, Any] | None + if override is None: + override_dict = None + effective = dict(base) + else: + patch: dict[str, Any] = {} + if override.rate_limit is not None: + patch["rate_limit"] = override.rate_limit + if override.max_concurrent is not None: + patch["max_concurrent"] = override.max_concurrent + override_dict = dict(patch) + if override.paused: + override_dict["paused"] = True + effective = {**base, **patch} + out.append( + { + "name": name, + "defaults": base, + "override": override_dict, + "effective": effective, + "paused": name in paused_set or (override.paused if override else False), + } + ) + return out diff --git a/py_src/taskito/proxies/handlers/requests_session.py b/py_src/taskito/proxies/handlers/requests_session.py index 489c7ea..f372776 100644 --- a/py_src/taskito/proxies/handlers/requests_session.py +++ b/py_src/taskito/proxies/handlers/requests_session.py @@ -5,7 +5,7 @@ from typing import Any try: - import requests # type: ignore[import-untyped] + import requests _HAS_REQUESTS = True except ImportError: diff --git a/py_src/taskito/serializers.py b/py_src/taskito/serializers.py index baa4c76..9233b80 100644 --- a/py_src/taskito/serializers.py +++ b/py_src/taskito/serializers.py @@ -88,12 +88,15 @@ def __init__(self, inner: Serializer, key: bytes): f"key must be 16, 24, or 32 bytes for AES-128/192/256, got {len(key)} bytes" ) + from cryptography.exceptions import InvalidTag from cryptography.hazmat.primitives.ciphers.aead import ( AESGCM, ) self._inner = inner self._aesgcm = AESGCM(key) + # Cache the exception class so ``loads`` doesn't re-import per call. + self._invalid_tag = InvalidTag def dumps(self, obj: Any) -> bytes: import os @@ -108,6 +111,9 @@ def loads(self, data: bytes) -> Any: nonce, ciphertext = data[:12], data[12:] try: plaintext = self._aesgcm.decrypt(nonce, ciphertext, None) - except Exception as exc: - raise ValueError(f"Decryption failed: {exc}") from exc + except self._invalid_tag as exc: + # Wrap so callers don't need to import cryptography.exceptions + # to handle decryption failures. The original ``InvalidTag`` is + # preserved in ``__cause__`` for debugging. + raise ValueError("Decryption failed: invalid authentication tag") from exc return self._inner.loads(plaintext) diff --git a/py_src/taskito/webhooks.py b/py_src/taskito/webhooks.py index 9eb9109..1afd2a4 100644 --- a/py_src/taskito/webhooks.py +++ b/py_src/taskito/webhooks.py @@ -1,4 +1,17 @@ -"""Webhook delivery for job events.""" +"""Webhook delivery for job events. + +The manager keeps an in-memory snapshot of the active subscriptions for +fast dispatch and rehydrates that snapshot from +:class:`~taskito.dashboard.webhook_store.WebhookSubscriptionStore` on +start (and on demand via :meth:`reload`). All add/update/delete writes +go through the DB-backed store so changes survive restarts and propagate +to every worker. + +In-memory subscriptions registered through the legacy +``add_webhook(url, ...)`` API continue to work but are not persisted — +that path is kept for backward compatibility with code that constructs +a ``Queue`` without a settings store yet (rare in practice). +""" from __future__ import annotations @@ -9,12 +22,18 @@ import queue import threading import time +import urllib.error import urllib.parse import urllib.request -from typing import Any +from typing import TYPE_CHECKING, Any +from taskito.dashboard.delivery_store import DeliveryStore from taskito.events import EventType +if TYPE_CHECKING: + from taskito.app import Queue + from taskito.dashboard.webhook_store import WebhookSubscription + logger = logging.getLogger("taskito.webhooks") @@ -22,14 +41,55 @@ class WebhookManager: """Delivers webhook POST requests for job events. Uses a background daemon thread with a queue for non-blocking delivery. - Each webhook is retried up to 3 times with exponential backoff. + Each webhook is retried up to its configured ``max_retries`` with + exponential backoff. """ - def __init__(self) -> None: + def __init__(self, queue_ref: Queue | None = None) -> None: + # ``queue_ref`` is the parent :class:`taskito.app.Queue`. Optional + # so legacy in-process tests can construct a bare manager. + self._queue: Queue | None = queue_ref + # In-memory subscription list. Each entry is a dict shaped like a + # legacy ``add_webhook`` call so both code paths share a single + # delivery loop. self._webhooks: list[dict[str, Any]] = [] - self._queue: queue.Queue[tuple[dict[str, Any], dict[str, Any]]] = queue.Queue() + self._delivery_queue: queue.Queue[tuple[dict[str, Any], dict[str, Any]]] = queue.Queue() self._thread: threading.Thread | None = None - self._thread_lock = threading.Lock() + self._lock = threading.Lock() + if queue_ref is not None: + self.reload() + + # ── Snapshot management ─────────────────────────────────────── + + def reload(self) -> None: + """Refresh the in-memory snapshot from the persistent store.""" + if self._queue is None: + return + from taskito.dashboard.webhook_store import WebhookSubscriptionStore + + store = WebhookSubscriptionStore(self._queue) + snapshot = [self._subscription_to_runtime(s) for s in store.list_all()] + with self._lock: + self._webhooks = snapshot + if snapshot: + self._ensure_thread() + + @staticmethod + def _subscription_to_runtime(sub: WebhookSubscription) -> dict[str, Any]: + return { + "subscription_id": sub.id, + "url": sub.url, + "events": set(sub.events) if sub.events else None, + "task_filter": set(sub.task_filter) if sub.task_filter is not None else None, + "headers": dict(sub.headers), + "secret": sub.secret.encode() if sub.secret else None, + "max_retries": sub.max_retries, + "timeout": sub.timeout_seconds, + "retry_backoff": sub.retry_backoff, + "enabled": sub.enabled, + } + + # ── Public API (legacy + new) ───────────────────────────────── def add_webhook( self, @@ -41,44 +101,60 @@ def add_webhook( timeout: float = 10.0, retry_backoff: float = 2.0, ) -> None: - """Register a webhook endpoint. - - Args: - url: URL to POST event payloads to. - events: List of event types to subscribe to. None means all events. - headers: Extra HTTP headers to include. - secret: HMAC-SHA256 signing secret for the ``X-Taskito-Signature`` header. - max_retries: Maximum delivery attempts (default 3). - timeout: HTTP request timeout in seconds (default 10.0). - retry_backoff: Base for exponential backoff between retries (default 2.0). + """Register a webhook endpoint (in-memory; not persisted). + + Prefer :meth:`Queue.create_webhook` for new code — it persists + through the dashboard-managed store and survives restarts. """ parsed = urllib.parse.urlparse(url) if parsed.scheme not in ("http", "https"): raise ValueError(f"Webhook URL must use http:// or https://, got {parsed.scheme!r}") - with self._thread_lock: + with self._lock: self._webhooks.append( { + "subscription_id": None, "url": url, "events": {e.value for e in events} if events else None, + "task_filter": None, "headers": headers or {}, "secret": secret.encode() if secret else None, "max_retries": max_retries, "timeout": timeout, "retry_backoff": retry_backoff, + "enabled": True, } ) self._ensure_thread() def notify(self, event_type: EventType, payload: dict[str, Any]) -> None: """Queue an event for delivery to matching webhooks.""" - with self._thread_lock: + with self._lock: webhooks = list(self._webhooks) + task_name = payload.get("task_name") + wire_event = event_type.value for wh in webhooks: - if wh["events"] is None or event_type.value in wh["events"]: - self._queue.put((wh, {"event": event_type.value, **payload})) + if not wh.get("enabled", True): + continue + if wh["events"] is not None and wire_event not in wh["events"]: + continue + task_filter = wh.get("task_filter") + if task_filter is not None and task_name not in task_filter: + continue + self._delivery_queue.put((wh, {"event": wire_event, **payload})) + + def deliver_now(self, wh: dict[str, Any], payload: dict[str, Any]) -> int | None: + """Synchronously deliver one payload. Returns the final HTTP status or + ``None`` if every attempt failed at the transport level. + + Used by the dashboard "send test event" endpoint so the operator + sees the result inline. Does NOT add to the retry queue. + """ + return self._send(wh, payload, write_to_log=False) + + # ── Delivery loop ───────────────────────────────────────────── def _ensure_thread(self) -> None: - with self._thread_lock: + with self._lock: if self._thread is None or not self._thread.is_alive(): self._thread = threading.Thread( target=self._deliver_loop, daemon=True, name="taskito-webhooks" @@ -88,14 +164,25 @@ def _ensure_thread(self) -> None: def _deliver_loop(self) -> None: while True: try: - wh, payload = self._queue.get(timeout=10) + wh, payload = self._delivery_queue.get(timeout=10) self._send(wh, payload) except queue.Empty: continue except Exception: logger.exception("Webhook delivery error") - def _send(self, wh: dict[str, Any], payload: dict[str, Any]) -> None: + def _send( + self, wh: dict[str, Any], payload: dict[str, Any], *, write_to_log: bool = True + ) -> int | None: + """Deliver ``payload`` to ``wh`` with retries. Returns the last HTTP + status code observed (after retries) or ``None`` if every attempt + failed at the transport level. + + When ``write_to_log`` is true AND the subscription is persisted + (``wh["subscription_id"]`` is not ``None``), a record of the final + outcome is appended to the delivery log so the dashboard can + replay it later. + """ body = json.dumps(payload, default=str).encode("utf-8") headers: dict[str, str] = { @@ -111,25 +198,134 @@ def _send(self, wh: dict[str, Any], payload: dict[str, Any]) -> None: timeout: float = wh.get("timeout", 10.0) retry_backoff: float = wh.get("retry_backoff", 2.0) + last_status: int | None = None + last_response_body: str | None = None + last_error: str | None = None + started_at = time.monotonic() + attempt_count = 0 + for attempt in range(max_retries): + attempt_count = attempt + 1 try: req = urllib.request.Request(wh["url"], data=body, headers=headers, method="POST") with urllib.request.urlopen(req, timeout=timeout) as resp: - if resp.status < 400: - return - if resp.status < 500: + last_status = int(resp.status) + last_response_body = self._read_response_body(resp) + if last_status < 400: + self._record( + wh, + payload, + status="delivered", + attempts=attempt_count, + response_code=last_status, + response_body=last_response_body, + latency_ms=int((time.monotonic() - started_at) * 1000), + write_to_log=write_to_log, + ) + return last_status + if write_to_log: + logger.warning( + "Webhook %s returned server error %d", wh["url"], resp.status + ) + except urllib.error.HTTPError as e: + last_status = e.code + last_response_body = self._read_response_body(e) + if e.code < 500: + if write_to_log: logger.warning( "Webhook %s returned client error %d, not retrying", wh["url"], - resp.status, + e.code, ) - return - logger.warning("Webhook %s returned server error %d", wh["url"], resp.status) - except Exception: - logger.debug("Webhook %s attempt %d failed", wh["url"], attempt + 1, exc_info=True) + self._record( + wh, + payload, + status="failed", + attempts=attempt_count, + response_code=last_status, + response_body=last_response_body, + latency_ms=int((time.monotonic() - started_at) * 1000), + write_to_log=write_to_log, + ) + return e.code + if write_to_log: + logger.warning("Webhook %s returned server error %d", wh["url"], e.code) + except Exception as e: + last_error = f"{type(e).__name__}: {e}" + if write_to_log: + logger.debug( + "Webhook %s attempt %d failed", + wh["url"], + attempt + 1, + exc_info=True, + ) if attempt == max_retries - 1: - logger.warning( - "Webhook delivery failed after %d attempts: %s", max_retries, wh["url"] - ) + if write_to_log: + logger.warning( + "Webhook delivery failed after %d attempts: %s", + max_retries, + wh["url"], + ) else: time.sleep(retry_backoff**attempt) + + # Out of retries — record as dead. + self._record( + wh, + payload, + status="dead", + attempts=attempt_count, + response_code=last_status, + response_body=last_response_body, + latency_ms=int((time.monotonic() - started_at) * 1000), + error=last_error, + write_to_log=write_to_log, + ) + return last_status + + def _record( + self, + wh: dict[str, Any], + payload: dict[str, Any], + *, + status: str, + attempts: int, + response_code: int | None = None, + response_body: str | None = None, + latency_ms: int | None = None, + error: str | None = None, + write_to_log: bool = True, + ) -> None: + """Persist a delivery outcome to the dashboard log.""" + if not write_to_log: + return + subscription_id = wh.get("subscription_id") + if not subscription_id or self._queue is None: + return + try: + DeliveryStore(self._queue).record_attempt( + subscription_id, + event=str(payload.get("event", "")), + payload=payload, + status=status, + attempts=attempts, + response_code=response_code, + response_body=response_body, + latency_ms=latency_ms, + error=error, + task_name=payload.get("task_name"), + job_id=payload.get("job_id"), + ) + except Exception: + logger.exception("Failed to record webhook delivery") + + @staticmethod + def _read_response_body(resp: Any) -> str | None: + """Read up to a few KiB from a response/HTTPError object.""" + try: + data = resp.read(4096) # limit even before truncation in DeliveryStore + except Exception: + return None + if not data: + return None + return str(data.decode("utf-8", errors="replace")) diff --git a/pyproject.toml b/pyproject.toml index 9fdfddd..91e36c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,8 @@ encryption = ["cryptography"] flask = ["flask>=3.0"] aws = ["boto3>=1.34"] gcs = ["google-cloud-storage>=2.10"] +docs = ["playwright>=1.59"] +oauth = ["authlib>=1.7,<2", "requests>=2.31"] [tool.maturin] manifest-path = "crates/taskito-python/Cargo.toml" @@ -148,3 +150,16 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = "click" ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["authlib", "authlib.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["joserfc", "joserfc.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["requests", "requests.*"] +ignore_missing_imports = true + diff --git a/scripts/capture_docs_screenshots.py b/scripts/capture_docs_screenshots.py new file mode 100644 index 0000000..8f4c3cc --- /dev/null +++ b/scripts/capture_docs_screenshots.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +"""Reproducible screenshot capture for the documentation site. + +Spins up a fresh Taskito Queue, seeds it with deterministic demo data +(admin user, sample tasks/queues, webhooks with mixed delivery outcomes, +runtime overrides, a middleware disable), starts the dashboard on a +random port, drives a headless Chromium through every screen, and saves +PNGs under ``docs/public/screenshots/dashboard/``. + +Run from the repo root: + + uv run --with playwright python scripts/capture_docs_screenshots.py + # First time only: + uv run --with playwright python -m playwright install chromium + +The script is **idempotent**: every run overwrites the previous PNGs and +starts from an empty SQLite DB in a temp directory, so there is no +"works on my machine" drift. +""" + +from __future__ import annotations + +import argparse +import contextlib +import json +import socket +import sys +import tempfile +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer +from pathlib import Path +from typing import Any + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard.auth import AuthStore +from taskito.dashboard.delivery_store import DeliveryStore +from taskito.events import EventType +from taskito.middleware import TaskMiddleware + +ADMIN_USER = "demo-admin" +ADMIN_PASSWORD = "demo-pass-1234" + +REPO_ROOT = Path(__file__).resolve().parent.parent +SCREENSHOT_DIR = REPO_ROOT / "docs" / "public" / "screenshots" / "dashboard" + + +# ── Demo middleware ────────────────────────────────────────────────── + + +class LoggingMiddleware(TaskMiddleware): + """Demo middleware that the screenshots reference.""" + + name = "demo.logging" + + +class MetricsMiddleware(TaskMiddleware): + name = "demo.metrics" + + +# ── Seed data ──────────────────────────────────────────────────────── + + +# Task body functions are defined at module level so their qualnames stay +# clean (``demo_tasks.send_email`` rather than +# ``capture_docs_screenshots.seed_queue..send_email``) — much nicer +# in screenshots. ``seed_queue`` registers them with the live Queue. + + +def send_email(to: str) -> str: + return f"sent:{to}" + + +def deliver_message(message: str) -> str: + return message + + +def sync_metrics() -> None: + pass + + +send_email.__module__ = "myapp.tasks" +send_email.__qualname__ = "send_email" +deliver_message.__module__ = "myapp.tasks" +deliver_message.__qualname__ = "deliver_message" +sync_metrics.__module__ = "myapp.tasks" +sync_metrics.__qualname__ = "sync_metrics" + + +def seed_queue(queue: Queue) -> None: + """Populate the demo queue with realistic data for the screenshots. + + Returns nothing — the dashboard reads everything from storage. + """ + # Admin user — header drop-down shows "demo-admin" in the screenshots. + AuthStore(queue).create_user(ADMIN_USER, ADMIN_PASSWORD, role="admin") + + # Tasks — defaults vary so the Tasks table has visual variety. + queue.task()(send_email) + queue.task( + queue="email", + max_retries=5, + timeout=120, + rate_limit="100/m", + max_concurrent=10, + )(deliver_message) + queue.task(queue="metrics", priority=2)(sync_metrics) + + # Queue-level configuration. set_queue_concurrency goes through the + # same code path as the dashboard override apply. + queue.set_queue_concurrency("email", 10) + + # Override the send_email task — Tasks page should show this in accent. + queue.set_task_override( + next(c.name for c in queue._task_configs if c.name.endswith("send_email")), + rate_limit="200/m", + max_retries=10, + ) + + # Webhook subscriptions — one fully configured, one disabled, one + # filtered to a specific task. + import os + + os.environ["TASKITO_WEBHOOKS_ALLOW_PRIVATE"] = "1" # echo server is loopback + + sub1 = queue.add_webhook( + url="https://hooks.example.com/ops-failures", + events=[EventType.JOB_FAILED, EventType.JOB_DEAD], + secret="whsec_demo_signing_secret", + description="Page ops on permanent job failures", + max_retries=5, + timeout=8.0, + ) + # Second subscription: filters by task name. Captured for visual + # contrast in the webhooks-list screenshot; we don't reuse its id. + queue.add_webhook( + url="https://audit.internal.example.com/taskito-events", + events=None, + task_filter=["myapp.tasks.send_email"], + description="Audit log for send_email only", + ) + sub3 = queue.add_webhook( + url="https://staging-hooks.example.com/all-events", + description="Staging echo — disabled", + ) + queue.update_webhook(sub3.id, enabled=False) + + # Synthesize delivery history for sub1 so the Deliveries page has rows. + store = DeliveryStore(queue) + base_time = int(time.time() * 1000) + deliveries = [ + ("delivered", 200, 42, "job.completed", "myapp.tasks.process_image"), + ("delivered", 200, 38, "job.completed", "myapp.tasks.send_email"), + ("delivered", 200, 51, "job.completed", "myapp.tasks.send_email"), + ("failed", 504, 9500, "job.failed", "myapp.tasks.process_image"), + ("delivered", 200, 44, "job.completed", "myapp.tasks.send_email"), + ("dead", 500, 30000, "job.dead", "myapp.tasks.process_image"), + ("delivered", 200, 39, "job.completed", "myapp.tasks.send_email"), + ] + for i, (status, code, lat, event, task_name) in enumerate(deliveries): + record = store.record_attempt( + sub1.id, + event=event, + payload={ + "task_name": task_name, + "job_id": f"01H{i:02d}DEMOXYZ{i}", + "queue": "default", + }, + status=status, + attempts=3 if status == "dead" else 1, + response_code=code if status != "delivered" or code == 200 else None, + latency_ms=lat, + response_body=( + None + if status == "delivered" + else "Internal Server Error\nstack trace here..." + ), + task_name=task_name, + job_id=f"01H{i:02d}DEMOXYZ{i}", + ) + # Backdate so the "When" column shows a range of relative times. + _backdate_delivery(queue, sub1.id, record.id, base_time - i * 600_000) + + # Disable one middleware on one task so the Middleware tab has a + # mix of green / grey toggles. + send_email_full = next( + c.name for c in queue._task_configs if c.name.endswith("send_email") + ) + queue.disable_middleware_for_task(send_email_full, "demo.metrics") + + +def _backdate_delivery(queue: Queue, sub_id: str, record_id: str, ts: int) -> None: + """Rewrite a delivery's ``created_at`` so the deliveries table shows + a believable range of relative times in the screenshot rather than + a clump of "just now" rows.""" + key = f"webhooks:deliveries:{sub_id}" + raw = queue.get_setting(key) + if not raw: + return + rows = json.loads(raw) + for row in rows: + if row.get("id") == record_id: + row["created_at"] = ts + if row.get("completed_at") is not None: + row["completed_at"] = ts + 50 + queue.set_setting(key, json.dumps(rows, separators=(",", ":"))) + return + + +# ── Webhook echo server (for the live send-test screenshot) ────────── + + +def start_echo_server() -> tuple[str, HTTPServer]: + """Local server the test webhook delivers to during the captures.""" + + class Handler(BaseHTTPRequestHandler): + def do_POST(self) -> None: + self.send_response(200) + self.end_headers() + self.wfile.write(b"ok") + + def log_message(self, *args: Any) -> None: + pass + + server = HTTPServer(("127.0.0.1", 0), Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return f"http://127.0.0.1:{server.server_address[1]}", server + + +# ── Dashboard process ──────────────────────────────────────────────── + + +def start_dashboard(queue: Queue) -> tuple[str, ThreadingHTTPServer]: + """Start the dashboard on a random localhost port.""" + handler = _make_handler(queue) + # Bind to port 0 → kernel picks a free port. + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + base = f"http://127.0.0.1:{server.server_address[1]}" + _wait_for_port(server.server_address[1]) + return base, server + + +def _wait_for_port(port: int, timeout: float = 5.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + with socket.socket() as sock: + try: + sock.settimeout(0.2) + sock.connect(("127.0.0.1", port)) + return + except OSError: + time.sleep(0.05) + raise RuntimeError(f"dashboard did not bind on port {port}") + + +# ── Capture flow ───────────────────────────────────────────────────── + + +def capture_all(base_url: str) -> None: + """Walk every dashboard page and save a PNG per screen. + + Uses Playwright's sync API. Each screenshot is named for its target + location in the docs so the MDX side can reference them by stable path. + """ + from playwright.sync_api import sync_playwright + + SCREENSHOT_DIR.mkdir(parents=True, exist_ok=True) + print(f"Capturing screenshots into {SCREENSHOT_DIR}") + + with sync_playwright() as p: + # Use the system Chrome so no extra browser download is needed. + browser = p.chromium.launch(headless=True, channel="chrome") + # Width 1280 matches the dashboard's ``max-w-[1400px]`` content + # area; height 800 keeps each capture above-the-fold without huge + # screenshots. ``deviceScaleFactor=2`` gives crisp HiDPI output. + context = browser.new_context( + viewport={"width": 1280, "height": 800}, + device_scale_factor=2, + ) + page = context.new_page() + + # ── Phase 1: auth ───────────────────────────────────────── + login_and_screenshot_setup_then_login(page, base_url) + + # Pre-fetch the cookie for the rest of the flow — the dashboard + # uses session cookies, and after the login flow above we already + # have them. So the page context is now authenticated. + + # ── Main pages ─────────────────────────────────────────── + capture_each( + page, + base_url, + [ + ("/", "overview"), + ("/jobs", "jobs"), + ("/queues", "queues"), + ("/workers", "workers"), + ], + wait_for_text={"/": "Overview", "/jobs": "Jobs", "/queues": "Queues"}, + ) + + # ── Phase 2/3: webhooks ────────────────────────────────── + capture_page(page, f"{base_url}/webhooks", "webhooks-list") + # Drive the deliveries view via the visible UI — same path an + # operator would take. + page.locator('button[aria-label="Webhook actions"]').first.click() + page.get_by_role("menuitem", name="View deliveries").click() + page.wait_for_url("**/deliveries", timeout=5000) + page.wait_for_load_state("networkidle") + time.sleep(1.2) + screenshot(page, "webhook-deliveries") + + # Open the create-webhook dialog for the form screenshot. + page.goto(f"{base_url}/webhooks", wait_until="networkidle") + page.get_by_role("button", name="New webhook").click() + page.wait_for_selector("text=Subscribe an HTTP endpoint", timeout=3000) + time.sleep(0.3) # let the dialog finish animating in + screenshot(page, "webhook-create-dialog") + + # ── Phase 4/5: tasks + middleware ──────────────────────── + capture_page(page, f"{base_url}/tasks", "tasks-list") + + page.goto(f"{base_url}/tasks", wait_until="networkidle") + # First Edit button opens the side sheet. + page.get_by_role("button", name="Edit").first.click() + page.wait_for_selector("text=Overrides", timeout=3000) + time.sleep(0.3) + screenshot(page, "task-edit-overrides") + + # Switch to the Middleware tab inside the same sheet. + page.get_by_role("tab", name="Middleware").click() + page.wait_for_selector("text=demo.logging", timeout=3000) + time.sleep(0.3) + screenshot(page, "task-edit-middleware") + + context.close() + browser.close() + print(f"OK — captured {len(list(SCREENSHOT_DIR.glob('*.png')))} screenshots") + + +def login_and_screenshot_setup_then_login(page: Any, base_url: str) -> None: + """Capture the setup page on a fresh dashboard, then the login page, + then sign in so subsequent captures are authenticated.""" + # Use a *separate* throwaway DB just for the setup screenshot — the + # main demo queue already has a user, so /login would show the sign-in + # form. Easier than tearing down and re-seeding mid-run. + setup_url = _start_throwaway_dashboard() + page.goto(setup_url + "/login", wait_until="networkidle") + page.wait_for_selector("text=Create the first admin", timeout=3000) + time.sleep(0.3) + screenshot(page, "auth-setup") + + # Now the real login page on the seeded dashboard. + page.goto(base_url + "/login", wait_until="networkidle") + page.wait_for_selector("text=Sign in", timeout=3000) + time.sleep(0.3) + screenshot(page, "auth-login") + + # Authenticate so the rest of the captures run inside the AppShell. + page.fill('input[id="login-username"]', ADMIN_USER) + page.fill('input[id="login-password"]', ADMIN_PASSWORD) + page.get_by_role("button", name="Sign in").click() + page.wait_for_url(f"{base_url}/", timeout=5000) + + +def _start_throwaway_dashboard() -> str: + """Spin up a second dashboard against a fresh empty DB just for the + setup screenshot.""" + tmpdir = tempfile.mkdtemp(prefix="taskito-docs-") + q = Queue(db_path=f"{tmpdir}/setup.db") + url, _server = start_dashboard(q) + return url + + +def capture_each( + page: Any, + base_url: str, + routes: list[tuple[str, str]], + *, + wait_for_text: dict[str, str] | None = None, +) -> None: + for route, name in routes: + url = base_url + route + page.goto(url, wait_until="networkidle") + expected = (wait_for_text or {}).get(route) + if expected is not None: + with contextlib.suppress(Exception): + page.wait_for_selector(f"text={expected}", timeout=3000) + time.sleep(0.3) + screenshot(page, name) + + +def capture_page(page: Any, url: str, name: str) -> None: + page.goto(url, wait_until="networkidle") + time.sleep(0.4) + screenshot(page, name) + + +def screenshot(page: Any, name: str) -> None: + out = SCREENSHOT_DIR / f"{name}.png" + page.screenshot(path=str(out), full_page=False) + print(f" • {out.name}") + + +def _first_webhook_id(base_url: str) -> str: + """Pull the demo webhook id straight from the API so the deep link + in the screenshot script stays in sync with whatever seed_queue + produced.""" + # NB: the dashboard is auth-gated, so we need the cookie. Simplest + # approach is a synchronous login via stdlib urllib. + import urllib.parse + + login = json.dumps({"username": ADMIN_USER, "password": ADMIN_PASSWORD}).encode() + req = urllib.request.Request( + f"{base_url}/api/auth/login", + method="POST", + data=login, + headers={"Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req) as resp: + cookies = "; ".join( + urllib.parse.unquote(c.split(";", 1)[0]) + for c in resp.headers.get_all("Set-Cookie") or [] + ) + list_req = urllib.request.Request( + f"{base_url}/api/webhooks", headers={"Cookie": cookies} + ) + with urllib.request.urlopen(list_req) as resp: + items = json.loads(resp.read()) + return str(items[0]["id"]) + + +# ── Entry point ────────────────────────────────────────────────────── + + +def main(argv: list[str]) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--skip-capture", + action="store_true", + help="Seed the demo queue and start the dashboard but skip the " + "Playwright run — useful for poking the seeded data in a browser.", + ) + args = parser.parse_args(argv) + + tmpdir = tempfile.mkdtemp(prefix="taskito-docs-") + print(f"Demo DB: {tmpdir}/demo.db") + queue = Queue( + db_path=f"{tmpdir}/demo.db", + middleware=[LoggingMiddleware(), MetricsMiddleware()], + ) + seed_queue(queue) + + echo_url, _echo = start_echo_server() + print(f"Echo server: {echo_url}") + + base_url, _dash = start_dashboard(queue) + print(f"Dashboard: {base_url}") + + if args.skip_capture: + print("\nSkipping Playwright. Open the dashboard in a browser. Ctrl+C to exit.") + try: + threading.Event().wait() + except KeyboardInterrupt: + return 0 + return 0 + + try: + capture_all(base_url) + except Exception: + import traceback + + traceback.print_exc() + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv[1:])) diff --git a/scripts/optimize_screenshots.py b/scripts/optimize_screenshots.py new file mode 100644 index 0000000..2310477 --- /dev/null +++ b/scripts/optimize_screenshots.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +"""Crunch the dashboard screenshots so the docs site stays light. + +The capture script writes 2x HiDPI PNGs straight from Playwright, which +range from 50 KB to 350 KB each. This pass: + +1. Converts to ``P`` mode (256-colour palette) where appropriate — + screenshots are dominated by flat UI panels and large solid regions, + so palette quantisation typically halves the file size with no + visible loss. +2. Falls back to the original RGBA encoding if quantisation actually + *grows* the file (rare on dense screenshots). +3. Reports before/after sizes so we can spot regressions. + +Run after every screenshot regen: + + uv run --with pillow python scripts/optimize_screenshots.py +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +from PIL import Image + +REPO_ROOT = Path(__file__).resolve().parent.parent +SCREENSHOT_DIR = REPO_ROOT / "docs" / "public" / "screenshots" / "dashboard" + + +def optimize(path: Path) -> tuple[int, int]: + """Return ``(before_bytes, after_bytes)``.""" + before = path.stat().st_size + with Image.open(path) as img: + img.load() + # Quantise to a 256-colour palette. ``method=2`` (median cut) is + # better for synthetic UI screenshots than the default libimagequant + # path Pillow uses on lossier paths. + quantised = img.convert("RGB").quantize(colors=256, method=2, dither=Image.Dither.NONE) + tmp = path.with_suffix(".opt.png") + quantised.save(tmp, format="PNG", optimize=True) + after = tmp.stat().st_size + if after < before: + tmp.replace(path) + return before, after + tmp.unlink() + return before, before + + +def main() -> int: + if not SCREENSHOT_DIR.exists(): + print(f"No screenshots at {SCREENSHOT_DIR}", file=sys.stderr) + return 1 + total_before = 0 + total_after = 0 + for png in sorted(SCREENSHOT_DIR.glob("*.png")): + before, after = optimize(png) + total_before += before + total_after += after + delta = (after - before) / before * 100 if before else 0.0 + print( + f"{png.name:35s} {before / 1024:7.1f} KB → {after / 1024:7.1f} KB" + f" ({delta:+5.1f}%)" + ) + print( + f"{'TOTAL':35s} {total_before / 1024:7.1f} KB → {total_after / 1024:7.1f} KB" + f" ({(total_after - total_before) / total_before * 100:+5.1f}%)" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/core/test_serializers.py b/tests/core/test_serializers.py index 977a718..c78f839 100644 --- a/tests/core/test_serializers.py +++ b/tests/core/test_serializers.py @@ -103,27 +103,48 @@ def test_wrong_key_fails(self) -> None: pytest.importorskip("cryptography") import os + from cryptography.exceptions import InvalidTag + from taskito.serializers import EncryptedSerializer s1 = EncryptedSerializer(JsonSerializer(), os.urandom(32)) s2 = EncryptedSerializer(JsonSerializer(), os.urandom(32)) - from cryptography.exceptions import InvalidTag encrypted = s1.dumps({"data": 1}) - with pytest.raises(InvalidTag): + with pytest.raises(ValueError, match="Decryption failed") as excinfo: s2.loads(encrypted) + # The original cryptography exception is preserved on the cause + # chain so debugging surfaces still know it was a tag-validation + # failure rather than a malformed-input ValueError. + assert isinstance(excinfo.value.__cause__, InvalidTag) def test_tampered_ciphertext_fails(self) -> None: pytest.importorskip("cryptography") import os + from cryptography.exceptions import InvalidTag + from taskito.serializers import EncryptedSerializer key = os.urandom(32) s = EncryptedSerializer(JsonSerializer(), key) encrypted = s.dumps("hello") - from cryptography.exceptions import InvalidTag tampered = encrypted[:-1] + bytes([encrypted[-1] ^ 0xFF]) - with pytest.raises(InvalidTag): + with pytest.raises(ValueError, match="Decryption failed") as excinfo: s.loads(tampered) + assert isinstance(excinfo.value.__cause__, InvalidTag) + + def test_short_ciphertext_fails(self) -> None: + """Inputs shorter than the AES-GCM nonce (12B) + tag (≥1B) are + rejected before the cipher is ever consulted, with a distinct + message so operators can tell parsing errors from key/tag failures. + """ + pytest.importorskip("cryptography") + import os + + from taskito.serializers import EncryptedSerializer + + s = EncryptedSerializer(JsonSerializer(), os.urandom(32)) + with pytest.raises(ValueError, match="too short"): + s.loads(b"only-twelve-") diff --git a/tests/dashboard/test_auth.py b/tests/dashboard/test_auth.py new file mode 100644 index 0000000..599e689 --- /dev/null +++ b/tests/dashboard/test_auth.py @@ -0,0 +1,588 @@ +"""Tests for dashboard authentication. + +Covers the auth helpers in :mod:`taskito.dashboard.auth` and the HTTP +endpoints under ``/api/auth/*``, plus the session-gating behaviour the +server applies to every other API route. +""" + +from __future__ import annotations + +import json +import threading +import urllib.error +import urllib.request +from collections.abc import Generator +from http.server import ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard.auth import ( + AuthStore, + bootstrap_admin_from_env, + hash_password, + verify_password, +) + + +@pytest.fixture +def queue(tmp_path: Path) -> Queue: + return Queue(db_path=str(tmp_path / "auth.db")) + + +# ── Password hashing primitives ───────────────────────────────────────── + + +def test_hash_password_round_trip() -> None: + encoded = hash_password("hunter2-correct-horse") + assert verify_password("hunter2-correct-horse", encoded) is True + assert verify_password("wrong", encoded) is False + + +def test_hash_password_produces_unique_salts() -> None: + a = hash_password("same-password") + b = hash_password("same-password") + assert a != b, "different salts must produce different hashes" + assert verify_password("same-password", a) + assert verify_password("same-password", b) + + +def test_verify_password_rejects_malformed_encoding() -> None: + assert verify_password("anything", "not-a-real-hash") is False + assert verify_password("anything", "scrypt$xxx$yyy$zzz") is False + assert verify_password("anything", "pbkdf2_sha256$abc$def$ghi") is False + + +# ── AuthStore: users ──────────────────────────────────────────────────── + + +def test_count_users_starts_at_zero(queue: Queue) -> None: + assert AuthStore(queue).count_users() == 0 + + +def test_create_user_persists(queue: Queue) -> None: + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + assert user.username == "alice" + assert user.role == "admin" + assert store.count_users() == 1 + assert store.get_user("alice") is not None + assert store.get_user("missing") is None + + +def test_create_user_rejects_duplicate(queue: Queue) -> None: + store = AuthStore(queue) + store.create_user("alice", "hunter2-secret") + with pytest.raises(ValueError, match="already exists"): + store.create_user("alice", "another-pass") + + +def test_create_user_validates_username(queue: Queue) -> None: + store = AuthStore(queue) + with pytest.raises(ValueError, match="empty"): + store.create_user("", "hunter2-secret") + with pytest.raises(ValueError, match="may only contain"): + store.create_user("alice bob", "hunter2-secret") + + +def test_create_user_validates_password(queue: Queue) -> None: + store = AuthStore(queue) + with pytest.raises(ValueError, match=">= 8 chars"): + store.create_user("alice", "short") + + +def test_authenticate(queue: Queue) -> None: + store = AuthStore(queue) + store.create_user("alice", "hunter2-secret") + assert store.authenticate("alice", "hunter2-secret") is not None + assert store.authenticate("alice", "wrong") is None + # Unknown username also returns None, without timing leak (we don't + # assert timing here, just behaviour). + assert store.authenticate("bob", "anything") is None + + +def test_authenticate_updates_last_login(queue: Queue) -> None: + store = AuthStore(queue) + store.create_user("alice", "hunter2-secret") + assert store.get_user("alice").last_login_at is None # type: ignore[union-attr] + store.authenticate("alice", "hunter2-secret") + assert store.get_user("alice").last_login_at is not None # type: ignore[union-attr] + + +def test_delete_user(queue: Queue) -> None: + store = AuthStore(queue) + store.create_user("alice", "hunter2-secret") + assert store.delete_user("alice") is True + assert store.delete_user("alice") is False + assert store.get_user("alice") is None + + +def test_update_password(queue: Queue) -> None: + store = AuthStore(queue) + store.create_user("alice", "hunter2-secret") + store.update_password("alice", "new-secure-pass") + assert store.authenticate("alice", "new-secure-pass") is not None + assert store.authenticate("alice", "hunter2-secret") is None + + +# ── AuthStore: sessions ──────────────────────────────────────────────── + + +def test_create_and_get_session(queue: Queue) -> None: + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + fetched = store.get_session(session.token) + assert fetched is not None + assert fetched.username == "alice" + assert fetched.csrf_token == session.csrf_token + assert not fetched.is_expired() + + +def test_get_session_unknown_token_returns_none(queue: Queue) -> None: + assert AuthStore(queue).get_session("nope") is None + assert AuthStore(queue).get_session("") is None + + +def test_delete_session(queue: Queue) -> None: + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + assert store.delete_session(session.token) is True + assert store.get_session(session.token) is None + assert store.delete_session(session.token) is False + + +def test_expired_sessions_pruned_on_lookup(queue: Queue) -> None: + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user, ttl_seconds=0) + # ttl_seconds=0 means it expires immediately. + assert store.get_session(session.token) is None + + +def test_prune_expired_sessions(queue: Queue) -> None: + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + long_lived = store.create_session(user, ttl_seconds=3600) + short_lived = store.create_session(user, ttl_seconds=0) + removed = store.prune_expired_sessions() + assert removed >= 1 + assert store.get_session(long_lived.token) is not None + assert store.get_session(short_lived.token) is None + + +# ── Env bootstrap ────────────────────────────────────────────────────── + + +def test_bootstrap_admin_from_env(queue: Queue, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TASKITO_DASHBOARD_ADMIN_USER", "envadmin") + monkeypatch.setenv("TASKITO_DASHBOARD_ADMIN_PASSWORD", "from-environ-pass") + user = bootstrap_admin_from_env(queue) + assert user is not None + assert user.username == "envadmin" + + # Idempotent — second call is a no-op. + again = bootstrap_admin_from_env(queue) + assert again is None + + +def test_bootstrap_admin_noop_without_env(queue: Queue, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TASKITO_DASHBOARD_ADMIN_USER", raising=False) + monkeypatch.delenv("TASKITO_DASHBOARD_ADMIN_PASSWORD", raising=False) + assert bootstrap_admin_from_env(queue) is None + assert AuthStore(queue).count_users() == 0 + + +# ── OAuth users ──────────────────────────────────────────────────────── + + +def test_verify_password_rejects_oauth_sentinel_hash() -> None: + assert verify_password("anything", "oauth:google") is False + assert verify_password("anything", "oauth:okta") is False + + +def test_get_or_create_oauth_user_creates_user_with_admin_when_table_empty( + queue: Queue, +) -> None: + store = AuthStore(queue) + user = store.get_or_create_oauth_user( + slot="google", + subject="1184283742", + email="alice@acme.com", + name="Alice Example", + email_verified=True, + ) + assert user.username == "google:1184283742" + assert user.role == "admin" + assert user.email == "alice@acme.com" + assert user.display_name == "Alice Example" + assert user.is_oauth is True + + +def test_get_or_create_oauth_user_subsequent_user_is_viewer(queue: Queue) -> None: + store = AuthStore(queue) + store.get_or_create_oauth_user( + slot="google", + subject="111", + email="alice@acme.com", + name="Alice", + email_verified=True, + ) + second = store.get_or_create_oauth_user( + slot="google", + subject="222", + email="bob@acme.com", + name="Bob", + email_verified=True, + ) + assert second.role == "viewer" + + +def test_get_or_create_oauth_user_admin_emails_take_precedence(queue: Queue) -> None: + store = AuthStore(queue) + # Pre-seed a password user so the table is not empty. + store.create_user("primary", "hunter2-secret") + listed = store.get_or_create_oauth_user( + slot="google", + subject="111", + email="alice@acme.com", + name="Alice", + email_verified=True, + admin_emails=("alice@acme.com",), + ) + assert listed.role == "admin" + + unlisted = store.get_or_create_oauth_user( + slot="google", + subject="222", + email="eve@evil.com", + name="Eve", + email_verified=True, + admin_emails=("alice@acme.com",), + ) + assert unlisted.role == "viewer" + + +def test_get_or_create_oauth_user_unverified_email_never_gets_admin(queue: Queue) -> None: + store = AuthStore(queue) + # Even on empty table, an unverified email cannot become admin. + user = store.get_or_create_oauth_user( + slot="github", + subject="42", + email="claimed@acme.com", + name=None, + email_verified=False, + admin_emails=("claimed@acme.com",), + ) + assert user.role == "viewer" + + +def test_get_or_create_oauth_user_email_match_is_case_insensitive(queue: Queue) -> None: + store = AuthStore(queue) + store.create_user("primary", "hunter2-secret") + user = store.get_or_create_oauth_user( + slot="google", + subject="123", + email="Alice@ACME.com", + name=None, + email_verified=True, + admin_emails=("alice@acme.com",), + ) + assert user.role == "admin" + + +def test_get_or_create_oauth_user_returning_user_refreshes_display_fields( + queue: Queue, +) -> None: + store = AuthStore(queue) + first = store.get_or_create_oauth_user( + slot="google", + subject="555", + email="alice@acme.com", + name="Alice", + email_verified=True, + ) + again = store.get_or_create_oauth_user( + slot="google", + subject="555", + email="alice-new@acme.com", + name="Alice Renamed", + email_verified=True, + ) + assert again.username == first.username + assert again.email == "alice-new@acme.com" + assert again.display_name == "Alice Renamed" + # Role is not re-evaluated on subsequent logins. + assert again.role == first.role + + +def test_oauth_users_namespace_by_slot(queue: Queue) -> None: + store = AuthStore(queue) + a = store.get_or_create_oauth_user( + slot="okta", subject="abc", email=None, name=None, email_verified=False + ) + b = store.get_or_create_oauth_user( + slot="microsoft", subject="abc", email=None, name=None, email_verified=False + ) + assert a.username != b.username + assert store.count_users() == 2 + + +# ── HTTP endpoints ───────────────────────────────────────────────────── + + +@pytest.fixture +def dashboard_server(queue: Queue) -> Generator[tuple[str, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + port = server.server_address[1] + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + try: + yield f"http://127.0.0.1:{port}", queue + finally: + server.shutdown() + + +def _get(url: str, *, cookies: dict[str, str] | None = None) -> tuple[int, Any, dict[str, str]]: + req = urllib.request.Request(url, method="GET") + if cookies: + req.add_header("Cookie", "; ".join(f"{k}={v}" for k, v in cookies.items())) + try: + resp = urllib.request.urlopen(req) + except urllib.error.HTTPError as e: + return e.code, json.loads(e.read() or b"{}"), dict(e.headers or {}) + body = json.loads(resp.read() or b"{}") + set_cookies = resp.headers.get_all("Set-Cookie") or [] + return resp.status, body, {"Set-Cookie": "\n".join(set_cookies)} + + +def _post( + url: str, + body: dict | None = None, + *, + cookies: dict[str, str] | None = None, + headers: dict[str, str] | None = None, +) -> tuple[int, Any, dict[str, str]]: + data = json.dumps(body or {}).encode() + req = urllib.request.Request(url, method="POST", data=data) + req.add_header("Content-Type", "application/json") + if cookies: + req.add_header("Cookie", "; ".join(f"{k}={v}" for k, v in cookies.items())) + for k, v in (headers or {}).items(): + req.add_header(k, v) + try: + resp = urllib.request.urlopen(req) + except urllib.error.HTTPError as e: + return e.code, json.loads(e.read() or b"{}"), dict(e.headers or {}) + parsed = json.loads(resp.read() or b"{}") + set_cookies = resp.headers.get_all("Set-Cookie") or [] + return resp.status, parsed, {"Set-Cookie": "\n".join(set_cookies)} + + +def _parse_set_cookie(raw: str) -> dict[str, str]: + """Pull out the cookie name→value pairs from one or more Set-Cookie lines.""" + out: dict[str, str] = {} + for line in raw.splitlines(): + if not line: + continue + nv = line.split(";", 1)[0] + if "=" in nv: + name, value = nv.split("=", 1) + out[name.strip()] = value.strip() + return out + + +def test_auth_status_before_setup(dashboard_server: tuple[str, Queue]) -> None: + base, _ = dashboard_server + status, body, _ = _get(f"{base}/api/auth/status") + assert status == 200 + assert body == {"setup_required": True} + + +def test_protected_route_returns_503_before_setup(dashboard_server: tuple[str, Queue]) -> None: + base, _ = dashboard_server + status, body, _ = _get(f"{base}/api/stats") + assert status == 503 + assert body == {"error": "setup_required"} + + +def test_setup_creates_first_admin(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + status, body, _ = _post( + f"{base}/api/auth/setup", + {"username": "alice", "password": "hunter2-secret"}, + ) + assert status == 200 + assert body["user"]["username"] == "alice" + assert AuthStore(queue).count_users() == 1 + + +def test_setup_blocked_after_first_user(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + AuthStore(queue).create_user("alice", "hunter2-secret") + status, body, _ = _post( + f"{base}/api/auth/setup", + {"username": "mallory", "password": "hijack-attempt"}, + ) + assert status == 400 + assert "setup already complete" in body["error"] + + +def test_login_and_session_cookie(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + AuthStore(queue).create_user("alice", "hunter2-secret") + status, body, headers = _post( + f"{base}/api/auth/login", + {"username": "alice", "password": "hunter2-secret"}, + ) + assert status == 200 + assert body["user"]["username"] == "alice" + # Token must NOT leak in the body — it lives only in the HttpOnly cookie. + assert "token" not in body["session"] + + cookies = _parse_set_cookie(headers["Set-Cookie"]) + assert "taskito_session" in cookies + assert "taskito_csrf" in cookies + # HttpOnly must be set on the session cookie. + assert "HttpOnly" in headers["Set-Cookie"] + # CSRF cookie value must match what whoami says. + sess_token = cookies["taskito_session"] + csrf = cookies["taskito_csrf"] + status, body, _ = _get(f"{base}/api/auth/whoami", cookies={"taskito_session": sess_token}) + assert status == 200 + assert body["user"]["username"] == "alice" + assert body["csrf_token"] == csrf + + +def test_login_with_wrong_password(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + AuthStore(queue).create_user("alice", "hunter2-secret") + status, body, _ = _post( + f"{base}/api/auth/login", + {"username": "alice", "password": "nope"}, + ) + assert status == 400 + assert body["error"] == "invalid_credentials" + + +def test_whoami_without_session_returns_404(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + AuthStore(queue).create_user("alice", "hunter2-secret") + status, body, _ = _get(f"{base}/api/auth/whoami") + assert status == 401 + assert body["error"] == "not_authenticated" + + +def test_protected_get_requires_session(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + AuthStore(queue).create_user("alice", "hunter2-secret") + status, _, _ = _get(f"{base}/api/stats") + assert status == 401 + + +def test_protected_get_works_with_session(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + status, _, _ = _get(f"{base}/api/stats", cookies={"taskito_session": session.token}) + assert status == 200 + + +def test_post_requires_csrf(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + # POST with only the session cookie but no CSRF → 403. + status, body, _ = _post( + f"{base}/api/dead-letters/purge", + {}, + cookies={"taskito_session": session.token}, + ) + assert status == 403 + assert body["error"] == "csrf_failed" + + +def test_post_succeeds_with_valid_csrf(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + status, _, _ = _post( + f"{base}/api/dead-letters/purge", + {}, + cookies={ + "taskito_session": session.token, + "taskito_csrf": session.csrf_token, + }, + headers={"X-CSRF-Token": session.csrf_token}, + ) + assert status == 200 + + +def test_post_rejected_when_csrf_mismatched(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + status, body, _ = _post( + f"{base}/api/dead-letters/purge", + {}, + cookies={ + "taskito_session": session.token, + "taskito_csrf": session.csrf_token, + }, + headers={"X-CSRF-Token": "different-value"}, + ) + assert status == 403 + assert body["error"] == "csrf_failed" + + +def test_logout_invalidates_session(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + status, _, _ = _post( + f"{base}/api/auth/logout", + {}, + cookies={ + "taskito_session": session.token, + "taskito_csrf": session.csrf_token, + }, + headers={"X-CSRF-Token": session.csrf_token}, + ) + assert status == 200 + assert AuthStore(queue).get_session(session.token) is None + + +def test_change_password_flow(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + store = AuthStore(queue) + user = store.create_user("alice", "hunter2-secret") + session = store.create_session(user) + status, _, _ = _post( + f"{base}/api/auth/change-password", + {"old_password": "hunter2-secret", "new_password": "brand-new-secure"}, + cookies={ + "taskito_session": session.token, + "taskito_csrf": session.csrf_token, + }, + headers={"X-CSRF-Token": session.csrf_token}, + ) + assert status == 200 + assert store.authenticate("alice", "brand-new-secure") is not None + assert store.authenticate("alice", "hunter2-secret") is None + + +def test_health_endpoint_is_public(dashboard_server: tuple[str, Queue]) -> None: + base, queue = dashboard_server + AuthStore(queue).create_user("alice", "hunter2-secret") + status, _, _ = _get(f"{base}/health") + assert status == 200 diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index de651b0..f096738 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -11,6 +11,7 @@ import pytest from taskito import Queue +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session @pytest.fixture @@ -189,100 +190,103 @@ def _start_dashboard(queue: Queue, *, static_assets: Any = None) -> tuple[str, A @pytest.fixture def dashboard_server( populated_queue: tuple[Queue, list[Any]], -) -> Generator[tuple[str, Queue, list[Any]]]: - """Start a dashboard server on a random port.""" +) -> Generator[tuple[AuthedClient, Queue, list[Any]]]: + """Start a dashboard server on a random port and pre-seed an admin session. + + Yields ``(client, queue, jobs)`` — the client transparently attaches the + session cookie and CSRF header to every request. + """ queue, jobs = populated_queue url, server = _start_dashboard(queue) + session = seed_admin_and_session(queue) + client = AuthedClient(base=url, session=session) try: - yield url, queue, jobs + yield client, queue, jobs finally: server.shutdown() -def _get(url: str) -> Any: - """GET request and parse JSON.""" - with urllib.request.urlopen(url) as resp: - return json.loads(resp.read()) - - -def _post(url: str) -> Any: - """POST request and parse JSON.""" - req = urllib.request.Request(url, method="POST", data=b"") - with urllib.request.urlopen(req) as resp: - return json.loads(resp.read()) - - -def test_api_stats(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_stats(dashboard_server: tuple[AuthedClient, Queue, list[Any]]) -> None: """GET /api/stats returns valid stats dict.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/stats") + client, _, __ = dashboard_server + data = client.get("/api/stats") assert "pending" in data assert data["pending"] == 8 -def test_api_jobs_list(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_jobs_list(dashboard_server: tuple[AuthedClient, Queue, list[Any]]) -> None: """GET /api/jobs returns job list.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs") + client, _, __ = dashboard_server + data = client.get("/api/jobs") assert isinstance(data, list) assert len(data) == 8 -def test_api_jobs_filter_status(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_jobs_filter_status( + dashboard_server: tuple[AuthedClient, Queue, list[Any]], +) -> None: """GET /api/jobs?status=pending filters correctly.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs?status=pending") + client, _, __ = dashboard_server + data = client.get("/api/jobs?status=pending") assert len(data) == 8 - data = _get(f"{base}/api/jobs?status=running") + data = client.get("/api/jobs?status=running") assert len(data) == 0 -def test_api_jobs_filter_queue(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_jobs_filter_queue( + dashboard_server: tuple[AuthedClient, Queue, list[Any]], +) -> None: """GET /api/jobs?queue=email filters correctly.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs?queue=email") + client, _, __ = dashboard_server + data = client.get("/api/jobs?queue=email") assert len(data) == 3 -def test_api_jobs_pagination(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_jobs_pagination( + dashboard_server: tuple[AuthedClient, Queue, list[Any]], +) -> None: """GET /api/jobs?limit=3&offset=0 paginates.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs?limit=3&offset=0") + client, _, __ = dashboard_server + data = client.get("/api/jobs?limit=3&offset=0") assert len(data) == 3 -def test_api_job_detail(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_job_detail(dashboard_server: tuple[AuthedClient, Queue, list[Any]]) -> None: """GET /api/jobs/{id} returns job dict.""" - base, _, jobs = dashboard_server + client, _, jobs = dashboard_server job_id = jobs[0].id - data = _get(f"{base}/api/jobs/{job_id}") + data = client.get(f"/api/jobs/{job_id}") assert data["id"] == job_id assert "status" in data -def test_api_job_not_found(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_job_not_found( + dashboard_server: tuple[AuthedClient, Queue, list[Any]], +) -> None: """GET /api/jobs/nonexistent returns 404.""" - base, _, __ = dashboard_server + client, _, __ = dashboard_server try: - _get(f"{base}/api/jobs/nonexistent-id") + client.get("/api/jobs/nonexistent-id") raise AssertionError("Expected 404") except urllib.error.HTTPError as e: assert e.code == 404 -def test_api_cancel_job(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_cancel_job(dashboard_server: tuple[AuthedClient, Queue, list[Any]]) -> None: """POST /api/jobs/{id}/cancel cancels a pending job.""" - base, _, jobs = dashboard_server + client, _, jobs = dashboard_server job_id = jobs[0].id - data = _post(f"{base}/api/jobs/{job_id}/cancel") + data = client.post(f"/api/jobs/{job_id}/cancel") assert data["cancelled"] is True -def test_api_dead_letters_empty(dashboard_server: tuple[str, Queue, list[Any]]) -> None: +def test_api_dead_letters_empty( + dashboard_server: tuple[AuthedClient, Queue, list[Any]], +) -> None: """GET /api/dead-letters returns empty list initially.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/dead-letters") + client, _, __ = dashboard_server + data = client.get("/api/dead-letters") assert data == [] diff --git a/tests/dashboard/test_dashboard_settings.py b/tests/dashboard/test_dashboard_settings.py index ee9888d..9ba4ba2 100644 --- a/tests/dashboard/test_dashboard_settings.py +++ b/tests/dashboard/test_dashboard_settings.py @@ -13,11 +13,11 @@ import urllib.request from collections.abc import Generator from pathlib import Path -from typing import Any import pytest from taskito import Queue +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session @pytest.fixture @@ -25,28 +25,6 @@ def queue(tmp_path: Path) -> Queue: return Queue(db_path=str(tmp_path / "settings.db")) -def _put(url: str, body: dict) -> Any: - req = urllib.request.Request( - url, - method="PUT", - data=json.dumps(body).encode(), - headers={"Content-Type": "application/json"}, - ) - with urllib.request.urlopen(req) as resp: - return json.loads(resp.read()) - - -def _delete(url: str) -> Any: - req = urllib.request.Request(url, method="DELETE") - with urllib.request.urlopen(req) as resp: - return json.loads(resp.read()) - - -def _get(url: str) -> Any: - with urllib.request.urlopen(url) as resp: - return json.loads(resp.read()) - - # ── Python API ────────────────────────────────────────── @@ -98,7 +76,7 @@ def test_setting_preserves_json(queue: Queue) -> None: @pytest.fixture -def dashboard_server(queue: Queue) -> Generator[tuple[str, Queue]]: +def dashboard_server(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: from http.server import ThreadingHTTPServer from taskito.dashboard import _make_handler @@ -108,65 +86,79 @@ def dashboard_server(queue: Queue) -> Generator[tuple[str, Queue]]: port = server.server_address[1] thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{port}", session=session) try: - yield f"http://127.0.0.1:{port}", queue + yield client, queue finally: server.shutdown() -def test_get_settings_returns_empty_dict(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server - assert _get(f"{base}/api/settings") == {} +def test_get_settings_returns_empty_dict(dashboard_server: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard_server + # The admin user setting is the only one populated by the seed helper. + snapshot = client.get("/api/settings") + assert "auth:users" in snapshot + # No dashboard.* keys yet. + assert not any(k.startswith("dashboard.") for k in snapshot) -def test_put_then_get_setting(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server - _put(f"{base}/api/settings/dashboard.title", {"value": "My Queue"}) +def test_put_then_get_setting(dashboard_server: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard_server + client.put("/api/settings/dashboard.title", {"value": "My Queue"}) - data = _get(f"{base}/api/settings/dashboard.title") + data = client.get("/api/settings/dashboard.title") assert data == {"key": "dashboard.title", "value": "My Queue"} - snapshot = _get(f"{base}/api/settings") - assert snapshot == {"dashboard.title": "My Queue"} + snapshot = client.get("/api/settings") + assert snapshot["dashboard.title"] == "My Queue" -def test_put_setting_with_json_value(dashboard_server: tuple[str, Queue]) -> None: +def test_put_setting_with_json_value(dashboard_server: tuple[AuthedClient, Queue]) -> None: """Non-string ``value`` is JSON-encoded before persistence.""" - base, queue = dashboard_server + client, queue = dashboard_server payload = [ {"label": "Grafana", "url": "https://grafana.example/d/abc"}, {"label": "Sentry", "url": "https://sentry.example/issues"}, ] - _put(f"{base}/api/settings/dashboard.external_links", {"value": payload}) + client.put("/api/settings/dashboard.external_links", {"value": payload}) stored = queue.get_setting("dashboard.external_links") assert stored is not None assert json.loads(stored) == payload -def test_get_unknown_setting_returns_404(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server +def test_get_unknown_setting_returns_404(dashboard_server: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard_server with pytest.raises(urllib.error.HTTPError) as exc_info: - _get(f"{base}/api/settings/missing.key") + client.get("/api/settings/missing.key") assert exc_info.value.code == 404 def test_put_setting_with_missing_value_field_returns_400( - dashboard_server: tuple[str, Queue], + dashboard_server: tuple[AuthedClient, Queue], ) -> None: - base, _ = dashboard_server + client, _ = dashboard_server with pytest.raises(urllib.error.HTTPError) as exc_info: - _put(f"{base}/api/settings/k", {"not_value": 1}) + client.put("/api/settings/k", {"not_value": 1}) assert exc_info.value.code == 400 -def test_put_setting_rejects_invalid_json_body(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server +def test_put_setting_rejects_invalid_json_body( + dashboard_server: tuple[AuthedClient, Queue], +) -> None: + client, _ = dashboard_server req = urllib.request.Request( - f"{base}/api/settings/k", + f"{client.base}/api/settings/k", method="PUT", data=b"{not json", - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + "Cookie": ( + f"taskito_session={client.session.token}; taskito_csrf={client.session.csrf_token}" + ), + "X-CSRF-Token": client.session.csrf_token, + }, ) with pytest.raises(urllib.error.HTTPError) as exc_info: urllib.request.urlopen(req) @@ -174,19 +166,19 @@ def test_put_setting_rejects_invalid_json_body(dashboard_server: tuple[str, Queu def test_delete_setting_returns_true_when_exists( - dashboard_server: tuple[str, Queue], + dashboard_server: tuple[AuthedClient, Queue], ) -> None: - base, queue = dashboard_server + client, queue = dashboard_server queue.set_setting("k", "v") - assert _delete(f"{base}/api/settings/k") == {"deleted": True} + assert client.delete("/api/settings/k") == {"deleted": True} assert queue.get_setting("k") is None def test_delete_missing_setting_returns_false( - dashboard_server: tuple[str, Queue], + dashboard_server: tuple[AuthedClient, Queue], ) -> None: - base, _ = dashboard_server - assert _delete(f"{base}/api/settings/missing") == {"deleted": False} + client, _ = dashboard_server + assert client.delete("/api/settings/missing") == {"deleted": False} def test_settings_persist_across_queue_instances(tmp_path: Path) -> None: diff --git a/tests/dashboard/test_middleware_toggles.py b/tests/dashboard/test_middleware_toggles.py new file mode 100644 index 0000000..cc93159 --- /dev/null +++ b/tests/dashboard/test_middleware_toggles.py @@ -0,0 +1,256 @@ +"""Tests for per-task middleware enable/disable from the dashboard.""" + +from __future__ import annotations + +import threading +import urllib.error +from collections.abc import Generator +from http.server import ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.context import JobContext +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.middleware_store import MiddlewareDisableStore +from taskito.middleware import TaskMiddleware + + +class RecordingMiddleware(TaskMiddleware): + """Captures every ``before`` invocation so the test can assert which + tasks the middleware fired for.""" + + name = "test.recording" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.invocations.append(ctx.task_name) + + +class OtherMiddleware(TaskMiddleware): + name = "test.other" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.invocations.append(ctx.task_name) + + +@pytest.fixture +def middleware_pair() -> tuple[RecordingMiddleware, OtherMiddleware]: + return RecordingMiddleware(), OtherMiddleware() + + +@pytest.fixture +def queue(tmp_path: Path, middleware_pair: tuple[RecordingMiddleware, OtherMiddleware]) -> Queue: + rec, other = middleware_pair + q = Queue(db_path=str(tmp_path / "mw.db"), middleware=[rec, other]) + + @q.task() + def alpha() -> str: + return "a" + + @q.task() + def beta() -> str: + return "b" + + return q + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── Store ────────────────────────────────────────────────────────────── + + +def test_store_starts_empty(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + assert store.list_all() == {} + assert store.get_for("alpha") == [] + + +def test_set_disabled_adds_and_removes(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + store.set_disabled("alpha", "test.other", True) + assert store.get_for("alpha") == ["test.other"] + # Idempotent — same disable twice still has just one entry. + store.set_disabled("alpha", "test.other", True) + assert store.get_for("alpha") == ["test.other"] + # Re-enable clears just that one. + store.set_disabled("alpha", "test.other", False) + assert store.get_for("alpha") == [] + + +def test_clear_for_drops_setting_key(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + store.set_disabled("alpha", "test.other", True) + assert store.clear_for("alpha") is True + assert store.clear_for("alpha") is False + assert store.get_for("alpha") == [] + + +# ── Wiring into the middleware chain ────────────────────────────────── + + +def test_chain_skips_disabled_middleware(queue: Queue) -> None: + """``_get_middleware_chain`` returns a chain that respects the disable + list at lookup time — no worker restart required.""" + full = queue._get_middleware_chain("alpha") + assert {mw.name for mw in full} == {"test.recording", "test.other"} + queue.disable_middleware_for_task("alpha", "test.other") + filtered = queue._get_middleware_chain("alpha") + assert {mw.name for mw in filtered} == {"test.recording"} + # Other tasks unaffected. + assert {mw.name for mw in queue._get_middleware_chain("beta")} == { + "test.recording", + "test.other", + } + + +def test_clear_re_enables_all(queue: Queue) -> None: + queue.disable_middleware_for_task("alpha", "test.other") + queue.disable_middleware_for_task("alpha", "test.recording") + assert queue._get_middleware_chain("alpha") == [] + queue.clear_middleware_disables("alpha") + assert len(queue._get_middleware_chain("alpha")) == 2 + + +# ── Discovery ───────────────────────────────────────────────────────── + + +def test_list_middleware_reports_globals(queue: Queue) -> None: + items = queue.list_middleware() + names = {item["name"] for item in items} + assert {"test.recording", "test.other"} <= names + for entry in items: + assert any(scope["kind"] == "global" for scope in entry["scopes"]) + + +# ── HTTP endpoints ──────────────────────────────────────────────────── + + +def test_list_middleware_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + items = client.get("/api/middleware") + names = {item["name"] for item in items} + assert {"test.recording", "test.other"} <= names + + +def test_get_task_middleware_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + result = client.get("/api/tasks/alpha/middleware") + by_name = {entry["name"]: entry for entry in result["middleware"]} + assert by_name["test.recording"]["disabled"] is False + assert by_name["test.recording"]["effective"] is True + + +def test_put_task_middleware_disables(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + result = client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": False}) + assert "test.other" in result["disabled"] + # Reflected in the chain. + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" not in chain_names + # Re-enabling clears it. + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": True}) + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" in chain_names + + +def test_put_task_middleware_rejects_unknown_middleware( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/middleware/not.a.real.mw", {"enabled": False}) + assert exc_info.value.code == 404 + + +def test_put_task_middleware_rejects_bad_body( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": "yes"}) + assert exc_info.value.code == 400 + + +def test_put_task_middleware_handles_url_encoded_name( + dashboard: tuple[AuthedClient, Queue], +) -> None: + """Browser clients ``encodeURIComponent`` task names containing + ``<``, ``>``, ``/`` etc. The server has to decode the captured + group before looking up the disable list — otherwise the toggle + silently no-ops because the disable is keyed by one name but the + chain lookup uses another. Regression test for that path.""" + import urllib.parse + + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + encoded = urllib.parse.quote(name, safe="") + assert "%" in encoded, "test setup: pick a task whose qualname needs encoding" + result = client.put(f"/api/tasks/{encoded}/middleware/test.other", {"enabled": False}) + assert "test.other" in result["disabled"] + # The chain lookup uses the decoded name, so it must reflect the + # disable that was written by the encoded URL. + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" not in chain_names + + +def test_delete_task_middleware_clears_all( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": False}) + client.put(f"/api/tasks/{name}/middleware/test.recording", {"enabled": False}) + assert queue._get_middleware_chain(name) == [] + result = client.delete(f"/api/tasks/{name}/middleware") + assert result == {"cleared": True} + assert len(queue._get_middleware_chain(name)) == 2 + + +# ── End-to-end: disabled middleware doesn't fire ───────────────────── + + +def test_disabled_middleware_does_not_fire( + queue: Queue, + middleware_pair: tuple[RecordingMiddleware, OtherMiddleware], + poll_until: Any, +) -> None: + rec, other = middleware_pair + alpha_name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + queue.disable_middleware_for_task(alpha_name, "test.other") + + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + try: + queue.enqueue(alpha_name) + poll_until(lambda: alpha_name in rec.invocations, message="task didn't run") + finally: + queue._inner.request_shutdown() + thread.join(timeout=5) + + assert alpha_name in rec.invocations # global fired + assert alpha_name not in other.invocations # disabled for this task diff --git a/tests/dashboard/test_oauth_config.py b/tests/dashboard/test_oauth_config.py new file mode 100644 index 0000000..16cbe5a --- /dev/null +++ b/tests/dashboard/test_oauth_config.py @@ -0,0 +1,209 @@ +"""Tests for OAuth config parsing from env vars.""" + +from __future__ import annotations + +import pytest + +from taskito.dashboard.oauth.config import ( + GitHubConfig, + GoogleConfig, + OAuthConfig, + OAuthConfigError, + OIDCConfig, + from_env, +) + + +def test_from_env_returns_none_when_unconfigured() -> None: + assert from_env({}) is None + + +def test_from_env_requires_base_url_when_any_provider_set() -> None: + with pytest.raises(OAuthConfigError, match="REDIRECT_BASE_URL"): + from_env( + { + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET": "gsec", + } + ) + + +def test_from_env_parses_google_provider() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET": "gsec", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_ALLOWED_DOMAINS": "acme.com, partner.com", + } + ) + assert config is not None + assert config.is_enabled + assert isinstance(config.google, GoogleConfig) + assert config.google.client_id == "gid" + assert config.google.client_secret == "gsec" + assert config.google.allowed_domains == ("acme.com", "partner.com") + assert config.github is None + assert config.oidc == () + + +def test_from_env_partial_google_config_raises() -> None: + with pytest.raises(OAuthConfigError, match="CLIENT_SECRET"): + from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + } + ) + + +def test_from_env_parses_github_provider() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_ID": "hid", + "TASKITO_DASHBOARD_OAUTH_GITHUB_CLIENT_SECRET": "hsec", + "TASKITO_DASHBOARD_OAUTH_GITHUB_ALLOWED_ORGS": "acme,partner", + } + ) + assert config is not None + assert isinstance(config.github, GitHubConfig) + assert config.github.allowed_orgs == ("acme", "partner") + + +def test_from_env_parses_multiple_oidc_slots() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS": "okta,microsoft", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_ID": "oid", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_SECRET": "osec", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_DISCOVERY_URL": "https://acme.okta.com/.well-known/openid-configuration", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_LABEL": "Acme SSO", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_ALLOWED_DOMAINS": "acme.com", + "TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_CLIENT_ID": "mid", + "TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_CLIENT_SECRET": "msec", + "TASKITO_DASHBOARD_OAUTH_OIDC_MICROSOFT_DISCOVERY_URL": "https://login.microsoftonline.com/x/v2.0/.well-known/openid-configuration", + } + ) + assert config is not None + assert [p.slot for p in config.oidc] == ["okta", "microsoft"] + okta = config.oidc[0] + assert isinstance(okta, OIDCConfig) + assert okta.label == "Acme SSO" + assert okta.allowed_domains == ("acme.com",) + microsoft = config.oidc[1] + assert microsoft.label == "Microsoft" # default = title-cased slot + + +def test_from_env_rejects_duplicate_oidc_slot() -> None: + with pytest.raises(OAuthConfigError, match="twice"): + from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS": "okta,okta", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_ID": "oid", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_SECRET": "osec", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_DISCOVERY_URL": "https://x/y", + } + ) + + +def test_oidc_slot_must_not_collide_with_reserved_name() -> None: + with pytest.raises(OAuthConfigError, match=r"reserved|collides|built-in"): + OIDCConfig( + slot="google", + client_id="x", + client_secret="y", + discovery_url="https://x/y", + ) + + +def test_oidc_slot_must_be_url_safe() -> None: + with pytest.raises(OAuthConfigError): + OIDCConfig( + slot="Has Spaces", + client_id="x", + client_secret="y", + discovery_url="https://x/y", + ) + + +def test_redirect_base_url_must_be_https_for_remote_hosts() -> None: + with pytest.raises(OAuthConfigError, match="https"): + OAuthConfig(redirect_base_url="http://taskito.acme.com") + + +def test_redirect_base_url_allows_http_for_localhost() -> None: + # No exception. + OAuthConfig(redirect_base_url="http://localhost:8000") + OAuthConfig(redirect_base_url="http://127.0.0.1:8000") + + +def test_password_auth_flag_parses() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET": "gsec", + "TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED": "false", + } + ) + assert config is not None + assert config.password_auth_enabled is False + + +def test_disabling_password_without_providers_is_an_error() -> None: + with pytest.raises(OAuthConfigError, match="no way to log in"): + from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_PASSWORD_AUTH_ENABLED": "false", + } + ) + + +def test_admin_emails_parsed_as_tuple() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET": "gsec", + "TASKITO_DASHBOARD_OAUTH_ADMIN_EMAILS": " alice@acme.com , bob@acme.com ", + } + ) + assert config is not None + assert config.admin_emails == ("alice@acme.com", "bob@acme.com") + + +def test_callback_url_built_from_base_url_and_slot() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com/", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET": "gsec", + } + ) + assert config is not None + assert ( + config.callback_url("google") == "https://taskito.acme.com/api/auth/oauth/callback/google" + ) + + +def test_find_provider_returns_matching_slot() -> None: + config = from_env( + { + "TASKITO_DASHBOARD_OAUTH_REDIRECT_BASE_URL": "https://taskito.acme.com", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_ID": "gid", + "TASKITO_DASHBOARD_OAUTH_GOOGLE_CLIENT_SECRET": "gsec", + "TASKITO_DASHBOARD_OAUTH_OIDC_PROVIDERS": "okta", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_ID": "oid", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_CLIENT_SECRET": "osec", + "TASKITO_DASHBOARD_OAUTH_OIDC_OKTA_DISCOVERY_URL": "https://acme.okta.com/.well-known/openid-configuration", + } + ) + assert config is not None + assert config.find_provider("google") is config.google + okta = config.find_provider("okta") + assert isinstance(okta, OIDCConfig) + assert config.find_provider("does-not-exist") is None diff --git a/tests/dashboard/test_oauth_endpoints.py b/tests/dashboard/test_oauth_endpoints.py new file mode 100644 index 0000000..d6b95ba --- /dev/null +++ b/tests/dashboard/test_oauth_endpoints.py @@ -0,0 +1,400 @@ +"""HTTP-level integration tests for the OAuth endpoints. + +Spins up a real :class:`ThreadingHTTPServer` with a stubbed +:class:`OAuthFlow` so we can drive the full request → 302-redirect → +cookies path without making real provider calls. +""" + +from __future__ import annotations + +import contextlib +import json +import threading +import urllib.error +import urllib.request +from collections.abc import Callable, Generator +from http.server import ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard.auth import AuthStore +from taskito.dashboard.oauth.config import ( + GitHubConfig, + GoogleConfig, + OAuthConfig, + OIDCConfig, +) +from taskito.dashboard.oauth.flow import OAuthFlow +from taskito.dashboard.oauth.identity import ( + AllowlistDenied, + IdentityFetchError, + ProviderIdentity, +) + + +@pytest.fixture +def queue(tmp_path: Path) -> Queue: + return Queue(db_path=str(tmp_path / "oauth_endpoints.db")) + + +class _FakeProvider: + """Programmable provider used by the integration tests.""" + + def __init__(self, slot: str, *, label: str = "Test", ptype: str = "google") -> None: + self.slot = slot + self.label = label + self.type = ptype + self.identity: ProviderIdentity | None = None + self.allow = True + self.start_called_with: dict[str, str] | None = None + + def authorization_url( + self, + *, + state: str, + nonce: str, + code_challenge: str, + redirect_uri: str, + ) -> str: + self.start_called_with = { + "state": state, + "nonce": nonce, + "code_challenge": code_challenge, + "redirect_uri": redirect_uri, + } + return f"https://idp.example.com/authorize?state={state}" + + def exchange_code( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + expected_nonce: str | None, + ) -> ProviderIdentity: + if self.identity is None: + raise IdentityFetchError("no identity configured") + return self.identity + + def check_allowlist(self, identity: ProviderIdentity) -> None: + if not self.allow: + raise AllowlistDenied("denied") + + +@pytest.fixture +def google_provider() -> _FakeProvider: + return _FakeProvider("google", label="Google", ptype="google") + + +@pytest.fixture +def okta_provider() -> _FakeProvider: + return _FakeProvider("okta", label="Acme SSO", ptype="oidc") + + +def _make_flow( + queue: Queue, + providers: dict[str, _FakeProvider], + *, + password_enabled: bool = True, + admin_emails: tuple[str, ...] = (), +) -> OAuthFlow: + google_cfg = GoogleConfig(client_id="gid", client_secret="gsec") + github_cfg = GitHubConfig(client_id="hid", client_secret="hsec") + config = OAuthConfig( + redirect_base_url="http://127.0.0.1", + google=google_cfg if "google" in providers else None, + github=github_cfg if "github" in providers else None, + oidc=tuple( + OIDCConfig( + slot=slot, + client_id="x", + client_secret="y", + discovery_url=f"https://idp/{slot}/.well-known/openid-configuration", + ) + for slot in providers + if slot not in ("google", "github") + ), + password_auth_enabled=password_enabled, + admin_emails=admin_emails, + ) + return OAuthFlow(queue, config, providers=providers) # type: ignore[arg-type] + + +@pytest.fixture +def server_factory( + queue: Queue, +) -> Generator[Callable[[OAuthFlow | None], str]]: + """Spawns dashboard servers with the requested OAuthFlow.""" + handles: list[ThreadingHTTPServer] = [] + + def _factory(flow: OAuthFlow | None) -> str: + handler = _make_handler(queue, oauth_flow=flow) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + handles.append(server) + port = server.server_address[1] + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return f"http://127.0.0.1:{port}" + + yield _factory + + for server in handles: + server.shutdown() + + +def _get_no_redirect( + url: str, *, cookies: dict[str, str] | None = None +) -> tuple[int, Any, dict[str, list[str]]]: + """GET without following redirects, returning (status, body, headers).""" + + class _NoRedirect(urllib.request.HTTPRedirectHandler): + def redirect_request(self, *_a: Any, **_k: Any) -> None: + return None + + opener = urllib.request.build_opener(_NoRedirect()) + req = urllib.request.Request(url, method="GET") + if cookies: + req.add_header("Cookie", "; ".join(f"{k}={v}" for k, v in cookies.items())) + try: + resp = opener.open(req) + body: Any = None + try: + raw = resp.read() + body = json.loads(raw) if raw else None + except (ValueError, json.JSONDecodeError): + body = None + headers = {k: resp.headers.get_all(k) or [] for k in set(resp.headers.keys())} + return resp.status, body, headers + except urllib.error.HTTPError as e: + body = None + with contextlib.suppress(ValueError, json.JSONDecodeError): + body = json.loads(e.read() or b"{}") + headers = {k: e.headers.get_all(k) or [] for k in set(e.headers.keys())} + return e.code, body, headers + + +def _parse_set_cookies(raw: list[str]) -> dict[str, str]: + out: dict[str, str] = {} + for line in raw: + nv = line.split(";", 1)[0] + if "=" in nv: + name, value = nv.split("=", 1) + out[name.strip()] = value.strip() + return out + + +# ── /api/auth/providers ────────────────────────────────────────────── + + +def test_providers_endpoint_returns_empty_list_when_no_flow( + server_factory: Any, +) -> None: + base = server_factory(None) + status, body, _ = _get_no_redirect(f"{base}/api/auth/providers") + assert status == 200 + assert body == {"password_enabled": True, "providers": []} + + +def test_providers_endpoint_lists_configured_providers( + server_factory: Any, + queue: Queue, + google_provider: _FakeProvider, + okta_provider: _FakeProvider, +) -> None: + flow = _make_flow(queue, {"google": google_provider, "okta": okta_provider}) + base = server_factory(flow) + status, body, _ = _get_no_redirect(f"{base}/api/auth/providers") + assert status == 200 + assert body == { + "password_enabled": True, + "providers": [ + {"slot": "google", "label": "Google", "type": "google"}, + {"slot": "okta", "label": "Acme SSO", "type": "oidc"}, + ], + } + + +def test_providers_endpoint_reflects_password_disabled( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + flow = _make_flow(queue, {"google": google_provider}, password_enabled=False) + base = server_factory(flow) + _, body, _ = _get_no_redirect(f"{base}/api/auth/providers") + assert body["password_enabled"] is False + + +# ── /api/auth/oauth/start/{slot} ───────────────────────────────────── + + +def test_start_returns_302_to_provider( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + status, _, headers = _get_no_redirect(f"{base}/api/auth/oauth/start/google") + assert status == 302 + locations = headers.get("Location") or [] + assert len(locations) == 1 + assert locations[0].startswith("https://idp.example.com/authorize?state=") + assert google_provider.start_called_with is not None + assert google_provider.start_called_with["redirect_uri"].endswith( + "/api/auth/oauth/callback/google" + ) + + +def test_start_returns_404_for_unknown_slot( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + status, body, _ = _get_no_redirect(f"{base}/api/auth/oauth/start/azure") + assert status == 404 + assert body is not None and "azure" in body.get("error", "") + + +def test_start_returns_404_when_oauth_not_configured( + server_factory: Any, +) -> None: + base = server_factory(None) + status, body, _ = _get_no_redirect(f"{base}/api/auth/oauth/start/google") + assert status == 404 + assert body is not None + assert body.get("error") == "oauth_not_configured" + + +# ── /api/auth/oauth/callback/{slot} ────────────────────────────────── + + +def test_callback_creates_session_and_sets_cookies( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + google_provider.identity = ProviderIdentity( + slot="google", + subject="118420987654321", + email="alice@acme.com", + email_verified=True, + name="Alice", + ) + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + + # First /start to mint state. + start_status, _, headers = _get_no_redirect( + f"{base}/api/auth/oauth/start/google?next=/dashboard" + ) + assert start_status == 302 + location = headers["Location"][0] + state = location.split("state=")[-1] + + cb_status, _, cb_headers = _get_no_redirect( + f"{base}/api/auth/oauth/callback/google?code=abc&state={state}" + ) + assert cb_status == 302 + # Redirected to the safe ``next`` URL. + assert cb_headers["Location"] == ["/dashboard"] + + cookies = _parse_set_cookies(cb_headers.get("Set-Cookie", [])) + assert "taskito_session" in cookies + assert "taskito_csrf" in cookies + assert cookies["taskito_session"] + + # A user was created in the AuthStore with the OAuth username scheme. + user = AuthStore(queue).get_user("google:118420987654321") + assert user is not None + assert user.email == "alice@acme.com" + assert user.is_oauth + + +def test_callback_rejects_unsafe_next_via_fallback_root( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + google_provider.identity = ProviderIdentity( + slot="google", subject="2", email="bob@acme.com", email_verified=True + ) + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + _, _, headers = _get_no_redirect( + f"{base}/api/auth/oauth/start/google?next=https://evil.com/take" + ) + state = headers["Location"][0].split("state=")[-1] + _, _, cb_headers = _get_no_redirect( + f"{base}/api/auth/oauth/callback/google?code=abc&state={state}" + ) + # Unsafe next was scrubbed to "/" before being persisted with the state. + assert cb_headers["Location"] == ["/"] + + +def test_callback_replayed_state_is_rejected( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + google_provider.identity = ProviderIdentity( + slot="google", subject="3", email="c@acme.com", email_verified=True + ) + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + _, _, headers = _get_no_redirect(f"{base}/api/auth/oauth/start/google") + state = headers["Location"][0].split("state=")[-1] + # First callback succeeds. + first_status, _, _ = _get_no_redirect( + f"{base}/api/auth/oauth/callback/google?code=abc&state={state}" + ) + assert first_status == 302 + # Replay is a 400. + replay_status, body, _ = _get_no_redirect( + f"{base}/api/auth/oauth/callback/google?code=abc&state={state}" + ) + assert replay_status == 400 + assert body is not None + assert "oauth_state_invalid" in body.get("error", "") + + +def test_callback_with_provider_error_returns_400( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + status, body, _ = _get_no_redirect( + f"{base}/api/auth/oauth/callback/google?error=access_denied" + ) + assert status == 400 + assert body is not None + assert "oauth_state_invalid" in body.get("error", "") or "identity" in body.get("error", "") + + +def test_callback_blocked_by_allowlist( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + google_provider.identity = ProviderIdentity( + slot="google", subject="4", email="eve@evil.com", email_verified=True + ) + google_provider.allow = False + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + _, _, headers = _get_no_redirect(f"{base}/api/auth/oauth/start/google") + state = headers["Location"][0].split("state=")[-1] + status, body, _ = _get_no_redirect( + f"{base}/api/auth/oauth/callback/google?code=abc&state={state}" + ) + assert status == 400 + assert body is not None + assert "allowlist_denied" in body.get("error", "") + + +def test_oauth_paths_bypass_setup_required_gate( + server_factory: Any, queue: Queue, google_provider: _FakeProvider +) -> None: + """Even before the first user exists, the OAuth flow paths must answer. + + Otherwise a fresh deployment using OAuth-only mode could never bootstrap. + """ + flow = _make_flow(queue, {"google": google_provider}) + base = server_factory(flow) + assert AuthStore(queue).count_users() == 0 + status, _, _ = _get_no_redirect(f"{base}/api/auth/providers") + assert status == 200 + status, _, _ = _get_no_redirect(f"{base}/api/auth/oauth/start/google") + assert status == 302 diff --git a/tests/dashboard/test_oauth_flow.py b/tests/dashboard/test_oauth_flow.py new file mode 100644 index 0000000..825b9d6 --- /dev/null +++ b/tests/dashboard/test_oauth_flow.py @@ -0,0 +1,208 @@ +"""Tests for :class:`OAuthFlow` — state + provider + auth-store orchestration.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from taskito import Queue +from taskito.dashboard.auth import AuthStore +from taskito.dashboard.oauth.config import GoogleConfig, OAuthConfig +from taskito.dashboard.oauth.flow import OAuthFlow +from taskito.dashboard.oauth.identity import ( + AllowlistDenied, + IdentityFetchError, + ProviderIdentity, + ProviderNotConfigured, + StateValidationError, +) + + +@pytest.fixture +def queue(tmp_path: Path) -> Queue: + return Queue(db_path=str(tmp_path / "oauth_flow.db")) + + +@pytest.fixture +def config() -> OAuthConfig: + return OAuthConfig( + redirect_base_url="https://taskito.example.com", + google=GoogleConfig( + client_id="cid", + client_secret="csec", + allowed_domains=("acme.com",), + ), + admin_emails=("alice@acme.com",), + ) + + +class FakeProvider: + """In-memory provider with programmable identity / allowlist behaviour.""" + + type = "google" + label = "Test" + + def __init__(self, slot: str, identity: ProviderIdentity | None = None) -> None: + self.slot = slot + self.identity = identity + self.allow = True + self.last_authorization_args: dict | None = None + + def authorization_url( + self, + *, + state: str, + nonce: str, + code_challenge: str, + redirect_uri: str, + ) -> str: + self.last_authorization_args = { + "state": state, + "nonce": nonce, + "code_challenge": code_challenge, + "redirect_uri": redirect_uri, + } + return f"https://idp.example.com/authorize?state={state}" + + def exchange_code( + self, + *, + code: str, + code_verifier: str, + redirect_uri: str, + expected_nonce: str | None, + ) -> ProviderIdentity: + if self.identity is None: + raise IdentityFetchError("test stub: no identity configured") + return self.identity + + def check_allowlist(self, identity: ProviderIdentity) -> None: + if not self.allow: + raise AllowlistDenied("test stub: denied") + + +def test_start_returns_provider_url_with_safe_next(queue: Queue, config: OAuthConfig) -> None: + fake = FakeProvider("google") + flow = OAuthFlow(queue, config, providers={"google": fake}) + url = flow.start("google", next_url="/dashboard/jobs") + assert url.startswith("https://idp.example.com/authorize?state=") + args = fake.last_authorization_args + assert args is not None + assert args["redirect_uri"] == "https://taskito.example.com/api/auth/oauth/callback/google" + assert len(args["state"]) >= 32 + + +def test_start_falls_back_to_root_when_next_unsafe(queue: Queue, config: OAuthConfig) -> None: + fake = FakeProvider("google") + flow = OAuthFlow(queue, config, providers={"google": fake}) + flow.start("google", next_url="https://evil.com/x") + # We can't read state.next_url back without inspecting the store — + # but we can confirm the callback rejects it via a separate test. + + +def test_start_raises_for_unknown_slot(queue: Queue, config: OAuthConfig) -> None: + flow = OAuthFlow(queue, config, providers={}) + with pytest.raises(ProviderNotConfigured): + flow.start("nonexistent", next_url="/") + + +def test_handle_callback_creates_user_and_session(queue: Queue, config: OAuthConfig) -> None: + identity = ProviderIdentity( + slot="google", + subject="100200300", + email="alice@acme.com", + email_verified=True, + name="Alice", + ) + fake = FakeProvider("google", identity=identity) + flow = OAuthFlow(queue, config, providers={"google": fake}) + + # Mint state then handle the callback. + flow.start("google", next_url="/dashboard") + state_token = next(iter(_state_tokens(queue))) + + session, next_url = flow.handle_callback( + "google", code="abc", state_token=state_token, error=None + ) + assert session.username == "google:100200300" + assert session.role == "admin" # alice is in admin_emails + assert next_url == "/dashboard" + + # Replay attempt fails because state is single-use. + with pytest.raises(StateValidationError, match="invalid"): + flow.handle_callback("google", code="abc", state_token=state_token, error=None) + + +def test_handle_callback_rejects_slot_mismatch(queue: Queue, config: OAuthConfig) -> None: + identity = ProviderIdentity(slot="google", subject="x", email=None, email_verified=False) + fake = FakeProvider("google", identity=identity) + flow = OAuthFlow(queue, config, providers={"google": fake}) + flow.start("google", next_url="/") + state_token = next(iter(_state_tokens(queue))) + with pytest.raises(StateValidationError, match="slot"): + flow.handle_callback("github", code="abc", state_token=state_token, error=None) + + +def test_handle_callback_propagates_provider_error(queue: Queue, config: OAuthConfig) -> None: + flow = OAuthFlow(queue, config, providers={"google": FakeProvider("google")}) + flow.start("google", next_url="/") + state_token = next(iter(_state_tokens(queue))) + with pytest.raises(IdentityFetchError): + flow.handle_callback("google", code="abc", state_token=state_token, error=None) + + +def test_handle_callback_propagates_allowlist_denied(queue: Queue, config: OAuthConfig) -> None: + identity = ProviderIdentity( + slot="google", + subject="1", + email="eve@evil.com", + email_verified=True, + ) + fake = FakeProvider("google", identity=identity) + fake.allow = False + flow = OAuthFlow(queue, config, providers={"google": fake}) + flow.start("google", next_url="/") + state_token = next(iter(_state_tokens(queue))) + with pytest.raises(AllowlistDenied): + flow.handle_callback("google", code="abc", state_token=state_token, error=None) + + +def test_handle_callback_with_provider_error_raises(queue: Queue, config: OAuthConfig) -> None: + fake = FakeProvider("google") + flow = OAuthFlow(queue, config, providers={"google": fake}) + with pytest.raises(IdentityFetchError, match="provider returned error"): + flow.handle_callback("google", code=None, state_token=None, error="access_denied") + + +def test_providers_listing_returns_visible_metadata(queue: Queue, config: OAuthConfig) -> None: + fake = FakeProvider("google") + flow = OAuthFlow(queue, config, providers={"google": fake}) + listing = flow.providers_listing() + assert listing == [{"slot": "google", "label": "Test", "type": "google"}] + + +def test_admin_emails_promote_first_user(queue: Queue, config: OAuthConfig) -> None: + identity = ProviderIdentity( + slot="google", + subject="alice-sub", + email="alice@acme.com", + email_verified=True, + ) + fake = FakeProvider("google", identity=identity) + flow = OAuthFlow(queue, config, providers={"google": fake}) + flow.start("google", next_url="/") + state_token = next(iter(_state_tokens(queue))) + session, _ = flow.handle_callback("google", code="x", state_token=state_token, error=None) + user = AuthStore(queue).get_user(session.username) + assert user is not None + assert user.role == "admin" + assert user.email == "alice@acme.com" + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _state_tokens(queue: Queue) -> list[str]: + prefix = "auth:oauth_state:" + return [k[len(prefix) :] for k in queue.list_settings() if k.startswith(prefix)] diff --git a/tests/dashboard/test_oauth_providers.py b/tests/dashboard/test_oauth_providers.py new file mode 100644 index 0000000..6d29aa0 --- /dev/null +++ b/tests/dashboard/test_oauth_providers.py @@ -0,0 +1,575 @@ +"""Unit tests for the concrete OAuth provider implementations. + +These tests stub every HTTP boundary so they run without network access. +The end-to-end "real flow" test lives in ``test_oauth_endpoints.py``. +""" + +from __future__ import annotations + +import json +import time +from typing import Any +from urllib.parse import parse_qs, urlparse + +import pytest +from joserfc import jwt as joserfc_jwt +from joserfc.jwk import RSAKey + +from taskito.dashboard.oauth.config import ( + GitHubConfig, + GoogleConfig, + OIDCConfig, +) +from taskito.dashboard.oauth.identity import ( + AllowlistDenied, + IdentityFetchError, +) +from taskito.dashboard.oauth.providers import ( + GenericOIDCProvider, + GitHubProvider, + GoogleProvider, + _audience_matches, + _email_domain, +) + +# ── HTTP stub helpers ──────────────────────────────────────────────── + + +class StubResponse: + """Minimal stand-in for a ``requests.Response`` object.""" + + def __init__(self, *, status_code: int = 200, payload: Any = None, text: str = "") -> None: + self.status_code = status_code + self._payload = payload + self.text = text or json.dumps(payload) + + def json(self) -> Any: + return self._payload + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + +class StubSession: + """Replaces ``requests.Session`` with a programmable URL → response map.""" + + def __init__(self, routes: dict[str, StubResponse]) -> None: + self._routes = routes + self.calls: list[tuple[str, dict[str, str]]] = [] + + def get(self, url: str, *, headers: dict[str, str] | None = None, **_: Any) -> StubResponse: + self.calls.append((url, headers or {})) + if url in self._routes: + return self._routes[url] + # Wildcard fallback: match by prefix to support .../members/. + for prefix, response in self._routes.items(): + if prefix.endswith("*") and url.startswith(prefix[:-1]): + return response + return StubResponse(status_code=404, payload={"error": "not found"}) + + +# ── Test fixtures ──────────────────────────────────────────────────── + + +@pytest.fixture +def rsa_key() -> RSAKey: + """A fresh RSA keypair used to sign + verify test ID tokens.""" + return RSAKey.generate_key(2048, parameters={"kid": "test-kid"}, private=True) + + +@pytest.fixture +def google_discovery() -> dict[str, str]: + return { + "issuer": "https://accounts.google.com", + "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", + "token_endpoint": "https://oauth2.googleapis.com/token", + "jwks_uri": "https://www.googleapis.com/oauth2/v3/certs", + } + + +def _make_google_provider( + *, + allowed_domains: tuple[str, ...] = (), + discovery: dict[str, str], + jwks_payload: dict, +) -> GoogleProvider: + routes = { + "https://accounts.google.com/.well-known/openid-configuration": StubResponse( + payload=discovery + ), + discovery["jwks_uri"]: StubResponse(payload=jwks_payload), + } + provider = GoogleProvider( + GoogleConfig( + client_id="test-client-id", + client_secret="test-client-secret", + allowed_domains=allowed_domains, + ), + http=StubSession(routes), + ) + return provider + + +def _make_id_token( + *, + key: RSAKey, + issuer: str, + audience: str, + subject: str, + email: str, + email_verified: bool, + nonce: str | None, + name: str | None = "Alice Example", + extra_claims: dict[str, Any] | None = None, +) -> str: + claims: dict[str, Any] = { + "iss": issuer, + "aud": audience, + "sub": subject, + "email": email, + "email_verified": email_verified, + "name": name, + "iat": int(time.time()), + "exp": int(time.time()) + 600, + } + if nonce is not None: + claims["nonce"] = nonce + if extra_claims: + claims.update(extra_claims) + header = {"alg": "RS256", "kid": key.kid} + encoded: str = joserfc_jwt.encode(header, claims, key) + return encoded + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def test_email_domain_extracts_lowercase() -> None: + assert _email_domain("Alice@ACME.com") == "acme.com" + assert _email_domain(None) is None + assert _email_domain("not-an-email") is None + + +def test_audience_matches_string_and_list() -> None: + assert _audience_matches("cid", "cid") + assert _audience_matches(["cid", "other"], "cid") + assert not _audience_matches("other", "cid") + assert not _audience_matches([], "cid") + assert not _audience_matches(None, "cid") + + +# ── Google: authorization URL ──────────────────────────────────────── + + +def test_google_authorization_url_includes_required_params( + google_discovery: dict[str, str], +) -> None: + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + ) + url = provider.authorization_url( + state="STATE", + nonce="NONCE", + code_challenge="CHALLENGE", + redirect_uri="https://taskito.example.com/api/auth/oauth/callback/google", + ) + parsed = urlparse(url) + qs = parse_qs(parsed.query) + assert parsed.scheme == "https" + assert qs["response_type"] == ["code"] + assert qs["client_id"] == ["test-client-id"] + assert qs["scope"] == ["openid email profile"] + assert qs["state"] == ["STATE"] + assert qs["nonce"] == ["NONCE"] + assert qs["code_challenge"] == ["CHALLENGE"] + assert qs["code_challenge_method"] == ["S256"] + assert qs["prompt"] == ["select_account"] + assert "hd" not in qs # no allowed_domains configured + + +def test_google_authorization_url_sets_hd_hint_for_single_domain( + google_discovery: dict[str, str], +) -> None: + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + allowed_domains=("acme.com",), + ) + url = provider.authorization_url( + state="s", nonce="n", code_challenge="c", redirect_uri="https://x/y" + ) + qs = parse_qs(urlparse(url).query) + assert qs["hd"] == ["acme.com"] + + +def test_google_authorization_url_omits_hd_for_multi_domain( + google_discovery: dict[str, str], +) -> None: + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + allowed_domains=("acme.com", "partner.com"), + ) + url = provider.authorization_url( + state="s", nonce="n", code_challenge="c", redirect_uri="https://x/y" + ) + qs = parse_qs(urlparse(url).query) + assert "hd" not in qs # ambiguous, do not preselect + + +# ── Google: exchange_code → identity ───────────────────────────────── + + +def test_google_exchange_code_returns_identity_for_valid_id_token( + google_discovery: dict[str, str], rsa_key: RSAKey, monkeypatch: pytest.MonkeyPatch +) -> None: + id_token = _make_id_token( + key=rsa_key, + issuer=google_discovery["issuer"], + audience="test-client-id", + subject="118420987654321", + email="alice@acme.com", + email_verified=True, + nonce="EXPECTED_NONCE", + ) + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": [rsa_key.as_dict(private=False)]}, + ) + monkeypatch.setattr( + provider, "_fetch_token", lambda **_: {"id_token": id_token, "access_token": "AT"} + ) + identity = provider.exchange_code( + code="abc", + code_verifier="verifier", + redirect_uri="https://x", + expected_nonce="EXPECTED_NONCE", + ) + assert identity.slot == "google" + assert identity.subject == "118420987654321" + assert identity.email == "alice@acme.com" + assert identity.email_verified is True + assert identity.name == "Alice Example" + + +def test_google_exchange_code_rejects_wrong_nonce( + google_discovery: dict[str, str], rsa_key: RSAKey, monkeypatch: pytest.MonkeyPatch +) -> None: + id_token = _make_id_token( + key=rsa_key, + issuer=google_discovery["issuer"], + audience="test-client-id", + subject="111", + email="x@y.com", + email_verified=True, + nonce="WRONG", + ) + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": [rsa_key.as_dict(private=False)]}, + ) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"id_token": id_token}) + with pytest.raises(IdentityFetchError, match="nonce mismatch"): + provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce="EXPECTED" + ) + + +def test_google_exchange_code_rejects_wrong_audience( + google_discovery: dict[str, str], rsa_key: RSAKey, monkeypatch: pytest.MonkeyPatch +) -> None: + id_token = _make_id_token( + key=rsa_key, + issuer=google_discovery["issuer"], + audience="DIFFERENT-CLIENT", + subject="111", + email="x@y.com", + email_verified=True, + nonce=None, + ) + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": [rsa_key.as_dict(private=False)]}, + ) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"id_token": id_token}) + with pytest.raises(IdentityFetchError, match="audience mismatch"): + provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) + + +def test_google_exchange_code_rejects_wrong_issuer( + google_discovery: dict[str, str], rsa_key: RSAKey, monkeypatch: pytest.MonkeyPatch +) -> None: + id_token = _make_id_token( + key=rsa_key, + issuer="https://evil.com", + audience="test-client-id", + subject="111", + email="x@y.com", + email_verified=True, + nonce=None, + ) + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": [rsa_key.as_dict(private=False)]}, + ) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"id_token": id_token}) + with pytest.raises(IdentityFetchError, match="issuer mismatch"): + provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) + + +def test_google_exchange_code_rejects_missing_id_token( + google_discovery: dict[str, str], monkeypatch: pytest.MonkeyPatch +) -> None: + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + ) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"access_token": "AT"}) + with pytest.raises(IdentityFetchError, match="no id_token"): + provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) + + +# ── Google: allowlist ───────────────────────────────────────────────── + + +def test_google_check_allowlist_passes_when_no_restriction( + google_discovery: dict[str, str], +) -> None: + provider = _make_google_provider(discovery=google_discovery, jwks_payload={"keys": []}) + from taskito.dashboard.oauth.identity import ProviderIdentity + + identity = ProviderIdentity(slot="google", subject="x", email="x@y.com", email_verified=True) + # Should not raise. + provider.check_allowlist(identity) + + +def test_google_check_allowlist_rejects_unverified_email( + google_discovery: dict[str, str], +) -> None: + from taskito.dashboard.oauth.identity import ProviderIdentity + + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + allowed_domains=("acme.com",), + ) + with pytest.raises(AllowlistDenied, match="verified email"): + provider.check_allowlist( + ProviderIdentity( + slot="google", + subject="x", + email="user@acme.com", + email_verified=False, + ) + ) + + +def test_google_check_allowlist_rejects_out_of_domain_email( + google_discovery: dict[str, str], +) -> None: + from taskito.dashboard.oauth.identity import ProviderIdentity + + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + allowed_domains=("acme.com",), + ) + with pytest.raises(AllowlistDenied, match="not in the allowed domains"): + provider.check_allowlist( + ProviderIdentity( + slot="google", + subject="x", + email="user@gmail.com", + email_verified=True, + ) + ) + + +def test_google_check_allowlist_accepts_listed_domain( + google_discovery: dict[str, str], +) -> None: + from taskito.dashboard.oauth.identity import ProviderIdentity + + provider = _make_google_provider( + discovery=google_discovery, + jwks_payload={"keys": []}, + allowed_domains=("acme.com",), + ) + provider.check_allowlist( + ProviderIdentity(slot="google", subject="x", email="USER@Acme.COM", email_verified=True) + ) + + +# ── Generic OIDC ────────────────────────────────────────────────────── + + +def test_generic_oidc_uses_provided_discovery_url(rsa_key: RSAKey) -> None: + discovery = { + "issuer": "https://acme.okta.com", + "authorization_endpoint": "https://acme.okta.com/oauth2/authorize", + "token_endpoint": "https://acme.okta.com/oauth2/token", + "jwks_uri": "https://acme.okta.com/oauth2/jwks", + } + routes = { + "https://acme.okta.com/.well-known/openid-configuration": StubResponse(payload=discovery), + discovery["jwks_uri"]: StubResponse(payload={"keys": [rsa_key.as_dict(private=False)]}), + } + provider = GenericOIDCProvider( + OIDCConfig( + slot="okta", + client_id="cid", + client_secret="csec", + discovery_url="https://acme.okta.com/.well-known/openid-configuration", + label="Acme SSO", + ), + http=StubSession(routes), + ) + url = provider.authorization_url( + state="s", nonce="n", code_challenge="c", redirect_uri="https://taskito.x/cb" + ) + assert url.startswith("https://acme.okta.com/oauth2/authorize?") + assert provider.slot == "okta" + assert provider.label == "Acme SSO" + assert provider.type == "oidc" + + +# ── GitHub ──────────────────────────────────────────────────────────── + + +def _gh_provider( + *, + allowed_orgs: tuple[str, ...] = (), + routes: dict[str, StubResponse] | None = None, +) -> GitHubProvider: + return GitHubProvider( + GitHubConfig( + client_id="gh-client", + client_secret="gh-secret", + allowed_orgs=allowed_orgs, + ), + http=StubSession(routes or {}), + ) + + +def test_github_authorization_url_includes_pkce_and_state() -> None: + provider = _gh_provider() + url = provider.authorization_url( + state="STATE", nonce="UNUSED", code_challenge="CHL", redirect_uri="https://x/cb" + ) + parsed = urlparse(url) + qs = parse_qs(parsed.query) + assert parsed.netloc == "github.com" + assert qs["client_id"] == ["gh-client"] + assert qs["state"] == ["STATE"] + assert qs["code_challenge"] == ["CHL"] + assert qs["code_challenge_method"] == ["S256"] + assert "nonce" not in qs # GitHub does not implement OIDC + + +def test_github_authorization_url_adds_read_org_when_allowlist_configured() -> None: + provider = _gh_provider(allowed_orgs=("acme",)) + url = provider.authorization_url( + state="s", nonce="n", code_challenge="c", redirect_uri="https://x/cb" + ) + qs = parse_qs(urlparse(url).query) + assert "read:org" in qs["scope"][0] + + +def test_github_exchange_code_returns_verified_primary_email( + monkeypatch: pytest.MonkeyPatch, +) -> None: + routes = { + "https://api.github.com/user": StubResponse( + payload={"id": 584213, "login": "alice", "name": "Alice", "avatar_url": "https://x/y"} + ), + "https://api.github.com/user/emails": StubResponse( + payload=[ + {"email": "alt@x.com", "primary": False, "verified": True}, + {"email": "alice@acme.com", "primary": True, "verified": True}, + ] + ), + } + provider = _gh_provider(routes=routes) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"access_token": "AT"}) + identity = provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) + assert identity.slot == "github" + assert identity.subject == "584213" + assert identity.email == "alice@acme.com" + assert identity.email_verified is True + assert identity.name == "Alice" + + +def test_github_exchange_code_returns_none_email_when_no_verified_primary( + monkeypatch: pytest.MonkeyPatch, +) -> None: + routes = { + "https://api.github.com/user": StubResponse( + payload={"id": 1, "login": "u", "name": None, "avatar_url": None} + ), + "https://api.github.com/user/emails": StubResponse( + payload=[ + {"email": "claimed@x.com", "primary": True, "verified": False}, + ] + ), + } + provider = _gh_provider(routes=routes) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"access_token": "AT"}) + identity = provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) + assert identity.email is None + assert identity.email_verified is False + + +def test_github_exchange_code_enforces_org_membership( + monkeypatch: pytest.MonkeyPatch, +) -> None: + routes = { + "https://api.github.com/user": StubResponse( + payload={"id": 1, "login": "alice", "name": "A", "avatar_url": "x"} + ), + "https://api.github.com/user/emails": StubResponse( + payload=[{"email": "a@x.com", "primary": True, "verified": True}] + ), + "https://api.github.com/orgs/acme/members/alice": StubResponse( + status_code=204, payload=None, text="" + ), + } + provider = _gh_provider(allowed_orgs=("acme",), routes=routes) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"access_token": "AT"}) + identity = provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) + assert identity.email == "a@x.com" + + +def test_github_exchange_code_rejects_non_member( + monkeypatch: pytest.MonkeyPatch, +) -> None: + routes = { + "https://api.github.com/user": StubResponse( + payload={"id": 1, "login": "eve", "name": "E", "avatar_url": "x"} + ), + "https://api.github.com/user/emails": StubResponse( + payload=[{"email": "e@x.com", "primary": True, "verified": True}] + ), + "https://api.github.com/orgs/acme/members/eve": StubResponse( + status_code=404, payload={"message": "Not Found"} + ), + } + provider = _gh_provider(allowed_orgs=("acme",), routes=routes) + monkeypatch.setattr(provider, "_fetch_token", lambda **_: {"access_token": "AT"}) + with pytest.raises(AllowlistDenied, match="not a member"): + provider.exchange_code( + code="abc", code_verifier="v", redirect_uri="https://x", expected_nonce=None + ) diff --git a/tests/dashboard/test_oauth_state_store.py b/tests/dashboard/test_oauth_state_store.py new file mode 100644 index 0000000..d20d764 --- /dev/null +++ b/tests/dashboard/test_oauth_state_store.py @@ -0,0 +1,98 @@ +"""Tests for the short-lived OAuth state store.""" + +from __future__ import annotations + +import time +from pathlib import Path + +import pytest + +from taskito import Queue +from taskito.dashboard.oauth.state_store import ( + DEFAULT_STATE_TTL_SECONDS, + STATE_PREFIX, + OAuthStateStore, +) + + +@pytest.fixture +def queue(tmp_path: Path) -> Queue: + return Queue(db_path=str(tmp_path / "oauth_state.db")) + + +def test_create_persists_row_and_returns_state(queue: Queue) -> None: + store = OAuthStateStore(queue) + row = store.create(slot="google", next_url="/dashboard") + + assert row.slot == "google" + assert row.next_url == "/dashboard" + assert len(row.state) >= 32 + assert len(row.nonce) >= 16 + assert len(row.code_verifier) >= 32 + # state, nonce, and verifier must each be unique tokens. + assert row.state != row.nonce != row.code_verifier + # Row is in the settings store under the expected prefix. + assert queue.get_setting(STATE_PREFIX + row.state) is not None + + +def test_consume_returns_row_then_invalidates_it(queue: Queue) -> None: + store = OAuthStateStore(queue) + row = store.create(slot="github", next_url="/") + + first = store.consume(row.state) + assert first is not None + assert first.slot == "github" + assert first.code_verifier == row.code_verifier + + # Second consume is a replay attempt — must fail. + assert store.consume(row.state) is None + + +def test_consume_rejects_empty_and_unknown_tokens(queue: Queue) -> None: + store = OAuthStateStore(queue) + assert store.consume("") is None + assert store.consume("never-issued") is None + + +def test_consume_expired_row_returns_none(queue: Queue) -> None: + store = OAuthStateStore(queue) + row = store.create(slot="google", next_url="/", ttl_seconds=0) + # Even at TTL=0 we deliberately treat the row as immediately expired + # — but it's still single-use (deleted on consume). + assert store.consume(row.state) is None + # Underlying entry is gone after the consume. + assert queue.get_setting(STATE_PREFIX + row.state) is None + + +def test_consume_strips_malformed_rows(queue: Queue) -> None: + store = OAuthStateStore(queue) + # Inject a garbage row directly into the settings store. + queue.set_setting(STATE_PREFIX + "broken", "not-json-{}") + assert store.consume("broken") is None + assert queue.get_setting(STATE_PREFIX + "broken") is None + + +def test_prune_expired_removes_only_old_rows(queue: Queue) -> None: + store = OAuthStateStore(queue) + fresh = store.create(slot="google", next_url="/") + stale = store.create(slot="github", next_url="/", ttl_seconds=0) + + # Simulate the prune sweep. + removed = store.prune_expired() + assert removed >= 1 + # Fresh row survives. + assert queue.get_setting(STATE_PREFIX + fresh.state) is not None + # Stale row is gone. + assert queue.get_setting(STATE_PREFIX + stale.state) is None + + +def test_default_ttl_is_five_minutes() -> None: + assert DEFAULT_STATE_TTL_SECONDS == 300 + + +def test_create_sets_expected_expiry(queue: Queue) -> None: + store = OAuthStateStore(queue) + before = int(time.time()) + row = store.create(slot="google", next_url="/", ttl_seconds=120) + after = int(time.time()) + assert before + 120 <= row.expires_at <= after + 120 diff --git a/tests/dashboard/test_task_overrides.py b/tests/dashboard/test_task_overrides.py new file mode 100644 index 0000000..ba01e2f --- /dev/null +++ b/tests/dashboard/test_task_overrides.py @@ -0,0 +1,234 @@ +"""Tests for task & queue runtime overrides.""" + +from __future__ import annotations + +import threading +import urllib.error +from collections.abc import Generator +from http.server import ThreadingHTTPServer +from pathlib import Path + +import pytest + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.overrides_store import OverridesStore + + +@pytest.fixture +def queue(tmp_path: Path) -> Queue: + q = Queue(db_path=str(tmp_path / "overrides.db")) + + @q.task(queue="default", max_retries=3, timeout=300) + def send_email(to: str) -> str: + return to + + @q.task(queue="email", max_retries=5, rate_limit="100/m", max_concurrent=10) + def deliver(message: str) -> str: + return message + + return q + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── Store ────────────────────────────────────────────────────────────── + + +def test_overrides_store_starts_empty(queue: Queue) -> None: + store = OverridesStore(queue) + assert store.list_tasks() == {} + assert store.list_queues() == {} + + +def test_set_task_override_persists(queue: Queue) -> None: + store = OverridesStore(queue) + override = store.set_task("foo", {"max_retries": 7, "rate_limit": "50/s"}) + assert override.max_retries == 7 + assert override.rate_limit == "50/s" + fetched = store.get_task("foo") + assert fetched is not None and fetched.max_retries == 7 + + +def test_set_task_override_validates(queue: Queue) -> None: + store = OverridesStore(queue) + with pytest.raises(ValueError, match="rate_limit"): + store.set_task("foo", {"rate_limit": "no-slash"}) + with pytest.raises(ValueError, match="max_concurrent"): + store.set_task("foo", {"max_concurrent": -1}) + with pytest.raises(ValueError, match="unknown task override"): + store.set_task("foo", {"not_a_field": 1}) + + +def test_set_task_override_merges_with_existing(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_task("foo", {"max_retries": 7}) + store.set_task("foo", {"rate_limit": "50/s"}) + merged = store.get_task("foo") + assert merged is not None + assert merged.max_retries == 7 + assert merged.rate_limit == "50/s" + + +def test_set_task_override_clears_field_with_none(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_task("foo", {"max_retries": 7, "rate_limit": "50/s"}) + store.set_task("foo", {"max_retries": None}) + fetched = store.get_task("foo") + assert fetched is not None + assert fetched.max_retries is None + assert fetched.rate_limit == "50/s" + + +def test_clear_task_override(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_task("foo", {"max_retries": 7}) + assert store.clear_task("foo") is True + assert store.clear_task("foo") is False + assert store.get_task("foo") is None + + +def test_queue_override_basics(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_queue("default", {"max_concurrent": 5, "paused": True}) + fetched = store.get_queue("default") + assert fetched is not None + assert fetched.max_concurrent == 5 + assert fetched.paused is True + + +def test_apply_task_overrides_mutates_configs(queue: Queue) -> None: + """Mutating the in-memory PyTaskConfig is what makes overrides reach the + Rust scheduler at worker start.""" + store = OverridesStore(queue) + send_email = next(c for c in queue._task_configs if "send_email" in c.name) + store.set_task(send_email.name, {"max_retries": 99, "rate_limit": "1/s"}) + store.apply_task_overrides(queue._task_configs) + assert send_email.max_retries == 99 + assert send_email.rate_limit == "1/s" + + +def test_apply_task_overrides_reports_paused(queue: Queue) -> None: + store = OverridesStore(queue) + send_email = next(c for c in queue._task_configs if "send_email" in c.name) + store.set_task(send_email.name, {"paused": True}) + paused = store.apply_task_overrides(queue._task_configs) + assert send_email.name in paused + + +def test_apply_queue_overrides_merges(queue: Queue) -> None: + store = OverridesStore(queue) + queue.set_queue_concurrency("email", 10) # configured-from-Python + store.set_queue("email", {"rate_limit": "200/m"}) + merged = store.apply_queue_overrides(queue._queue_configs) + assert merged["email"]["max_concurrent"] == 10 # decorator-set survives + assert merged["email"]["rate_limit"] == "200/m" # override wins + + +# ── Queue.registered_tasks() ────────────────────────────────────────── + + +def test_registered_tasks_lists_defaults_and_overrides(queue: Queue) -> None: + tasks = queue.registered_tasks() + assert len(tasks) == 2 + by_name = {t["name"]: t for t in tasks} + deliver = next(t for n, t in by_name.items() if "deliver" in n) + assert deliver["defaults"]["rate_limit"] == "100/m" + assert deliver["defaults"]["max_retries"] == 5 + assert deliver["override"] is None + assert deliver["effective"]["rate_limit"] == "100/m" + + +def test_registered_tasks_reflects_override(queue: Queue) -> None: + send_email = next(t for t in queue.registered_tasks() if "send_email" in t["name"]) + queue.set_task_override(send_email["name"], max_retries=99) + fresh = next(t for t in queue.registered_tasks() if t["name"] == send_email["name"]) + assert fresh["override"] == {"max_retries": 99} + assert fresh["effective"]["max_retries"] == 99 + assert fresh["defaults"]["max_retries"] == 3 # original decorator value + + +# ── HTTP endpoints ──────────────────────────────────────────────────── + + +def test_list_tasks_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + tasks = client.get("/api/tasks") + assert len(tasks) == 2 + for entry in tasks: + assert "name" in entry and "defaults" in entry and "effective" in entry + + +def test_put_task_override(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if "send_email" in c.name) + result = client.put( + f"/api/tasks/{name}/override", + {"max_retries": 7, "rate_limit": "50/s"}, + ) + assert result["max_retries"] == 7 + assert result["rate_limit"] == "50/s" + + fetched = client.get(f"/api/tasks/{name}/override") + assert fetched["max_retries"] == 7 + + +def test_put_task_override_rejects_unknown_field( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if "send_email" in c.name) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/override", {"made_up": 1}) + assert exc_info.value.code == 400 + + +def test_delete_task_override(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if "send_email" in c.name) + client.put(f"/api/tasks/{name}/override", {"max_retries": 7}) + assert client.delete(f"/api/tasks/{name}/override") == {"cleared": True} + assert client.delete(f"/api/tasks/{name}/override") == {"cleared": False} + + +def test_get_task_override_404_when_none(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get("/api/tasks/nonexistent/override") + assert exc_info.value.code == 404 + + +def test_put_queue_override_pauses_queue(dashboard: tuple[AuthedClient, Queue]) -> None: + """Pausing via queue override must also update the live paused_queues + state so a running worker stops dequeueing immediately.""" + client, queue = dashboard + client.put("/api/queues/email/override", {"paused": True}) + assert "email" in queue.paused_queues() + client.put("/api/queues/email/override", {"paused": False}) + assert "email" not in queue.paused_queues() + + +def test_list_queues_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + queues = client.get("/api/queues") + names = {q["name"] for q in queues} + assert {"default", "email"} <= names + + +def test_put_queue_override_validates(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put("/api/queues/default/override", {"max_concurrent": -1}) + assert exc_info.value.code == 400 diff --git a/tests/dashboard/test_url_safety.py b/tests/dashboard/test_url_safety.py new file mode 100644 index 0000000..c2f4d2d --- /dev/null +++ b/tests/dashboard/test_url_safety.py @@ -0,0 +1,40 @@ +"""Tests for the URL-safety helpers used by the dashboard.""" + +from __future__ import annotations + +import pytest + +from taskito.dashboard.url_safety import is_safe_redirect + + +@pytest.mark.parametrize( + "path", + [ + "/", + "/dashboard", + "/dashboard/jobs", + "/dashboard?tab=overview", + "/dashboard/jobs#section", + ], +) +def test_is_safe_redirect_accepts_relative_paths(path: str) -> None: + assert is_safe_redirect(path) is True + + +@pytest.mark.parametrize( + "path", + [ + "", + None, + "dashboard", # no leading slash + "//evil.com/x", # protocol-relative URL + "/\\evil.com", # backslash variant + "http://evil.com/x", + "https://evil.com/x", + "javascript:alert(1)", + "data:text/html,xss", + "\\\\evil.com", + ], +) +def test_is_safe_redirect_rejects_unsafe(path: str | None) -> None: + assert is_safe_redirect(path) is False diff --git a/tests/dashboard/test_webhook_deliveries.py b/tests/dashboard/test_webhook_deliveries.py new file mode 100644 index 0000000..16f5e83 --- /dev/null +++ b/tests/dashboard/test_webhook_deliveries.py @@ -0,0 +1,303 @@ +"""Tests for the webhook delivery log + replay endpoints.""" + +from __future__ import annotations + +import json +import threading +import urllib.error +import urllib.request +from collections.abc import Generator +from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.delivery_store import DeliveryStore +from taskito.events import EventType + + +@pytest.fixture +def queue(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Queue: + monkeypatch.setenv("TASKITO_WEBHOOKS_ALLOW_PRIVATE", "1") + return Queue(db_path=str(tmp_path / "deliveries.db")) + + +@pytest.fixture +def echo_server() -> Generator[tuple[str, list[dict[str, Any]]]]: + """A local server that captures the bodies it receives.""" + received: list[dict[str, Any]] = [] + + class Handler(BaseHTTPRequestHandler): + def do_POST(self) -> None: + length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(length) + received.append({"body": json.loads(body), "headers": dict(self.headers)}) + self.send_response(200) + self.end_headers() + self.wfile.write(b"ok") + + def log_message(self, *args: Any) -> None: + pass + + server = HTTPServer(("127.0.0.1", 0), Handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + try: + yield f"http://127.0.0.1:{server.server_address[1]}", received + finally: + server.shutdown() + + +@pytest.fixture +def fail_server() -> Generator[str]: + """Always returns 500 to exercise the dead-letter path.""" + + class Handler(BaseHTTPRequestHandler): + def do_POST(self) -> None: + self.send_response(500) + self.end_headers() + self.wfile.write(b"server error") + + def log_message(self, *args: Any) -> None: + pass + + server = HTTPServer(("127.0.0.1", 0), Handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + try: + yield f"http://127.0.0.1:{server.server_address[1]}" + finally: + server.shutdown() + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── DeliveryStore ────────────────────────────────────────────────────── + + +def test_delivery_store_starts_empty(queue: Queue) -> None: + store = DeliveryStore(queue) + assert store.list_for("missing") == [] + assert store.count_for("missing") == 0 + + +def test_record_attempt_appends(queue: Queue) -> None: + store = DeliveryStore(queue) + record = store.record_attempt( + "sub1", + event="job.completed", + payload={"job_id": "x"}, + status="delivered", + attempts=1, + response_code=200, + latency_ms=10, + ) + assert record.subscription_id == "sub1" + assert record.status == "delivered" + assert store.count_for("sub1") == 1 + listed = store.list_for("sub1") + assert len(listed) == 1 + assert listed[0].id == record.id + + +def test_record_attempt_caps_history(queue: Queue) -> None: + store = DeliveryStore(queue, max_per_webhook=3) + for i in range(5): + store.record_attempt( + "sub1", + event="job.completed", + payload={"job_id": str(i)}, + status="delivered", + attempts=1, + ) + items = store.list_for("sub1") + assert len(items) == 3 + # Newest first; oldest (i=0, i=1) evicted. + assert items[0].payload["job_id"] == "4" + assert items[-1].payload["job_id"] == "2" + + +def test_record_attempt_truncates_response_body(queue: Queue) -> None: + store = DeliveryStore(queue) + big = "x" * 100_000 + record = store.record_attempt( + "sub1", + event="job.completed", + payload={}, + status="failed", + attempts=1, + response_body=big, + ) + assert record.response_body is not None + assert len(record.response_body.encode("utf-8")) <= 2048 + 4 # +ellipsis + + +def test_list_for_filters_by_status_and_event(queue: Queue) -> None: + store = DeliveryStore(queue) + store.record_attempt("sub1", event="job.completed", payload={}, status="delivered", attempts=1) + store.record_attempt("sub1", event="job.failed", payload={}, status="failed", attempts=1) + store.record_attempt("sub1", event="job.completed", payload={}, status="failed", attempts=1) + + delivered = store.list_for("sub1", status="delivered") + assert len(delivered) == 1 + failed = store.list_for("sub1", status="failed") + assert len(failed) == 2 + completed_event = store.list_for("sub1", event="job.completed") + assert len(completed_event) == 2 + + +# ── End-to-end delivery recording ────────────────────────────────────── + + +def test_successful_delivery_recorded( + queue: Queue, echo_server: tuple[str, list[dict[str, Any]]], poll_until: Any +) -> None: + url, _ = echo_server + sub = queue.add_webhook(url, events=[EventType.JOB_COMPLETED]) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "abc"}) + poll_until( + lambda: DeliveryStore(queue).count_for(sub.id) >= 1, + message="delivery not recorded", + ) + items = DeliveryStore(queue).list_for(sub.id) + assert len(items) == 1 + assert items[0].status == "delivered" + assert items[0].response_code == 200 + assert items[0].latency_ms is not None + + +def test_failed_delivery_marked_dead(queue: Queue, fail_server: str, poll_until: Any) -> None: + sub = queue.add_webhook(fail_server, events=[EventType.JOB_FAILED], max_retries=2) + queue._webhook_manager.notify(EventType.JOB_FAILED, {"job_id": "x", "error": "boom"}) + poll_until( + lambda: DeliveryStore(queue).count_for(sub.id) >= 1, + message="delivery never recorded", + ) + items = DeliveryStore(queue).list_for(sub.id) + assert len(items) == 1 + assert items[0].status == "dead" + assert items[0].attempts == 2 + assert items[0].response_code == 500 + + +# ── Dashboard endpoints ──────────────────────────────────────────────── + + +def test_list_deliveries_endpoint( + dashboard: tuple[AuthedClient, Queue], + echo_server: tuple[str, list[dict[str, Any]]], + poll_until: Any, +) -> None: + client, queue = dashboard + url, _ = echo_server + sub = queue.add_webhook(url, events=[EventType.JOB_COMPLETED]) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) + poll_until(lambda: DeliveryStore(queue).count_for(sub.id) >= 1) + + page = client.get(f"/api/webhooks/{sub.id}/deliveries") + assert page["total"] == 1 + assert page["items"][0]["status"] == "delivered" + + +def test_list_deliveries_filters_by_status( + dashboard: tuple[AuthedClient, Queue], fail_server: str, poll_until: Any +) -> None: + client, queue = dashboard + sub = queue.add_webhook(fail_server, max_retries=1) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) + poll_until(lambda: DeliveryStore(queue).count_for(sub.id) >= 1) + + only_failed = client.get(f"/api/webhooks/{sub.id}/deliveries?status=dead") + assert only_failed["total"] >= 1 + assert all(r["status"] == "dead" for r in only_failed["items"]) + + delivered = client.get(f"/api/webhooks/{sub.id}/deliveries?status=delivered") + assert delivered["items"] == [] + + +def test_get_delivery_endpoint( + dashboard: tuple[AuthedClient, Queue], + echo_server: tuple[str, list[dict[str, Any]]], + poll_until: Any, +) -> None: + client, queue = dashboard + url, _ = echo_server + sub = queue.add_webhook(url) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "x"}) + poll_until(lambda: DeliveryStore(queue).count_for(sub.id) >= 1) + record_id = DeliveryStore(queue).list_for(sub.id)[0].id + + record = client.get(f"/api/webhooks/{sub.id}/deliveries/{record_id}") + assert record["id"] == record_id + assert record["status"] == "delivered" + + +def test_replay_delivery_endpoint( + dashboard: tuple[AuthedClient, Queue], + echo_server: tuple[str, list[dict[str, Any]]], + poll_until: Any, +) -> None: + client, queue = dashboard + url, received = echo_server + sub = queue.add_webhook(url) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "x"}) + poll_until(lambda: len(received) >= 1) + poll_until(lambda: DeliveryStore(queue).count_for(sub.id) >= 1) + delivery_id = DeliveryStore(queue).list_for(sub.id)[0].id + + result = client.post(f"/api/webhooks/{sub.id}/deliveries/{delivery_id}/replay") + assert result["delivered"] is True + assert result["status"] == 200 + assert result["replayed_of"] == delivery_id + + # Replay produces a NEW delivery record AND a new POST. + poll_until(lambda: len(received) >= 2) + poll_until(lambda: DeliveryStore(queue).count_for(sub.id) >= 2) + items = DeliveryStore(queue).list_for(sub.id) + assert any(r.payload.get("replay_of") == delivery_id for r in items) + + +def test_list_deliveries_404_for_unknown_subscription( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, _ = dashboard + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get("/api/webhooks/nope/deliveries") + assert exc_info.value.code == 404 + + +def test_get_delivery_404_when_missing( + dashboard: tuple[AuthedClient, Queue], + echo_server: tuple[str, list[dict[str, Any]]], +) -> None: + client, queue = dashboard + url, _ = echo_server + sub = queue.add_webhook(url) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get(f"/api/webhooks/{sub.id}/deliveries/nonexistent") + assert exc_info.value.code == 404 + + +def test_list_deliveries_rejects_bad_status( + dashboard: tuple[AuthedClient, Queue], + echo_server: tuple[str, list[dict[str, Any]]], +) -> None: + client, queue = dashboard + url, _ = echo_server + sub = queue.add_webhook(url) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get(f"/api/webhooks/{sub.id}/deliveries?status=not-real") + assert exc_info.value.code == 400 diff --git a/tests/dashboard/test_webhooks_endpoints.py b/tests/dashboard/test_webhooks_endpoints.py new file mode 100644 index 0000000..a28cb7f --- /dev/null +++ b/tests/dashboard/test_webhooks_endpoints.py @@ -0,0 +1,378 @@ +"""Tests for the persistent webhook subscription store + dashboard CRUD endpoints.""" + +from __future__ import annotations + +import hashlib +import hmac +import json +import threading +import urllib.error +import urllib.request +from collections.abc import Generator +from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.url_safety import UnsafeWebhookUrl, validate_webhook_url +from taskito.dashboard.webhook_store import WebhookSubscriptionStore +from taskito.events import EventType + +# ── Fixtures ─────────────────────────────────────────────────────────── + + +@pytest.fixture +def queue(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Queue: + # Tests in this file create webhooks against 127.0.0.1 servers, which the + # SSRF guard would otherwise reject. + monkeypatch.setenv("TASKITO_WEBHOOKS_ALLOW_PRIVATE", "1") + return Queue(db_path=str(tmp_path / "webhooks.db")) + + +@pytest.fixture +def echo_server() -> Generator[tuple[str, list[dict[str, Any]]]]: + """A local HTTP server that captures incoming webhook bodies.""" + received: list[dict[str, Any]] = [] + + class Handler(BaseHTTPRequestHandler): + def do_POST(self) -> None: + length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(length) + received.append({"body": json.loads(body), "headers": dict(self.headers)}) + self.send_response(200) + self.end_headers() + + def log_message(self, *args: Any) -> None: + pass + + server = HTTPServer(("127.0.0.1", 0), Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + try: + yield f"http://127.0.0.1:{server.server_address[1]}", received + finally: + server.shutdown() + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── SSRF guard ───────────────────────────────────────────────────────── + + +def test_url_safety_rejects_loopback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TASKITO_WEBHOOKS_ALLOW_PRIVATE", raising=False) + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("http://127.0.0.1:8080/x") + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("http://localhost/x") + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("http://something.internal/x") + + +def test_url_safety_rejects_private_ranges(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TASKITO_WEBHOOKS_ALLOW_PRIVATE", raising=False) + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("http://10.0.0.5/x") + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("http://192.168.1.1/x") + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("http://169.254.169.254/latest/meta-data") + + +def test_url_safety_rejects_bad_scheme(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("TASKITO_WEBHOOKS_ALLOW_PRIVATE", raising=False) + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("ftp://example.com/x") + with pytest.raises(UnsafeWebhookUrl): + validate_webhook_url("javascript:alert(1)") + + +def test_url_safety_allows_private_with_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("TASKITO_WEBHOOKS_ALLOW_PRIVATE", "1") + # No exception + validate_webhook_url("http://127.0.0.1:8080/x") + validate_webhook_url("http://10.0.0.5/x") + + +# ── Store / Python API ───────────────────────────────────────────────── + + +def test_store_starts_empty(queue: Queue) -> None: + assert WebhookSubscriptionStore(queue).list_all() == [] + + +def test_create_and_get_subscription(queue: Queue) -> None: + sub = queue.add_webhook( + "http://127.0.0.1:9999/x", events=[EventType.JOB_FAILED], secret="topsecret" + ) + fetched = queue.get_webhook(sub.id) + assert fetched is not None + assert fetched.url == "http://127.0.0.1:9999/x" + assert fetched.events == ["job.failed"] + assert fetched.secret == "topsecret" + + +def test_subscriptions_persist_across_queue_instances(tmp_path: Path) -> None: + """A fresh Queue against the same DB sees prior subscriptions.""" + import os + + os.environ["TASKITO_WEBHOOKS_ALLOW_PRIVATE"] = "1" + try: + db = str(tmp_path / "persist.db") + q1 = Queue(db_path=db) + sub = q1.add_webhook("http://127.0.0.1:9999/x") + + q2 = Queue(db_path=db) + all_subs = q2.list_webhooks() + assert any(s.id == sub.id for s in all_subs) + finally: + del os.environ["TASKITO_WEBHOOKS_ALLOW_PRIVATE"] + + +def test_update_webhook(queue: Queue) -> None: + sub = queue.add_webhook("http://127.0.0.1:9999/x", max_retries=3) + updated = queue.update_webhook(sub.id, max_retries=7, enabled=False) + assert updated.max_retries == 7 + assert updated.enabled is False + fresh = queue.get_webhook(sub.id) + assert fresh is not None and fresh.max_retries == 7 + + +def test_remove_webhook(queue: Queue) -> None: + sub = queue.add_webhook("http://127.0.0.1:9999/x") + assert queue.remove_webhook(sub.id) is True + assert queue.remove_webhook(sub.id) is False + assert queue.get_webhook(sub.id) is None + + +def test_rotate_secret(queue: Queue) -> None: + sub = queue.add_webhook("http://127.0.0.1:9999/x", secret="old") + new_secret = queue.rotate_webhook_secret(sub.id) + assert new_secret != "old" + fresh = queue.get_webhook(sub.id) + assert fresh is not None and fresh.secret == new_secret + + +def test_disabled_webhook_does_not_deliver( + queue: Queue, echo_server: tuple[str, list[dict[str, Any]]], poll_until: Any +) -> None: + url, received = echo_server + sub = queue.add_webhook(url, events=[EventType.JOB_COMPLETED]) + queue.update_webhook(sub.id, enabled=False) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) + # Give the dispatcher a chance. + import time + + time.sleep(0.3) + assert received == [] + + +def test_task_filter_restricts_delivery( + queue: Queue, echo_server: tuple[str, list[dict[str, Any]]], poll_until: Any +) -> None: + url, received = echo_server + queue.add_webhook(url, task_filter=["only_me"]) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "1", "task_name": "other"}) + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "2", "task_name": "only_me"}) + poll_until(lambda: len(received) >= 1, message="task-filtered webhook not delivered") + assert len(received) == 1 + assert received[0]["body"]["task_name"] == "only_me" + + +def test_manager_reload_picks_up_new_subscription( + queue: Queue, echo_server: tuple[str, list[dict[str, Any]]], poll_until: Any +) -> None: + """Subscriptions written by another worker show up after reload.""" + url, received = echo_server + # Bypass the Queue API and write directly to the store to simulate a peer. + WebhookSubscriptionStore(queue).create(url=url) + queue._webhook_manager.reload() + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) + poll_until(lambda: len(received) >= 1, message="reloaded webhook not delivered") + + +def test_subscription_secret_signs_payload( + queue: Queue, echo_server: tuple[str, list[dict[str, Any]]], poll_until: Any +) -> None: + url, received = echo_server + sub = queue.add_webhook(url, secret="signing-key") + queue._webhook_manager.notify(EventType.JOB_COMPLETED, {"job_id": "x"}) + poll_until(lambda: len(received) >= 1) + + sig_header = received[0]["headers"].get("X-Taskito-Signature") + assert sig_header is not None + body_bytes = json.dumps(received[0]["body"], default=str).encode("utf-8") + expected = hmac.new(b"signing-key", body_bytes, hashlib.sha256).hexdigest() + assert sig_header == f"sha256={expected}" + assert sub.secret == "signing-key" + + +# ── Dashboard HTTP endpoints ────────────────────────────────────────── + + +def test_list_webhooks_returns_empty(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + assert client.get("/api/webhooks") == [] + + +def test_event_types_listing(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + events = client.get("/api/event-types") + assert "job.completed" in events + assert "job.failed" in events + assert sorted(events) == events # always sorted + + +def test_create_webhook_endpoint( + dashboard: tuple[AuthedClient, Queue], echo_server: tuple[str, list[dict[str, Any]]] +) -> None: + client, _queue = dashboard + url, _ = echo_server + created = client.post( + "/api/webhooks", + { + "url": url, + "events": ["job.failed"], + "task_filter": ["send_email"], + "max_retries": 5, + "description": "ops failures", + "generate_secret": True, + }, + ) + assert created["url"] == url + assert created["events"] == ["job.failed"] + assert created["task_filter"] == ["send_email"] + assert created["max_retries"] == 5 + # Secret is revealed exactly once on create. + assert "secret" in created + assert created["has_secret"] is True + + listed = client.get("/api/webhooks") + assert len(listed) == 1 + # ``secret`` is redacted from list/get responses. + assert "secret" not in listed[0] + assert listed[0]["has_secret"] is True + + +def test_create_webhook_rejects_unsafe_url(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + # The fixture has TASKITO_WEBHOOKS_ALLOW_PRIVATE=1; remove it for this test only. + import os + + saved = os.environ.pop("TASKITO_WEBHOOKS_ALLOW_PRIVATE", None) + try: + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.post("/api/webhooks", {"url": "http://127.0.0.1/x"}) + assert exc_info.value.code == 400 + finally: + if saved is not None: + os.environ["TASKITO_WEBHOOKS_ALLOW_PRIVATE"] = saved + + +def test_create_webhook_rejects_unknown_event( + dashboard: tuple[AuthedClient, Queue], echo_server: tuple[str, list[dict[str, Any]]] +) -> None: + client, _ = dashboard + url, _ = echo_server + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.post("/api/webhooks", {"url": url, "events": ["not.a.real.event"]}) + assert exc_info.value.code == 400 + + +def test_update_webhook_endpoint( + dashboard: tuple[AuthedClient, Queue], echo_server: tuple[str, list[dict[str, Any]]] +) -> None: + client, _ = dashboard + url, _ = echo_server + created = client.post("/api/webhooks", {"url": url}) + + updated = client.put( + f"/api/webhooks/{created['id']}", + {"max_retries": 10, "enabled": False, "description": "paused"}, + ) + assert updated["max_retries"] == 10 + assert updated["enabled"] is False + assert updated["description"] == "paused" + + +def test_delete_webhook_endpoint( + dashboard: tuple[AuthedClient, Queue], echo_server: tuple[str, list[dict[str, Any]]] +) -> None: + client, _ = dashboard + url, _ = echo_server + created = client.post("/api/webhooks", {"url": url}) + assert client.delete(f"/api/webhooks/{created['id']}") == {"deleted": True} + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.delete(f"/api/webhooks/{created['id']}") + assert exc_info.value.code == 404 + + +def test_rotate_secret_endpoint( + dashboard: tuple[AuthedClient, Queue], echo_server: tuple[str, list[dict[str, Any]]] +) -> None: + client, _ = dashboard + url, _ = echo_server + created = client.post("/api/webhooks", {"url": url, "secret": "old"}) + rotated = client.post(f"/api/webhooks/{created['id']}/rotate-secret") + assert rotated["secret"] != "old" + assert rotated["id"] == created["id"] + + +def test_test_webhook_endpoint_returns_status( + dashboard: tuple[AuthedClient, Queue], echo_server: tuple[str, list[dict[str, Any]]] +) -> None: + client, _ = dashboard + url, received = echo_server + created = client.post("/api/webhooks", {"url": url}) + result = client.post(f"/api/webhooks/{created['id']}/test") + assert result["delivered"] is True + assert result["status"] == 200 + # A test event landed at the echo server. + assert any(r["body"].get("event") == "test.ping" for r in received) + + +def test_test_webhook_endpoint_reports_failure( + dashboard: tuple[AuthedClient, Queue], +) -> None: + """When the target server returns 4xx, the test endpoint surfaces it.""" + received_count = [0] + + class FailHandler(BaseHTTPRequestHandler): + def do_POST(self) -> None: + received_count[0] += 1 + self.send_response(418) + self.end_headers() + + def log_message(self, *args: Any) -> None: + pass + + server = HTTPServer(("127.0.0.1", 0), FailHandler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + try: + client, _ = dashboard + created = client.post( + "/api/webhooks", + {"url": f"http://127.0.0.1:{server.server_address[1]}/x"}, + ) + result = client.post(f"/api/webhooks/{created['id']}/test") + assert result["delivered"] is False + assert result["status"] == 418 + finally: + server.shutdown()