diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 0000000..6c2a298 --- /dev/null +++ b/src/api.rs @@ -0,0 +1,570 @@ +use std::fs; +use std::io::{self, Read, Write}; +use std::path::PathBuf; + +use anyhow::{bail, Context, Result}; +use clap::{Args, Subcommand, ValueEnum}; +use reqwest::Method; +use serde::Serialize; +use serde_json::Value; + +use crate::args::BaseArgs; +use crate::auth::login; +use crate::http::{ApiClient, HttpError, RawRequestBody, ServiceBase}; + +const OPENAPI_SPEC_URL: &str = + "https://raw.githubusercontent.com/braintrustdata/braintrust-openapi/main/openapi/spec.json"; + +#[derive(Debug, Clone, Args)] +#[command(after_help = "\ +Examples: + bt api get /v1/project + bt api post /v1/project_score --body '{\"name\":\"example\"}' + bt api post /api/project_score/register --base app --body-file payload.json + bt api spec --filter project_score +")] +pub struct ApiArgs { + #[command(subcommand)] + command: ApiCommand, +} + +#[derive(Debug, Clone, Subcommand)] +enum ApiCommand { + /// Send an authenticated GET request + Get(ReadRequestArgs), + /// Send an authenticated POST request + Post(WriteRequestArgs), + /// Send an authenticated PUT request + Put(WriteRequestArgs), + /// Send an authenticated PATCH request + Patch(WriteRequestArgs), + /// Send an authenticated DELETE request + Delete(ReadRequestArgs), + /// Fetch and inspect the Braintrust OpenAPI spec + Spec(SpecArgs), +} + +#[derive(Debug, Clone, Args)] +struct ReadRequestArgs { + #[command(flatten)] + target: RequestTargetArgs, +} + +#[derive(Debug, Clone, Args)] +struct WriteRequestArgs { + #[command(flatten)] + target: RequestTargetArgs, + + #[command(flatten)] + body: RequestBodyArgs, +} + +#[derive(Debug, Clone, Args)] +struct RequestTargetArgs { + /// Relative path to request (for example: /v1/project or /api/project_score/register) + #[arg(value_name = "PATH")] + path: String, + + /// Which Braintrust base URL to target + #[arg(long, value_enum, default_value_t = ApiBase::Auto)] + base: ApiBase, + + /// Extra request header in Name: Value form + #[arg(long = "header", value_parser = parse_header_arg, value_name = "NAME:VALUE")] + headers: Vec, +} + +#[derive(Debug, Clone, Args)] +struct RequestBodyArgs { + /// Inline request body + #[arg(long, conflicts_with = "body_file", value_name = "BODY")] + body: Option, + + /// Read request body from a file, or '-' for stdin + #[arg(long, conflicts_with = "body", value_name = "FILE")] + body_file: Option, + + /// Content-Type to send when a body is present + #[arg(long, value_name = "MIME_TYPE")] + content_type: Option, +} + +#[derive(Debug, Clone, Args)] +struct SpecArgs { + /// Print the raw JSON OpenAPI document + #[arg(long)] + raw: bool, + + /// Filter operations by path, method, summary, or operation id + #[arg(long, value_name = "TEXT")] + filter: Option, + + /// Override the source URL for the spec + #[arg(long, hide = true, value_name = "URL")] + source_url: Option, +} + +#[derive(Debug, Clone, Copy, ValueEnum, Eq, PartialEq)] +enum ApiBase { + Auto, + Api, + App, +} + +#[derive(Debug, Clone, Eq, PartialEq)] +struct HeaderArg { + name: String, + value: String, +} + +pub async fn run(base: BaseArgs, args: ApiArgs) -> Result<()> { + match args.command { + ApiCommand::Spec(args) => run_spec(base.json, args).await, + command => { + let ctx = login(&base).await?; + let client = ApiClient::new(&ctx)?; + run_authenticated_command(&client, command).await + } + } +} + +async fn run_authenticated_command(client: &ApiClient, command: ApiCommand) -> Result<()> { + match command { + ApiCommand::Get(args) => run_request(client, Method::GET, args.target, None).await, + ApiCommand::Post(args) => { + run_request( + client, + Method::POST, + args.target, + load_request_body(args.body)?, + ) + .await + } + ApiCommand::Put(args) => { + run_request( + client, + Method::PUT, + args.target, + load_request_body(args.body)?, + ) + .await + } + ApiCommand::Patch(args) => { + run_request( + client, + Method::PATCH, + args.target, + load_request_body(args.body)?, + ) + .await + } + ApiCommand::Delete(args) => run_request(client, Method::DELETE, args.target, None).await, + ApiCommand::Spec(_) => unreachable!("spec commands are handled before auth is loaded"), + } +} + +async fn run_request( + client: &ApiClient, + method: Method, + args: RequestTargetArgs, + body: Option, +) -> Result<()> { + if is_absolute_url(&args.path) { + bail!("absolute URLs are not supported; pass a relative path such as /v1/project"); + } + + let service = resolve_service_base(&args.path, args.base); + let headers = args + .headers + .into_iter() + .map(|header| (header.name, header.value)) + .collect::>(); + + let response = client + .request_raw(method, service, &args.path, &headers, body) + .await?; + let status = response.status(); + let bytes = response + .bytes() + .await + .context("failed to read response body")?; + if !status.is_success() { + let body = String::from_utf8_lossy(&bytes).into_owned(); + return Err(HttpError { status, body }.into()); + } + + if !bytes.is_empty() { + io::stdout() + .write_all(&bytes) + .context("failed to write response body")?; + } + Ok(()) +} + +async fn run_spec(json: bool, args: SpecArgs) -> Result<()> { + let source_url = args + .source_url + .unwrap_or_else(|| OPENAPI_SPEC_URL.to_string()); + let spec = fetch_openapi_spec(&source_url).await?; + if args.raw { + print!("{spec}"); + return Ok(()); + } + + let operations = filter_operations(&extract_operations(&spec)?, args.filter.as_deref()); + if json { + println!( + "{}", + serde_json::to_string(&SpecOutput { + source_url, + operations, + })? + ); + return Ok(()); + } + + for operation in operations { + println!("{}", format_operation_line(&operation)); + } + Ok(()) +} + +fn load_request_body(args: RequestBodyArgs) -> Result> { + let RequestBodyArgs { + body, + body_file, + content_type, + } = args; + + let body_bytes = if let Some(body) = body { + Some(body.into_bytes()) + } else if let Some(path) = body_file { + Some(read_body_source(&path)?) + } else { + None + }; + + if body_bytes.is_none() { + if content_type.is_some() { + bail!("--content-type requires --body or --body-file"); + } + return Ok(None); + } + + Ok(body_bytes.map(|bytes| RawRequestBody { + bytes, + content_type: Some(content_type.unwrap_or_else(|| "application/json".to_string())), + })) +} + +fn read_body_source(path: &PathBuf) -> Result> { + if path.as_os_str() == "-" { + let mut bytes = Vec::new(); + io::stdin() + .read_to_end(&mut bytes) + .context("failed to read request body from stdin")?; + return Ok(bytes); + } + + fs::read(path).with_context(|| format!("failed to read request body from {}", path.display())) +} + +fn parse_header_arg(value: &str) -> Result { + let Some((name, raw_value)) = value.split_once(':') else { + return Err("header must be in Name: Value form".to_string()); + }; + + let name = name.trim(); + if name.is_empty() { + return Err("header name cannot be empty".to_string()); + } + + Ok(HeaderArg { + name: name.to_string(), + value: raw_value.trim().to_string(), + }) +} + +fn resolve_service_base(path: &str, base: ApiBase) -> ServiceBase { + match base { + ApiBase::Api => ServiceBase::Api, + ApiBase::App => ServiceBase::App, + ApiBase::Auto => { + if is_app_path(path) { + ServiceBase::App + } else { + ServiceBase::Api + } + } + } +} + +fn is_app_path(path: &str) -> bool { + let trimmed = path.trim_start_matches('/'); + trimmed == "api" || trimmed.starts_with("api/") +} + +fn is_absolute_url(path: &str) -> bool { + path.starts_with("http://") || path.starts_with("https://") +} + +#[derive(Debug, Clone, Serialize, Eq, PartialEq)] +struct SpecOperation { + method: String, + path: String, + operation_id: Option, + summary: Option, +} + +#[derive(Debug, Serialize)] +struct SpecOutput { + source_url: String, + operations: Vec, +} + +async fn fetch_openapi_spec(source_url: &str) -> Result { + let response = reqwest::Client::new() + .get(source_url) + .send() + .await + .with_context(|| format!("failed to fetch OpenAPI spec from {source_url}"))?; + let status = response.status(); + let body = response + .text() + .await + .context("failed to read OpenAPI spec response body")?; + if !status.is_success() { + bail!("failed to fetch OpenAPI spec ({status}): {body}"); + } + Ok(body) +} + +fn extract_operations(spec_json: &str) -> Result> { + let spec: Value = serde_json::from_str(spec_json).context("failed to parse OpenAPI spec")?; + let paths = spec + .get("paths") + .and_then(Value::as_object) + .ok_or_else(|| anyhow::anyhow!("OpenAPI spec is missing the top-level `paths` object"))?; + + let mut operations = Vec::new(); + for (path, path_item) in paths { + let Some(path_item) = path_item.as_object() else { + continue; + }; + + for method in ["get", "post", "put", "patch", "delete", "options", "head"] { + let Some(operation) = path_item.get(method).and_then(Value::as_object) else { + continue; + }; + operations.push(SpecOperation { + method: method.to_uppercase(), + path: path.clone(), + operation_id: operation + .get("operationId") + .and_then(Value::as_str) + .map(ToOwned::to_owned), + summary: operation + .get("summary") + .and_then(Value::as_str) + .map(ToOwned::to_owned), + }); + } + } + + operations.sort_by(|left, right| { + left.path + .cmp(&right.path) + .then_with(|| left.method.cmp(&right.method)) + }); + Ok(operations) +} + +fn filter_operations(operations: &[SpecOperation], filter: Option<&str>) -> Vec { + let Some(filter) = filter.map(str::trim).filter(|value| !value.is_empty()) else { + return operations.to_vec(); + }; + let needle = filter.to_ascii_lowercase(); + + operations + .iter() + .filter(|operation| operation_matches_filter(operation, &needle)) + .cloned() + .collect() +} + +fn operation_matches_filter(operation: &SpecOperation, needle: &str) -> bool { + [ + operation.method.as_str(), + operation.path.as_str(), + operation.operation_id.as_deref().unwrap_or(""), + operation.summary.as_deref().unwrap_or(""), + ] + .into_iter() + .any(|value| value.to_ascii_lowercase().contains(needle)) +} + +fn format_operation_line(operation: &SpecOperation) -> String { + let mut line = format!("{} {}", operation.method, operation.path); + let detail = operation + .summary + .as_deref() + .or(operation.operation_id.as_deref()) + .unwrap_or(""); + if !detail.is_empty() { + line.push_str(" "); + line.push_str(detail); + } + line +} + +#[cfg(test)] +mod tests { + use super::*; + + const SAMPLE_SPEC: &str = r#"{ + "paths": { + "/v1/project_score": { + "get": { + "summary": "List project scores", + "operationId": "listProjectScores" + }, + "post": { + "summary": "Create a project score", + "operationId": "createProjectScore" + } + }, + "/brainstore/automation/reset-cursors": { + "post": { + "summary": "Reset automation cursors", + "operationId": "resetAutomationCursors" + } + } + } + }"#; + + #[test] + fn parse_header_arg_splits_name_and_value() { + let header = parse_header_arg("X-Test: hello").expect("parse header"); + assert_eq!( + header, + HeaderArg { + name: "X-Test".to_string(), + value: "hello".to_string(), + } + ); + } + + #[test] + fn parse_header_arg_rejects_invalid_input() { + let err = parse_header_arg("invalid").expect_err("expected parse failure"); + assert!(err.contains("Name: Value")); + } + + #[test] + fn resolve_service_base_uses_app_for_next_api_routes() { + assert_eq!( + resolve_service_base("/api/project_score/register", ApiBase::Auto), + ServiceBase::App + ); + assert_eq!( + resolve_service_base("api/foo", ApiBase::Auto), + ServiceBase::App + ); + } + + #[test] + fn resolve_service_base_uses_api_for_non_app_routes() { + assert_eq!( + resolve_service_base("/v1/project_score", ApiBase::Auto), + ServiceBase::Api + ); + assert_eq!( + resolve_service_base("/brainstore/automation/reset-cursors", ApiBase::Auto), + ServiceBase::Api + ); + } + + #[test] + fn load_request_body_defaults_to_json_content_type() { + let body = load_request_body(RequestBodyArgs { + body: Some("{\"ok\":true}".to_string()), + body_file: None, + content_type: None, + }) + .expect("load request body") + .expect("expected request body"); + + assert_eq!(body.content_type.as_deref(), Some("application/json")); + assert_eq!(body.bytes, br#"{"ok":true}"#); + } + + #[test] + fn load_request_body_rejects_content_type_without_body() { + let err = load_request_body(RequestBodyArgs { + body: None, + body_file: None, + content_type: Some("application/json".to_string()), + }) + .expect_err("expected failure"); + + assert!(err.to_string().contains("--content-type requires")); + } + + #[test] + fn extract_operations_reads_methods_and_metadata() { + let operations = extract_operations(SAMPLE_SPEC).expect("extract operations"); + assert_eq!( + operations, + vec![ + SpecOperation { + method: "POST".to_string(), + path: "/brainstore/automation/reset-cursors".to_string(), + operation_id: Some("resetAutomationCursors".to_string()), + summary: Some("Reset automation cursors".to_string()), + }, + SpecOperation { + method: "GET".to_string(), + path: "/v1/project_score".to_string(), + operation_id: Some("listProjectScores".to_string()), + summary: Some("List project scores".to_string()), + }, + SpecOperation { + method: "POST".to_string(), + path: "/v1/project_score".to_string(), + operation_id: Some("createProjectScore".to_string()), + summary: Some("Create a project score".to_string()), + }, + ] + ); + } + + #[test] + fn filter_operations_matches_path_method_summary_and_operation_id() { + let operations = extract_operations(SAMPLE_SPEC).expect("extract operations"); + + let by_path = filter_operations(&operations, Some("project_score")); + assert_eq!(by_path.len(), 2); + + let by_method = filter_operations(&operations, Some("post")); + assert_eq!(by_method.len(), 2); + + let by_summary = filter_operations(&operations, Some("reset automation")); + assert_eq!(by_summary.len(), 1); + + let by_operation_id = filter_operations(&operations, Some("createProjectScore")); + assert_eq!(by_operation_id.len(), 1); + assert_eq!(by_operation_id[0].path, "/v1/project_score"); + } + + #[test] + fn format_operation_line_prefers_summary() { + let line = format_operation_line(&SpecOperation { + method: "POST".to_string(), + path: "/v1/project_score".to_string(), + operation_id: Some("createProjectScore".to_string()), + summary: Some("Create a project score".to_string()), + }); + + assert_eq!(line, "POST /v1/project_score Create a project score"); + } +} diff --git a/src/auth.rs b/src/auth.rs index 384fd98..eaa4cb1 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -254,12 +254,30 @@ struct OAuthTokenResponse { expires_in: Option, } +#[derive(Debug, Clone, Deserialize)] +struct OrgScopedJwtResponse { + token: String, + expires_at: String, +} + +#[derive(Debug, Clone, Serialize)] +struct OrgScopedJwtRequest<'a> { + org_name: &'a str, +} + +#[derive(Debug, Clone)] +struct ResolvedBearerToken { + token: String, + expires_at: Option, +} + #[derive(Debug, Clone, Args)] #[command(after_help = "\ Examples: bt auth login bt auth profiles bt auth refresh + bt auth token --header bt auth logout --profile work ")] pub struct AuthArgs { @@ -273,6 +291,8 @@ enum AuthCommand { Login(AuthLoginArgs), /// Force-refresh OAuth access token for a profile Refresh, + /// Print the resolved bearer token for the current auth context (org-scoped for OAuth logins) + Token(AuthTokenArgs), /// List auth profiles and check connection status Profiles(AuthProfilesArgs), /// Log out by removing a saved profile @@ -285,6 +305,13 @@ struct AuthProfilesArgs { verbose: bool, } +#[derive(Debug, Clone, Args)] +struct AuthTokenArgs { + /// Print a full Authorization header instead of only the token + #[arg(long)] + header: bool, +} + #[derive(Debug, Clone, Args)] struct AuthLoginArgs { /// Use OAuth login instead of API key login @@ -319,19 +346,129 @@ pub async fn run(base: BaseArgs, args: AuthArgs) -> Result<()> { match args.command { AuthCommand::Login(login_args) => run_login_set(&base, login_args).await, AuthCommand::Refresh => run_login_refresh(&base).await, + AuthCommand::Token(token_args) => run_token(&base, token_args).await, AuthCommand::Profiles(profile_args) => run_profiles(&base, profile_args).await, AuthCommand::Logout(logout_args) => run_login_logout(base, logout_args), } } -pub async fn login(base: &BaseArgs) -> Result { - maybe_warn_api_key_override(base); +#[derive(Debug, Clone, Serialize)] +struct AuthTokenOutput { + token: String, + authorization_header: String, + api_url: Option, + app_url: Option, + org_name: Option, + expires_at: Option, +} + +async fn run_token(base: &BaseArgs, args: AuthTokenArgs) -> Result<()> { let auth = resolve_auth(base).await?; - let api_key = auth.api_key.clone().ok_or_else(|| { + let token = resolve_api_bearer_token(&auth).await?; + + if base.json { + println!( + "{}", + serde_json::to_string(&build_auth_token_output(&token, &auth))? + ); + return Ok(()); + } + + if args.header { + println!("{}", format_authorization_header(&token.token)); + } else { + println!("{}", token.token); + } + Ok(()) +} + +fn format_authorization_header(token: &str) -> String { + format!("Authorization: Bearer {token}") +} + +fn build_auth_token_output(token: &ResolvedBearerToken, auth: &ResolvedAuth) -> AuthTokenOutput { + AuthTokenOutput { + token: token.token.clone(), + authorization_header: format_authorization_header(&token.token), + api_url: auth.api_url.clone(), + app_url: auth.app_url.clone(), + org_name: auth.org_name.clone(), + expires_at: token.expires_at.clone(), + } +} + +async fn resolve_api_bearer_token(auth: &ResolvedAuth) -> Result { + let token = auth.api_key.clone().ok_or_else(|| { anyhow::anyhow!( "no login credentials found; set BRAINTRUST_API_KEY, pass --api-key, or run `bt auth login`" ) })?; + if !auth.is_oauth { + return Ok(ResolvedBearerToken { + token, + expires_at: None, + }); + } + + let org_name = auth + .org_name + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .ok_or_else(|| { + anyhow::anyhow!( + "oauth login is missing an active org; pass --org or re-run `bt auth login --oauth`" + ) + })?; + let app_url = auth + .app_url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or(DEFAULT_APP_URL); + + exchange_oauth_token_for_org_scoped_jwt(&token, app_url, org_name).await +} + +async fn exchange_oauth_token_for_org_scoped_jwt( + login_token: &str, + app_url: &str, + org_name: &str, +) -> Result { + let token_url = format!("{}/api/self/org_scoped_jwt", app_url.trim_end_matches('/')); + let client = Client::builder() + .timeout(crate::http::DEFAULT_HTTP_TIMEOUT) + .build() + .context("failed to initialize HTTP client")?; + let response = client + .post(&token_url) + .bearer_auth(login_token) + .json(&OrgScopedJwtRequest { org_name }) + .send() + .await + .with_context(|| format!("failed to call org-scoped jwt endpoint {token_url}"))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(crate::http::HttpError { status, body }.into()); + } + + let payload: OrgScopedJwtResponse = response + .json() + .await + .context("failed to parse org-scoped jwt response")?; + + Ok(ResolvedBearerToken { + token: payload.token, + expires_at: Some(payload.expires_at), + }) +} + +pub async fn login(base: &BaseArgs) -> Result { + maybe_warn_api_key_override(base); + let auth = resolve_auth(base).await?; + let api_key = resolve_api_bearer_token(&auth).await?.token; let mut builder = BraintrustClient::builder() .blocking_login(true) @@ -490,8 +627,11 @@ pub async fn resolved_auth_env(base: &BaseArgs) -> Result> let auth = resolve_auth(base).await?; let mut envs = Vec::new(); - if let Some(api_key) = auth.api_key { - envs.push(("BRAINTRUST_API_KEY".to_string(), api_key)); + let token = resolve_api_bearer_token(&auth).await?; + envs.push(("BRAINTRUST_API_KEY".to_string(), token.token)); + + if let Some(expires_at) = token.expires_at { + envs.push(("BRAINTRUST_TOKEN_EXPIRES_AT".to_string(), expires_at)); } if let Some(api_url) = auth.api_url { envs.push(("BRAINTRUST_API_URL".to_string(), api_url)); @@ -2575,8 +2715,12 @@ fn auth_store_path() -> Result { #[cfg(test)] mod tests { use super::*; + use std::net::TcpListener; + use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; + use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer}; + fn make_base() -> BaseArgs { BaseArgs { json: false, @@ -2594,6 +2738,59 @@ mod tests { } } + #[derive(Debug, Default)] + struct MockOrgScopedJwtState { + authorization_header: Mutex>, + org_name: Mutex>, + } + + #[derive(Debug, Clone, Deserialize)] + struct MockOrgScopedJwtRequest { + org_name: String, + } + + async fn mock_org_scoped_jwt( + state: web::Data>, + request: HttpRequest, + body: web::Json, + ) -> HttpResponse { + let authorization_header = request + .headers() + .get("authorization") + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()); + *state + .authorization_header + .lock() + .expect("authorization_header lock") = authorization_header; + *state.org_name.lock().expect("org_name lock") = Some(body.org_name.clone()); + HttpResponse::Ok().json(serde_json::json!({ + "token": "scoped-token", + "expires_at": "2026-04-05T12:00:00Z" + })) + } + + fn start_mock_org_scoped_jwt_server() -> (String, Arc) { + let state = Arc::new(MockOrgScopedJwtState::default()); + let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind mock server"); + let addr = listener.local_addr().expect("mock server addr"); + let server_state = state.clone(); + let server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(server_state.clone())) + .route( + "/api/self/org_scoped_jwt", + web::post().to(mock_org_scoped_jwt), + ) + }) + .listen(listener) + .expect("listen mock server") + .run(); + tokio::spawn(server); + std::thread::sleep(Duration::from_millis(25)); + (format!("http://{addr}"), state) + } + #[test] fn default_app_url_is_www() { assert_eq!(DEFAULT_APP_URL, "https://www.braintrust.dev"); @@ -2983,6 +3180,86 @@ mod tests { assert_eq!(id.email, None); } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn resolve_api_bearer_token_keeps_api_keys_unchanged() { + let token = resolve_api_bearer_token(&ResolvedAuth { + api_key: Some("api-key".to_string()), + api_url: Some("https://api.example.com".to_string()), + app_url: Some("https://www.example.com".to_string()), + org_name: Some("acme".to_string()), + is_oauth: false, + }) + .await + .expect("resolve"); + + assert_eq!(token.token, "api-key"); + assert_eq!(token.expires_at, None); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn resolve_api_bearer_token_exchanges_oauth_for_org_scoped_jwt() { + let (app_url, state) = start_mock_org_scoped_jwt_server(); + let token = resolve_api_bearer_token(&ResolvedAuth { + api_key: Some("oauth-access-token".to_string()), + api_url: Some("https://api.example.com".to_string()), + app_url: Some(app_url), + org_name: Some("Acme".to_string()), + is_oauth: true, + }) + .await + .expect("resolve"); + + assert_eq!(token.token, "scoped-token"); + assert_eq!(token.expires_at.as_deref(), Some("2026-04-05T12:00:00Z")); + assert_eq!( + state + .authorization_header + .lock() + .expect("authorization_header lock") + .as_deref(), + Some("Bearer oauth-access-token") + ); + assert_eq!( + state.org_name.lock().expect("org_name lock").as_deref(), + Some("Acme") + ); + } + + #[test] + fn format_authorization_header_uses_standard_bearer_scheme() { + assert_eq!( + format_authorization_header("secret-token"), + "Authorization: Bearer secret-token" + ); + } + + #[test] + fn build_auth_token_output_includes_context() { + let output = build_auth_token_output( + &ResolvedBearerToken { + token: "secret-token".to_string(), + expires_at: Some("2026-04-05T12:00:00Z".to_string()), + }, + &ResolvedAuth { + api_key: Some("secret-token".to_string()), + api_url: Some("https://api.example.com".to_string()), + app_url: Some("https://www.example.com".to_string()), + org_name: Some("acme".to_string()), + is_oauth: true, + }, + ); + + assert_eq!(output.token, "secret-token"); + assert_eq!( + output.authorization_header, + "Authorization: Bearer secret-token" + ); + assert_eq!(output.api_url.as_deref(), Some("https://api.example.com")); + assert_eq!(output.app_url.as_deref(), Some("https://www.example.com")); + assert_eq!(output.org_name.as_deref(), Some("acme")); + assert_eq!(output.expires_at.as_deref(), Some("2026-04-05T12:00:00Z")); + } + #[test] fn format_verification_line_ok_with_identity() { let v = ProfileVerification { diff --git a/src/http.rs b/src/http.rs index e077add..638e66d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -9,10 +9,23 @@ use crate::auth::LoginContext; pub const DEFAULT_HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ServiceBase { + Api, + App, +} + +#[derive(Debug, Clone)] +pub struct RawRequestBody { + pub bytes: Vec, + pub content_type: Option, +} + #[derive(Clone)] pub struct ApiClient { http: Client, - base_url: String, + api_url: String, + app_url: String, api_key: String, org_name: String, } @@ -45,15 +58,24 @@ impl ApiClient { Ok(Self { http, - base_url: ctx.api_url.trim_end_matches('/').to_string(), + api_url: ctx.api_url.trim_end_matches('/').to_string(), + app_url: ctx.app_url.trim_end_matches('/').to_string(), api_key: ctx.login.api_key.clone(), org_name: ctx.login.org_name.clone(), }) } pub fn url(&self, path: &str) -> String { + self.url_for_service(ServiceBase::Api, path) + } + + pub fn url_for_service(&self, service: ServiceBase, path: &str) -> String { let path = path.trim_start_matches('/'); - format!("{}/{}", self.base_url, path) + let base_url = match service { + ServiceBase::Api => &self.api_url, + ServiceBase::App => &self.app_url, + }; + format!("{}/{}", base_url, path) } pub fn api_key(&self) -> &str { @@ -64,6 +86,30 @@ impl ApiClient { &self.org_name } + pub async fn request_raw( + &self, + method: reqwest::Method, + service: ServiceBase, + path: &str, + headers: &[(String, String)], + body: Option, + ) -> Result { + let url = self.url_for_service(service, path); + let mut request = self.http.request(method, &url).bearer_auth(&self.api_key); + + for (key, value) in headers { + request = request.header(key, value); + } + if let Some(body) = body { + if let Some(content_type) = body.content_type { + request = request.header("Content-Type", content_type); + } + request = request.body(body.bytes); + } + + request.send().await.context("request failed") + } + pub async fn get(&self, path: &str) -> Result { let url = self.url(path); let response = self diff --git a/src/main.rs b/src/main.rs index d5b838b..86762fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use anyhow::Result; use clap::{Parser, Subcommand}; use std::ffi::{OsStr, OsString}; +mod api; mod args; mod auth; #[allow(dead_code)] @@ -66,6 +67,7 @@ Projects & resources experiments Manage experiments Data & evaluation + api Send authenticated HTTP requests eval Run eval files sql Run SQL queries against Braintrust sync Synchronize project logs between Braintrust and local NDJSON files @@ -118,6 +120,8 @@ enum Commands { Setup(CLIArgs), /// Manage workflow docs for coding agents Docs(CLIArgs), + /// Send authenticated HTTP requests + Api(CLIArgs), /// Run SQL queries against Braintrust Sql(CLIArgs), /// Authenticate bt with Braintrust @@ -160,6 +164,7 @@ impl Commands { Commands::Init(cmd) => &cmd.base, Commands::Setup(cmd) => &cmd.base, Commands::Docs(cmd) => &cmd.base, + Commands::Api(cmd) => &cmd.base, Commands::Sql(cmd) => &cmd.base, Commands::Auth(cmd) => &cmd.base, Commands::View(cmd) => &cmd.base, @@ -214,6 +219,7 @@ async fn try_main() -> Result<()> { Commands::Auth(cmd) => auth::run(cmd.base, cmd.args).await?, Commands::View(cmd) => traces::run(cmd.base, cmd.args).await?, Commands::Init(cmd) => init::run(cmd.base, cmd.args).await?, + Commands::Api(cmd) => api::run(cmd.base, cmd.args).await?, Commands::Sql(cmd) => sql::run(cmd.base, cmd.args).await?, Commands::Setup(cmd) => setup::run_setup_top(cmd.base, cmd.args).await?, Commands::Docs(cmd) => setup::run_docs_top(cmd.base, cmd.args).await?,