Skip to content

Commit 729857c

Browse files
Edition-Xforge-code-agentamitksingh1490
authored
fix(auth): preserve Codex id_token during OAuth exchange (#2946)
Co-authored-by: ForgeCode <noreply@forgecode.dev> Co-authored-by: Amit Singh <amitksingh1490@gmail.com>
1 parent 7c430bb commit 729857c

3 files changed

Lines changed: 107 additions & 85 deletions

File tree

crates/forge_infra/src/auth/http/standard.rs

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
use forge_app::OAuthHttpProvider;
22
use forge_domain::{AuthCodeParams, OAuthConfig, OAuthTokenResponse};
3-
use oauth2::{
4-
AuthorizationCode as OAuth2AuthCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope,
5-
};
3+
use oauth2::{CsrfToken, PkceCodeChallenge, Scope};
4+
use serde::Serialize;
65

76
use crate::auth::util::*;
87

98
/// Standard RFC-compliant OAuth provider
109
pub struct StandardHttpProvider;
1110

11+
#[derive(Debug, Serialize)]
12+
struct StandardTokenRequest<'a> {
13+
grant_type: &'static str,
14+
code: &'a str,
15+
client_id: &'a str,
16+
#[serde(skip_serializing_if = "Option::is_none")]
17+
redirect_uri: Option<&'a str>,
18+
#[serde(skip_serializing_if = "Option::is_none")]
19+
code_verifier: Option<&'a str>,
20+
}
21+
1222
#[async_trait::async_trait]
1323
impl OAuthHttpProvider for StandardHttpProvider {
1424
async fn build_auth_url(&self, config: &OAuthConfig) -> anyhow::Result<AuthCodeParams> {
@@ -58,27 +68,33 @@ impl OAuthHttpProvider for StandardHttpProvider {
5868
code: &str,
5969
verifier: Option<&str>,
6070
) -> anyhow::Result<OAuthTokenResponse> {
61-
use oauth2::{AuthUrl, ClientId, TokenUrl};
62-
63-
let mut client =
64-
oauth2::basic::BasicClient::new(ClientId::new(config.client_id.to_string()))
65-
.set_auth_uri(AuthUrl::new(config.auth_url.to_string())?)
66-
.set_token_uri(TokenUrl::new(config.token_url.to_string())?);
67-
68-
if let Some(redirect_uri) = &config.redirect_uri {
69-
client = client.set_redirect_uri(oauth2::RedirectUrl::new(redirect_uri.clone())?);
70-
}
71-
7271
let http_client = self.build_http_client(config)?;
72+
let request_body = StandardTokenRequest {
73+
grant_type: "authorization_code",
74+
code,
75+
client_id: config.client_id.as_ref(),
76+
redirect_uri: config.redirect_uri.as_deref(),
77+
code_verifier: verifier,
78+
};
79+
80+
let response = http_client
81+
.post(config.token_url.as_str())
82+
.header("Content-Type", "application/x-www-form-urlencoded")
83+
.header("Accept", "application/json")
84+
.body(serde_urlencoded::to_string(&request_body)?)
85+
.send()
86+
.await?;
7387

74-
let mut request = client.exchange_code(OAuth2AuthCode::new(code.to_string()));
88+
let status = response.status();
89+
let body = response.text().await?;
7590

76-
if let Some(v) = verifier {
77-
request = request.set_pkce_verifier(PkceCodeVerifier::new(v.to_string()));
91+
if !status.is_success() {
92+
anyhow::bail!("OAuth token exchange failed ({status}): {body}");
7893
}
7994

80-
let token_result = request.request_async(&http_client).await?;
81-
Ok(into_domain(token_result))
95+
// Parse the raw token payload so provider-specific fields like
96+
// `id_token` are preserved instead of being dropped by generic helpers.
97+
Ok(parse_token_response(&body)?)
8298
}
8399

84100
/// Create HTTP client with provider-specific headers/behavior

crates/forge_infra/src/auth/strategy.rs

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -772,23 +772,12 @@ async fn poll_for_tokens(
772772
}
773773

774774
// No error field - parse as success
775-
let (access_token, refresh_token, expires_in) = parse_token_response(&body_text)?;
776-
777-
return Ok(build_token_response(
778-
access_token,
779-
refresh_token,
780-
expires_in,
781-
));
775+
return Ok(parse_token_response(&body_text)?);
782776
}
783777

784778
// Standard OAuth: HTTP success means tokens
785779
if !github_compatible && status.is_success() {
786-
let (access_token, refresh_token, expires_in) = parse_token_response(&body_text)?;
787-
return Ok(build_token_response(
788-
access_token,
789-
refresh_token,
790-
expires_in,
791-
));
780+
return Ok(parse_token_response(&body_text)?);
792781
}
793782

794783
// Handle error responses (non-200 status for standard OAuth)
@@ -911,16 +900,11 @@ async fn codex_poll_for_tokens(
911900
.into());
912901
}
913902

914-
let (access_token, refresh_token, expires_in) =
915-
parse_token_response(&token_response.text().await.map_err(|e| {
903+
return Ok(parse_token_response(
904+
&token_response.text().await.map_err(|e| {
916905
AuthError::PollFailed(format!("Failed to read token response: {e}"))
917-
})?)?;
918-
919-
return Ok(build_token_response(
920-
access_token,
921-
refresh_token,
922-
expires_in,
923-
));
906+
})?,
907+
)?);
924908
}
925909

926910
// 403/404 means authorization pending (user hasn't entered code yet)
@@ -1323,6 +1307,49 @@ mod tests {
13231307
assert_eq!(actual, None);
13241308
}
13251309

1310+
#[test]
1311+
fn test_enrich_codex_oauth_credential_uses_id_token_claims() {
1312+
let fixture_id_token = build_jwt(&serde_json::json!({
1313+
"chatgpt_account_id": "acct_from_id_token"
1314+
}));
1315+
let fixture_access_token = "not-a-jwt";
1316+
let mut actual = AuthCredential::new_oauth(
1317+
ProviderId::CODEX,
1318+
OAuthTokens::new(
1319+
fixture_access_token,
1320+
None::<String>,
1321+
chrono::Utc::now() + chrono::Duration::hours(1),
1322+
),
1323+
OAuthConfig {
1324+
client_id: "test".to_string().into(),
1325+
auth_url: Url::parse("https://example.com/auth").unwrap(),
1326+
token_url: Url::parse("https://example.com/token").unwrap(),
1327+
scopes: vec![],
1328+
redirect_uri: Some("http://localhost:1455/auth/callback".to_string()),
1329+
use_pkce: true,
1330+
token_refresh_url: None,
1331+
extra_auth_params: None,
1332+
custom_headers: None,
1333+
},
1334+
);
1335+
1336+
enrich_codex_oauth_credential(
1337+
&ProviderId::CODEX,
1338+
&mut actual,
1339+
Some(&fixture_id_token),
1340+
fixture_access_token,
1341+
);
1342+
1343+
let actual = actual
1344+
.url_params
1345+
.get(&URLParam::from("chatgpt_account_id".to_string()));
1346+
let expected = Some(&forge_domain::URLParamValue::from(
1347+
"acct_from_id_token".to_string(),
1348+
));
1349+
1350+
assert_eq!(actual, expected);
1351+
}
1352+
13261353
#[tokio::test]
13271354
async fn test_refresh_oauth_credential_preserves_url_params() {
13281355
let fixture_config = OAuthConfig {

crates/forge_infra/src/auth/util.rs

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,6 @@ pub(crate) fn build_oauth_credential(
8686
))
8787
}
8888

89-
/// Build OAuthTokenResponse with standard defaults
90-
pub(crate) fn build_token_response(
91-
access_token: String,
92-
refresh_token: Option<String>,
93-
expires_in: Option<u64>,
94-
) -> OAuthTokenResponse {
95-
OAuthTokenResponse {
96-
access_token,
97-
refresh_token,
98-
expires_in,
99-
expires_at: None,
100-
token_type: "Bearer".to_string(),
101-
scope: None,
102-
id_token: None,
103-
}
104-
}
105-
10689
/// Extract OAuth tokens from any credential type
10790
pub(crate) fn extract_oauth_tokens(credential: &AuthCredential) -> anyhow::Result<&OAuthTokens> {
10891
match &credential.auth_details {
@@ -217,25 +200,18 @@ pub(crate) fn handle_oauth_error(error_code: &str) -> Result<(), Error> {
217200
}
218201
}
219202

220-
/// Parse token response from JSON
221-
pub(crate) fn parse_token_response(
222-
body: &str,
223-
) -> Result<(String, Option<String>, Option<u64>), Error> {
224-
let token_response: serde_json::Value = serde_json::from_str(body)
203+
/// Parse token response from JSON.
204+
pub(crate) fn parse_token_response(body: &str) -> Result<OAuthTokenResponse, Error> {
205+
let token_response: OAuthTokenResponse = serde_json::from_str(body)
225206
.map_err(|e| Error::PollFailed(format!("Failed to parse token response: {e}")))?;
226207

227-
let access_token = token_response["access_token"]
228-
.as_str()
229-
.ok_or_else(|| Error::PollFailed("Missing access_token in response".to_string()))?
230-
.to_string();
231-
232-
let refresh_token = token_response["refresh_token"]
233-
.as_str()
234-
.map(|s| s.to_string());
235-
236-
let expires_in = token_response["expires_in"].as_u64();
208+
if token_response.access_token.trim().is_empty() {
209+
return Err(Error::PollFailed(
210+
"Missing access_token in response".to_string(),
211+
));
212+
}
237213

238-
Ok((access_token, refresh_token, expires_in))
214+
Ok(token_response)
239215
}
240216

241217
#[cfg(test)]
@@ -265,17 +241,20 @@ mod tests {
265241
}
266242

267243
#[test]
268-
fn test_build_token_response() {
269-
let response = build_token_response(
270-
"test_token".to_string(),
271-
Some("refresh_token".to_string()),
272-
Some(3600),
273-
);
274-
275-
assert_eq!(response.access_token, "test_token");
276-
assert_eq!(response.refresh_token, Some("refresh_token".to_string()));
277-
assert_eq!(response.expires_in, Some(3600));
278-
assert_eq!(response.token_type, "Bearer");
244+
fn test_parse_token_response_preserves_id_token() {
245+
let fixture = r#"{
246+
"access_token": "test_token",
247+
"refresh_token": "refresh_token",
248+
"expires_in": 3600,
249+
"id_token": "test_id_token"
250+
}"#;
251+
252+
let actual = parse_token_response(fixture).unwrap();
253+
254+
assert_eq!(actual.access_token, "test_token");
255+
assert_eq!(actual.refresh_token, Some("refresh_token".to_string()));
256+
assert_eq!(actual.expires_in, Some(3600));
257+
assert_eq!(actual.id_token, Some("test_id_token".to_string()));
279258
}
280259

281260
#[test]

0 commit comments

Comments
 (0)