diff --git a/src/CommitStatus.hs b/src/CommitStatus.hs index 2ce6806..3a6f839 100644 --- a/src/CommitStatus.hs +++ b/src/CommitStatus.hs @@ -20,7 +20,7 @@ import qualified Data.Text as Text import qualified Data.ByteString.Lazy as BL import System.FileLock (withFileLock, SharedExclusive(..)) import System.Directory (doesFileExist) -import Utils (getCurrentCommit, logError, logDebug) +import Utils (getCurrentCommit, logError, logDebug, logWarn) import Types (AppState(..), GithubClient(..), Settings(..)) -- Define the data types for the status update @@ -118,6 +118,33 @@ loadOrRefreshClient appState = do writeIORef appState.githubClient (Just client) pure client +-- Force-refresh the token, ignoring both in-memory and file caches +forceRefreshClient :: AppState -> IO GithubClient +forceRefreshClient appState = do + let cacheFile = credentialsCacheFile appState.settings + let lockFile = cacheFile <> ".lock" + + writeIORef appState.githubClient Nothing + + client <- withFileLock lockFile Exclusive \_ -> + refreshToken appState cacheFile + + writeIORef appState.githubClient (Just client) + pure client + +-- Execute a GitHub API request, retrying once with a fresh token on 401 +withFreshClient :: AppState -> (GithubClient -> IO (HTTP.Response BL.ByteString)) -> IO (HTTP.Response BL.ByteString) +withFreshClient appState doRequest = do + client <- getClient appState + response <- doRequest client + if response.responseStatus.statusCode == 401 + then do + logWarn appState "GitHub API returned 401, force-refreshing token and retrying..." + freshClient <- forceRefreshClient appState + doRequest freshClient + else + pure response + -- Create new token and write to cache (caller should hold EXCLUSIVE lock) refreshToken :: AppState -> FilePath -> IO GithubClient refreshToken appState cacheFile = do @@ -209,23 +236,23 @@ createTokenFromGitHub appState = do updateCommitStatus :: MonadIO m => AppState -> StatusRequest -> m () updateCommitStatus appState statusRequest = liftIO do - client <- getClient appState sha <- getCurrentCommit appState - -- Prepare the status update request - let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/statuses/" ++ toString sha - initStatusRequest <- HTTP.parseRequest statusUrl - let statusReq = initStatusRequest - { HTTP.method = "POST" - , HTTP.requestHeaders = - [ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken) - , ("Accept", "application/vnd.github.v3+json") - , ("Content-Type", "application/json") - , ("User-Agent", "restaumatic-bot") - ] - , HTTP.requestBody = HTTP.RequestBodyLBS $ encode statusRequest - } - statusResponse <- HTTP.httpLbs statusReq client.manager + statusResponse <- withFreshClient appState \client -> do + let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/statuses/" ++ toString sha + initStatusRequest <- HTTP.parseRequest statusUrl + let statusReq = initStatusRequest + { HTTP.method = "POST" + , HTTP.requestHeaders = + [ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken) + , ("Accept", "application/vnd.github.v3+json") + , ("Content-Type", "application/json") + , ("User-Agent", "restaumatic-bot") + ] + , HTTP.requestBody = HTTP.RequestBodyLBS $ encode statusRequest + } + HTTP.httpLbs statusReq client.manager + if statusResponse.responseStatus.statusCode == 201 then logDebug appState "Commit status updated successfully" @@ -237,21 +264,21 @@ updateCommitStatus appState statusRequest = liftIO do -- Check if a status exists for the current commit and context checkExistingStatus :: MonadIO m => AppState -> T.Text -> m Bool checkExistingStatus appState contextName = liftIO do - client <- getClient appState sha <- getCurrentCommit appState - -- Prepare the GET request for statuses - let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/commits/" ++ toString sha ++ "/statuses" - initStatusRequest <- HTTP.parseRequest statusUrl - let statusReq = initStatusRequest - { HTTP.method = "GET" - , HTTP.requestHeaders = - [ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken) - , ("Accept", "application/vnd.github.v3+json") - , ("User-Agent", "restaumatic-bot") - ] - } - statusResponse <- HTTP.httpLbs statusReq client.manager + statusResponse <- withFreshClient appState \client -> do + let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/commits/" ++ toString sha ++ "/statuses" + initStatusRequest <- HTTP.parseRequest statusUrl + let statusReq = initStatusRequest + { HTTP.method = "GET" + , HTTP.requestHeaders = + [ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken) + , ("Accept", "application/vnd.github.v3+json") + , ("User-Agent", "restaumatic-bot") + ] + } + HTTP.httpLbs statusReq client.manager + if statusResponse.responseStatus.statusCode == 200 then do let mStatuses = eitherDecode @[StatusResponse] (HTTP.responseBody statusResponse) diff --git a/test/FakeGithubApi.hs b/test/FakeGithubApi.hs index c1645b4..616aa0f 100644 --- a/test/FakeGithubApi.hs +++ b/test/FakeGithubApi.hs @@ -1,18 +1,19 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecursiveDo #-} -module FakeGithubApi (Server, start, stop, clearOutput, getOutput, setTokenLifetime) where +module FakeGithubApi (Server, start, stop, clearOutput, getOutput, setTokenLifetime, setTokenExpirationOffset) where import Universum import Network.Wai import qualified Network.Wai.Handler.Warp as Warp -import Network.HTTP.Types (status200, status201, status400, status404, methodPost, methodGet) +import Network.HTTP.Types (status200, status201, status400, status401, status404, methodPost, methodGet) import Data.Aeson (encode, object, (.=), Value) import qualified Data.Aeson as Aeson import qualified Data.Map.Strict as Map -import Data.Time.Clock (getCurrentTime, addUTCTime) +import Data.Time.Clock (getCurrentTime, addUTCTime, UTCTime) import Data.Time.Format.ISO8601 (iso8601Show) +import qualified Data.ByteString as BS import Control.Concurrent (forkIO, ThreadId, killThread) @@ -35,36 +36,77 @@ handleAccessTokenRequest server instId req respond = then do -- Read token lifetime from server state lifetimeSeconds <- readIORef server.tokenLifetimeSeconds + offset <- readIORef server.tokenExpirationOffset now <- getCurrentTime - let expiresAt = addUTCTime (fromIntegral lifetimeSeconds) now + let actualExpiry = addUTCTime (fromIntegral lifetimeSeconds) now + let reportedExpiry = addUTCTime (fromIntegral offset) actualExpiry + + -- Issue unique token + n <- atomicModifyIORef server.tokenCounter (\c -> (c + 1, c + 1)) + let tokenText = "mock-access-token-" <> show n + + -- Track the token with its actual expiry + modifyIORef server.validTokens (Map.insert tokenText actualExpiry) + addOutput server $ "Requested access token for installation " <> instId respond $ responseLBS status200 [("Content-Type", "application/json")] (encode $ object - [ "token" .= ("mock-access-token" :: Text) - , "expires_at" .= iso8601Show expiresAt + [ "token" .= tokenText + , "expires_at" .= iso8601Show reportedExpiry , "installation_id" .= instId ]) else respond $ responseLBS status400 [] "Bad Request" +-- Validate the Bearer token from the Authorization header. +-- Returns Nothing if valid, or a 401 response if invalid/expired. +validateToken :: Server -> Request -> IO (Maybe Response) +validateToken server req = do + tokens <- readIORef server.validTokens + -- If no tokens have been issued yet, skip validation (backwards compat) + if Map.null tokens + then pure Nothing + else do + now <- getCurrentTime + let mAuth = fmap snd $ find (\(k, _) -> k == "Authorization") (requestHeaders req) + case mAuth of + Just authHeader + | Just tokenBS <- BS.stripPrefix "Bearer " authHeader -> do + let tokenText = decodeUtf8 tokenBS + case Map.lookup tokenText tokens of + Just expiry + | now < expiry -> pure Nothing -- Valid + | otherwise -> pure $ Just $ responseLBS status401 [] "Token expired" + Nothing -> pure $ Just $ responseLBS status401 [] "Unknown token" + | otherwise -> pure $ Just $ responseLBS status401 [] "Invalid Authorization header" + Nothing -> pure $ Just $ responseLBS status401 [] "Missing Authorization header" + handleCommitStatusRequest :: Server -> Text -> Text -> Text -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived handleCommitStatusRequest server owner repo commitSha req respond = if requestMethod req == methodPost then do - body <- strictRequestBody req - -- Store the status for later retrieval - storeStatus server commitSha body - -- Note: commit SHA omitted because it's nondeterministic - addOutput server $ "Updated commit status for " <> owner <> "/" <> repo <> " to " <> decodeUtf8 body - respond $ responseLBS status201 [("Content-Type", "application/json")] - (encode $ object ["state" .= ("success" :: Text), "sha" .= commitSha, "repository" .= repo, "owner" .= owner]) + mReject <- validateToken server req + case mReject of + Just rejection -> respond rejection + Nothing -> do + body <- strictRequestBody req + -- Store the status for later retrieval + storeStatus server commitSha body + -- Note: commit SHA omitted because it's nondeterministic + addOutput server $ "Updated commit status for " <> owner <> "/" <> repo <> " to " <> decodeUtf8 body + respond $ responseLBS status201 [("Content-Type", "application/json")] + (encode $ object ["state" .= ("success" :: Text), "sha" .= commitSha, "repository" .= repo, "owner" .= owner]) else respond $ responseLBS status400 [] "Bad Request" handleGetCommitStatuses :: Server -> Text -> Text -> Text -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived handleGetCommitStatuses server _owner _repo commitSha req respond = if requestMethod req == methodGet then do - statuses <- getStatuses server commitSha - respond $ responseLBS status200 [("Content-Type", "application/json")] (encode statuses) + mReject <- validateToken server req + case mReject of + Just rejection -> respond rejection + Nothing -> do + statuses <- getStatuses server commitSha + respond $ responseLBS status200 [("Content-Type", "application/json")] (encode statuses) else respond $ responseLBS status400 [] "Bad Request" data Server = Server @@ -72,6 +114,9 @@ data Server = Server , output :: IORef [Text] , statuses :: IORef (Map Text [Value]) -- Map from commit SHA to list of status objects , tokenLifetimeSeconds :: IORef Int + , tokenCounter :: IORef Int + , validTokens :: IORef (Map Text UTCTime) -- Map from token to actual expiry time + , tokenExpirationOffset :: IORef Int -- Seconds to add to reported expires_at (simulates clock skew) } start :: Int -> IO Server @@ -80,9 +125,12 @@ start port = do output <- newIORef [] statuses <- newIORef Map.empty tokenLifetimeSeconds <- newIORef 3600 -- Default: 1 hour + tokenCounter <- newIORef 0 + validTokens <- newIORef Map.empty + tokenExpirationOffset <- newIORef 0 let settings = Warp.setPort port $ Warp.setBeforeMainLoop (putMVar started ()) Warp.defaultSettings rec - let server = Server {tid, output, statuses, tokenLifetimeSeconds} + let server = Server {tid, output, statuses, tokenLifetimeSeconds, tokenCounter, validTokens, tokenExpirationOffset} tid <- forkIO $ Warp.runSettings settings $ app server takeMVar started pure server @@ -94,10 +142,13 @@ addOutput :: Server -> Text -> IO () addOutput (Server {output}) msg = modifyIORef output (msg :) clearOutput :: Server -> IO () -clearOutput (Server {output, statuses, tokenLifetimeSeconds}) = do - writeIORef output [] - writeIORef statuses Map.empty - writeIORef tokenLifetimeSeconds 3600 -- Reset to default +clearOutput server = do + writeIORef server.output [] + writeIORef server.statuses Map.empty + writeIORef server.tokenLifetimeSeconds 3600 -- Reset to default + writeIORef server.tokenCounter 0 + writeIORef server.validTokens Map.empty + writeIORef server.tokenExpirationOffset 0 getOutput :: Server -> IO [Text] getOutput (Server {output}) = reverse <$> readIORef output @@ -116,3 +167,6 @@ getStatuses (Server {statuses}) commitSha = do setTokenLifetime :: Server -> Int -> IO () setTokenLifetime server seconds = writeIORef server.tokenLifetimeSeconds seconds + +setTokenExpirationOffset :: Server -> Int -> IO () +setTokenExpirationOffset server seconds = writeIORef server.tokenExpirationOffset seconds diff --git a/test/Spec.hs b/test/Spec.hs index 2f8d1d6..20976f6 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -91,6 +91,10 @@ runTest fakeGithubServer source = do whenJust options.githubTokenLifetime $ \lifetime -> FakeGithubApi.setTokenLifetime fakeGithubServer lifetime + -- Set token expiration offset if specified in test + whenJust options.githubTokenExpirationOffset $ \offset -> + FakeGithubApi.setTokenExpirationOffset fakeGithubServer offset + (pipeRead, pipeWrite) <- createPipe path <- getEnv "PATH" @@ -175,6 +179,7 @@ data Options = Options , githubKeys :: Bool , quiet :: Bool , githubTokenLifetime :: Maybe Int + , githubTokenExpirationOffset :: Maybe Int } instance Default Options where @@ -185,6 +190,7 @@ instance Default Options where , githubKeys = False , quiet = False , githubTokenLifetime = Nothing + , githubTokenExpirationOffset = Nothing } getOptions :: Text -> Options @@ -207,6 +213,9 @@ getOptions source = flip execState def $ go (lines source) ["#", "github", "token", "lifetime", n] -> do modify (\s -> s { githubTokenLifetime = readMaybe (toString n) }) go rest + ["#", "github", "token", "expiration", "offset", n] -> do + modify (\s -> s { githubTokenExpirationOffset = readMaybe (toString n) }) + go rest ["#", "quiet"] -> do modify (\s -> (s :: Options) { quiet = True }) go rest diff --git a/test/t/slow/github-token-retry-on-401.out b/test/t/slow/github-token-retry-on-401.out new file mode 100644 index 0000000..115b1b8 --- /dev/null +++ b/test/t/slow/github-token-retry-on-401.out @@ -0,0 +1,9 @@ +-- output: +[mytask] stdout | Task started, pending status posted +[mytask] stdout | Task finishing (should retry on 401) +[mytask] warn | GitHub API returned 401, force-refreshing token and retrying... +-- github: +Requested access token for installation 123 +Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":"not cached","state":"pending","target_url":null} +Requested access token for installation 123 +Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":null,"state":"success","target_url":null} diff --git a/test/t/slow/github-token-retry-on-401.txt b/test/t/slow/github-token-retry-on-401.txt new file mode 100644 index 0000000..b20b652 --- /dev/null +++ b/test/t/slow/github-token-retry-on-401.txt @@ -0,0 +1,18 @@ +# check output github +# no toplevel +# github keys +# github token lifetime 2 +# github token expiration offset 5 + +export TASKRUNNER_ENABLE_COMMIT_STATUS=1 +export TASKRUNNER_GITHUB_TOKEN_REFRESH_THRESHOLD_SECONDS=0 + +git init -q +git commit --allow-empty -q -m "Initial commit" + +taskrunner -n mytask bash -e -c ' + snapshot -n --commit-status + echo "Task started, pending status posted" + sleep 3 + echo "Task finishing (should retry on 401)" +'