Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions src/CommitStatus.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import qualified Data.Text as Text
import qualified Data.ByteString.Lazy as BL
import System.FileLock (withFileLock, SharedExclusive(..))
import System.Directory (doesFileExist)
import Utils (getCurrentCommit, logError, logDebug)
import Utils (getCurrentCommit, logError, logDebug, logWarn)
import Types (AppState(..), GithubClient(..), Settings(..))

-- Define the data types for the status update
Expand Down Expand Up @@ -118,6 +118,33 @@ loadOrRefreshClient appState = do
writeIORef appState.githubClient (Just client)
pure client

-- Force-refresh the token, ignoring both in-memory and file caches
forceRefreshClient :: AppState -> IO GithubClient
forceRefreshClient appState = do
let cacheFile = credentialsCacheFile appState.settings
let lockFile = cacheFile <> ".lock"

writeIORef appState.githubClient Nothing

client <- withFileLock lockFile Exclusive \_ ->
refreshToken appState cacheFile

writeIORef appState.githubClient (Just client)
pure client

-- Execute a GitHub API request, retrying once with a fresh token on 401
withFreshClient :: AppState -> (GithubClient -> IO (HTTP.Response BL.ByteString)) -> IO (HTTP.Response BL.ByteString)
withFreshClient appState doRequest = do
client <- getClient appState
response <- doRequest client
if response.responseStatus.statusCode == 401
then do
logWarn appState "GitHub API returned 401, force-refreshing token and retrying..."
freshClient <- forceRefreshClient appState
doRequest freshClient
else
pure response

-- Create new token and write to cache (caller should hold EXCLUSIVE lock)
refreshToken :: AppState -> FilePath -> IO GithubClient
refreshToken appState cacheFile = do
Expand Down Expand Up @@ -209,23 +236,23 @@ createTokenFromGitHub appState = do

updateCommitStatus :: MonadIO m => AppState -> StatusRequest -> m ()
updateCommitStatus appState statusRequest = liftIO do
client <- getClient appState
sha <- getCurrentCommit appState

-- Prepare the status update request
let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/statuses/" ++ toString sha
initStatusRequest <- HTTP.parseRequest statusUrl
let statusReq = initStatusRequest
{ HTTP.method = "POST"
, HTTP.requestHeaders =
[ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken)
, ("Accept", "application/vnd.github.v3+json")
, ("Content-Type", "application/json")
, ("User-Agent", "restaumatic-bot")
]
, HTTP.requestBody = HTTP.RequestBodyLBS $ encode statusRequest
}
statusResponse <- HTTP.httpLbs statusReq client.manager
statusResponse <- withFreshClient appState \client -> do
let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/statuses/" ++ toString sha
initStatusRequest <- HTTP.parseRequest statusUrl
let statusReq = initStatusRequest
{ HTTP.method = "POST"
, HTTP.requestHeaders =
[ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken)
, ("Accept", "application/vnd.github.v3+json")
, ("Content-Type", "application/json")
, ("User-Agent", "restaumatic-bot")
]
, HTTP.requestBody = HTTP.RequestBodyLBS $ encode statusRequest
}
HTTP.httpLbs statusReq client.manager

if statusResponse.responseStatus.statusCode == 201
then
logDebug appState "Commit status updated successfully"
Expand All @@ -237,21 +264,21 @@ updateCommitStatus appState statusRequest = liftIO do
-- Check if a status exists for the current commit and context
checkExistingStatus :: MonadIO m => AppState -> T.Text -> m Bool
checkExistingStatus appState contextName = liftIO do
client <- getClient appState
sha <- getCurrentCommit appState

-- Prepare the GET request for statuses
let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/commits/" ++ toString sha ++ "/statuses"
initStatusRequest <- HTTP.parseRequest statusUrl
let statusReq = initStatusRequest
{ HTTP.method = "GET"
, HTTP.requestHeaders =
[ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken)
, ("Accept", "application/vnd.github.v3+json")
, ("User-Agent", "restaumatic-bot")
]
}
statusResponse <- HTTP.httpLbs statusReq client.manager
statusResponse <- withFreshClient appState \client -> do
let statusUrl = toString client.apiUrl <> "/repos/" ++ toString client.owner ++ "/" ++ toString client.repo ++ "/commits/" ++ toString sha ++ "/statuses"
initStatusRequest <- HTTP.parseRequest statusUrl
let statusReq = initStatusRequest
{ HTTP.method = "GET"
, HTTP.requestHeaders =
[ ("Authorization", "Bearer " <> TE.encodeUtf8 client.accessToken)
, ("Accept", "application/vnd.github.v3+json")
, ("User-Agent", "restaumatic-bot")
]
}
HTTP.httpLbs statusReq client.manager

if statusResponse.responseStatus.statusCode == 200
then do
let mStatuses = eitherDecode @[StatusResponse] (HTTP.responseBody statusResponse)
Expand Down
94 changes: 74 additions & 20 deletions test/FakeGithubApi.hs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo #-}

module FakeGithubApi (Server, start, stop, clearOutput, getOutput, setTokenLifetime) where
module FakeGithubApi (Server, start, stop, clearOutput, getOutput, setTokenLifetime, setTokenExpirationOffset) where

import Universum

import Network.Wai
import qualified Network.Wai.Handler.Warp as Warp
import Network.HTTP.Types (status200, status201, status400, status404, methodPost, methodGet)
import Network.HTTP.Types (status200, status201, status400, status401, status404, methodPost, methodGet)
import Data.Aeson (encode, object, (.=), Value)
import qualified Data.Aeson as Aeson
import qualified Data.Map.Strict as Map
import Data.Time.Clock (getCurrentTime, addUTCTime)
import Data.Time.Clock (getCurrentTime, addUTCTime, UTCTime)
import Data.Time.Format.ISO8601 (iso8601Show)
import qualified Data.ByteString as BS

import Control.Concurrent (forkIO, ThreadId, killThread)

Expand All @@ -35,43 +36,87 @@ handleAccessTokenRequest server instId req respond =
then do
-- Read token lifetime from server state
lifetimeSeconds <- readIORef server.tokenLifetimeSeconds
offset <- readIORef server.tokenExpirationOffset
now <- getCurrentTime
let expiresAt = addUTCTime (fromIntegral lifetimeSeconds) now
let actualExpiry = addUTCTime (fromIntegral lifetimeSeconds) now
let reportedExpiry = addUTCTime (fromIntegral offset) actualExpiry

-- Issue unique token
n <- atomicModifyIORef server.tokenCounter (\c -> (c + 1, c + 1))
let tokenText = "mock-access-token-" <> show n

-- Track the token with its actual expiry
modifyIORef server.validTokens (Map.insert tokenText actualExpiry)

addOutput server $ "Requested access token for installation " <> instId
respond $ responseLBS status200 [("Content-Type", "application/json")]
(encode $ object
[ "token" .= ("mock-access-token" :: Text)
, "expires_at" .= iso8601Show expiresAt
[ "token" .= tokenText
, "expires_at" .= iso8601Show reportedExpiry
, "installation_id" .= instId
])
else respond $ responseLBS status400 [] "Bad Request"

-- Validate the Bearer token from the Authorization header.
-- Returns Nothing if valid, or a 401 response if invalid/expired.
validateToken :: Server -> Request -> IO (Maybe Response)
validateToken server req = do
tokens <- readIORef server.validTokens
-- If no tokens have been issued yet, skip validation (backwards compat)
if Map.null tokens
then pure Nothing
else do
now <- getCurrentTime
let mAuth = fmap snd $ find (\(k, _) -> k == "Authorization") (requestHeaders req)
case mAuth of
Just authHeader
| Just tokenBS <- BS.stripPrefix "Bearer " authHeader -> do
let tokenText = decodeUtf8 tokenBS
case Map.lookup tokenText tokens of
Just expiry
| now < expiry -> pure Nothing -- Valid
| otherwise -> pure $ Just $ responseLBS status401 [] "Token expired"
Nothing -> pure $ Just $ responseLBS status401 [] "Unknown token"
| otherwise -> pure $ Just $ responseLBS status401 [] "Invalid Authorization header"
Nothing -> pure $ Just $ responseLBS status401 [] "Missing Authorization header"

handleCommitStatusRequest :: Server -> Text -> Text -> Text -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
handleCommitStatusRequest server owner repo commitSha req respond =
if requestMethod req == methodPost
then do
body <- strictRequestBody req
-- Store the status for later retrieval
storeStatus server commitSha body
-- Note: commit SHA omitted because it's nondeterministic
addOutput server $ "Updated commit status for " <> owner <> "/" <> repo <> " to " <> decodeUtf8 body
respond $ responseLBS status201 [("Content-Type", "application/json")]
(encode $ object ["state" .= ("success" :: Text), "sha" .= commitSha, "repository" .= repo, "owner" .= owner])
mReject <- validateToken server req
case mReject of
Just rejection -> respond rejection
Nothing -> do
body <- strictRequestBody req
-- Store the status for later retrieval
storeStatus server commitSha body
-- Note: commit SHA omitted because it's nondeterministic
addOutput server $ "Updated commit status for " <> owner <> "/" <> repo <> " to " <> decodeUtf8 body
respond $ responseLBS status201 [("Content-Type", "application/json")]
(encode $ object ["state" .= ("success" :: Text), "sha" .= commitSha, "repository" .= repo, "owner" .= owner])
else respond $ responseLBS status400 [] "Bad Request"

handleGetCommitStatuses :: Server -> Text -> Text -> Text -> Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
handleGetCommitStatuses server _owner _repo commitSha req respond =
if requestMethod req == methodGet
then do
statuses <- getStatuses server commitSha
respond $ responseLBS status200 [("Content-Type", "application/json")] (encode statuses)
mReject <- validateToken server req
case mReject of
Just rejection -> respond rejection
Nothing -> do
statuses <- getStatuses server commitSha
respond $ responseLBS status200 [("Content-Type", "application/json")] (encode statuses)
else respond $ responseLBS status400 [] "Bad Request"

data Server = Server
{ tid :: ThreadId
, output :: IORef [Text]
, statuses :: IORef (Map Text [Value]) -- Map from commit SHA to list of status objects
, tokenLifetimeSeconds :: IORef Int
, tokenCounter :: IORef Int
, validTokens :: IORef (Map Text UTCTime) -- Map from token to actual expiry time
, tokenExpirationOffset :: IORef Int -- Seconds to add to reported expires_at (simulates clock skew)
}

start :: Int -> IO Server
Expand All @@ -80,9 +125,12 @@ start port = do
output <- newIORef []
statuses <- newIORef Map.empty
tokenLifetimeSeconds <- newIORef 3600 -- Default: 1 hour
tokenCounter <- newIORef 0
validTokens <- newIORef Map.empty
tokenExpirationOffset <- newIORef 0
let settings = Warp.setPort port $ Warp.setBeforeMainLoop (putMVar started ()) Warp.defaultSettings
rec
let server = Server {tid, output, statuses, tokenLifetimeSeconds}
let server = Server {tid, output, statuses, tokenLifetimeSeconds, tokenCounter, validTokens, tokenExpirationOffset}
tid <- forkIO $ Warp.runSettings settings $ app server
takeMVar started
pure server
Expand All @@ -94,10 +142,13 @@ addOutput :: Server -> Text -> IO ()
addOutput (Server {output}) msg = modifyIORef output (msg :)

clearOutput :: Server -> IO ()
clearOutput (Server {output, statuses, tokenLifetimeSeconds}) = do
writeIORef output []
writeIORef statuses Map.empty
writeIORef tokenLifetimeSeconds 3600 -- Reset to default
clearOutput server = do
writeIORef server.output []
writeIORef server.statuses Map.empty
writeIORef server.tokenLifetimeSeconds 3600 -- Reset to default
writeIORef server.tokenCounter 0
writeIORef server.validTokens Map.empty
writeIORef server.tokenExpirationOffset 0

getOutput :: Server -> IO [Text]
getOutput (Server {output}) = reverse <$> readIORef output
Expand All @@ -116,3 +167,6 @@ getStatuses (Server {statuses}) commitSha = do

setTokenLifetime :: Server -> Int -> IO ()
setTokenLifetime server seconds = writeIORef server.tokenLifetimeSeconds seconds

setTokenExpirationOffset :: Server -> Int -> IO ()
setTokenExpirationOffset server seconds = writeIORef server.tokenExpirationOffset seconds
9 changes: 9 additions & 0 deletions test/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@
whenJust options.githubTokenLifetime $ \lifetime ->
FakeGithubApi.setTokenLifetime fakeGithubServer lifetime

-- Set token expiration offset if specified in test
whenJust options.githubTokenExpirationOffset $ \offset ->
FakeGithubApi.setTokenExpirationOffset fakeGithubServer offset

(pipeRead, pipeWrite) <- createPipe
path <- getEnv "PATH"

Expand Down Expand Up @@ -175,6 +179,7 @@
, githubKeys :: Bool
, quiet :: Bool
, githubTokenLifetime :: Maybe Int
, githubTokenExpirationOffset :: Maybe Int
}

instance Default Options where
Expand All @@ -185,6 +190,7 @@
, githubKeys = False
, quiet = False
, githubTokenLifetime = Nothing
, githubTokenExpirationOffset = Nothing
}

getOptions :: Text -> Options
Expand All @@ -207,8 +213,11 @@
["#", "github", "token", "lifetime", n] -> do
modify (\s -> s { githubTokenLifetime = readMaybe (toString n) })
go rest
["#", "github", "token", "expiration", "offset", n] -> do
modify (\s -> s { githubTokenExpirationOffset = readMaybe (toString n) })
go rest
["#", "quiet"] -> do
modify (\s -> (s :: Options) { quiet = True })

Check warning on line 220 in test/Spec.hs

View workflow job for this annotation

GitHub Actions / build (x86_64, ubuntu-latest)

The record update (s :: Options)

Check warning on line 220 in test/Spec.hs

View workflow job for this annotation

GitHub Actions / build (arm64, ubuntu-24.04-arm)

The record update (s :: Options)
go rest
-- TODO: validate?
_ ->
Expand Down
9 changes: 9 additions & 0 deletions test/t/slow/github-token-retry-on-401.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- output:
[mytask] stdout | Task started, pending status posted
[mytask] stdout | Task finishing (should retry on 401)
[mytask] warn | GitHub API returned 401, force-refreshing token and retrying...
-- github:
Requested access token for installation 123
Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":"not cached","state":"pending","target_url":null}
Requested access token for installation 123
Updated commit status for fakeowner/fakerepo to {"context":"mytask","description":null,"state":"success","target_url":null}
18 changes: 18 additions & 0 deletions test/t/slow/github-token-retry-on-401.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# check output github
# no toplevel
# github keys
# github token lifetime 2
# github token expiration offset 5

export TASKRUNNER_ENABLE_COMMIT_STATUS=1
export TASKRUNNER_GITHUB_TOKEN_REFRESH_THRESHOLD_SECONDS=0

git init -q
git commit --allow-empty -q -m "Initial commit"

taskrunner -n mytask bash -e -c '
snapshot -n --commit-status
echo "Task started, pending status posted"
sleep 3
echo "Task finishing (should retry on 401)"
'
Loading