-
Notifications
You must be signed in to change notification settings - Fork 733
Fix/graphql client token refresh #8791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
88fc30c
b588955
611ff3e
83ccc18
1d380ae
ab0a5fb
07132ec
6b1268d
4e49845
f8c3078
afe168b
9cb0970
c7c5e72
c8927e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,14 @@ package api | |
| import ( | ||
| "context" | ||
| "fmt" | ||
| "strconv" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/apache/incubator-devlake/core/errors" | ||
| "github.com/apache/incubator-devlake/core/log" | ||
| "github.com/apache/incubator-devlake/core/plugin" | ||
| "github.com/apache/incubator-devlake/core/utils" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/merico-ai/graphql" | ||
| ) | ||
|
|
@@ -47,30 +49,52 @@ type GraphqlAsyncClient struct { | |
| getRateCost func(q interface{}) int | ||
| } | ||
|
|
||
| // defaultRateLimitConst is the generic fallback rate limit for GraphQL requests. | ||
| // It is used as the initial remaining quota when dynamic rate limit | ||
| // information is unavailable from the provider. | ||
| const defaultRateLimitConst = 1000 | ||
|
|
||
| // CreateAsyncGraphqlClient creates a new GraphqlAsyncClient | ||
| func CreateAsyncGraphqlClient( | ||
| taskCtx plugin.TaskContext, | ||
| graphqlClient *graphql.Client, | ||
| logger log.Logger, | ||
| getRateRemaining func(context.Context, *graphql.Client, log.Logger) (rateRemaining int, resetAt *time.Time, err errors.Error), | ||
| opts ...func(*GraphqlAsyncClient), | ||
| ) (*GraphqlAsyncClient, errors.Error) { | ||
| ctxWithCancel, cancel := context.WithCancel(taskCtx.GetContext()) | ||
|
|
||
| graphqlAsyncClient := &GraphqlAsyncClient{ | ||
| ctx: ctxWithCancel, | ||
| cancel: cancel, | ||
| client: graphqlClient, | ||
| logger: logger, | ||
| rateExhaustCond: sync.NewCond(&sync.Mutex{}), | ||
| rateRemaining: 0, | ||
| rateRemaining: defaultRateLimitConst, | ||
| getRateRemaining: getRateRemaining, | ||
| } | ||
|
|
||
| // apply options | ||
| for _, opt := range opts { | ||
| opt(graphqlAsyncClient) | ||
| } | ||
|
|
||
| // Env config wins over everything, only if explicitly set | ||
| if rateLimit := resolveRateLimit(taskCtx, logger); rateLimit != -1 { | ||
| logger.Info("GRAPHQL_RATE_LIMIT env override applied: %d (was %d)", rateLimit, graphqlAsyncClient.rateRemaining) | ||
| graphqlAsyncClient.rateRemaining = rateLimit | ||
| } | ||
|
|
||
| if getRateRemaining != nil { | ||
| rateRemaining, resetAt, err := getRateRemaining(taskCtx.GetContext(), graphqlClient, logger) | ||
| if err != nil { | ||
| panic(err) | ||
| graphqlAsyncClient.logger.Info("failed to fetch initial graphql rate limit, fallback to default: %v", err) | ||
| graphqlAsyncClient.updateRateRemaining(graphqlAsyncClient.rateRemaining, nil) | ||
| } else { | ||
| graphqlAsyncClient.updateRateRemaining(rateRemaining, resetAt) | ||
| } | ||
| graphqlAsyncClient.updateRateRemaining(rateRemaining, resetAt) | ||
| } else { | ||
| graphqlAsyncClient.updateRateRemaining(graphqlAsyncClient.rateRemaining, nil) | ||
| } | ||
|
|
||
| // load retry/timeout from configuration | ||
|
|
@@ -115,6 +139,10 @@ func (apiClient *GraphqlAsyncClient) updateRateRemaining(rateRemaining int, rese | |
| apiClient.rateExhaustCond.Signal() | ||
| } | ||
| go func() { | ||
| if apiClient.getRateRemaining == nil { | ||
| return | ||
| } | ||
|
|
||
| nextDuring := 3 * time.Minute | ||
| if resetAt != nil && resetAt.After(time.Now()) { | ||
| nextDuring = time.Until(*resetAt) | ||
|
|
@@ -126,7 +154,9 @@ func (apiClient *GraphqlAsyncClient) updateRateRemaining(rateRemaining int, rese | |
| case <-time.After(nextDuring): | ||
| newRateRemaining, newResetAt, err := apiClient.getRateRemaining(apiClient.ctx, apiClient.client, apiClient.logger) | ||
| if err != nil { | ||
| panic(err) | ||
| apiClient.logger.Info("failed to update graphql rate limit, will retry next cycle: %v", err) | ||
| apiClient.updateRateRemaining(apiClient.rateRemaining, nil) | ||
| return | ||
| } | ||
| apiClient.updateRateRemaining(newRateRemaining, newResetAt) | ||
| } | ||
|
|
@@ -218,3 +248,25 @@ func (apiClient *GraphqlAsyncClient) Wait() { | |
| func (apiClient *GraphqlAsyncClient) Release() { | ||
| apiClient.cancel() | ||
| } | ||
|
|
||
| // WithFallbackRateLimit sets the initial/fallback rate limit used when | ||
| // rate limit information cannot be fetched dynamically. | ||
| // This value may be overridden later by getRateRemaining. | ||
| func WithFallbackRateLimit(limit int) func(*GraphqlAsyncClient) { | ||
| return func(c *GraphqlAsyncClient) { | ||
| if limit > 0 { | ||
| c.rateRemaining = limit | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // resolveRateLimit returns -1 if GRAPHQL_RATE_LIMIT is not set or invalid | ||
| func resolveRateLimit(taskCtx plugin.TaskContext, logger log.Logger) int { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
| if v := taskCtx.GetConfig("GRAPHQL_RATE_LIMIT"); v != "" { | ||
| if parsed, err := strconv.Atoi(v); err == nil { | ||
| return parsed | ||
| } | ||
| logger.Warn(nil, "invalid GRAPHQL_RATE_LIMIT, using default") | ||
| } | ||
| return -1 | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| /* | ||
| Licensed to the Apache Software Foundation (ASF) under one or more | ||
| contributor license agreements. See the NOTICE file distributed with | ||
| this work for additional information regarding copyright ownership. | ||
| The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package tasks | ||
|
|
||
| import ( | ||
| "net/http" | ||
|
|
||
| "github.com/apache/incubator-devlake/core/errors" | ||
| "github.com/apache/incubator-devlake/core/plugin" | ||
| "github.com/apache/incubator-devlake/plugins/github/models" | ||
| "github.com/apache/incubator-devlake/plugins/github/token" | ||
| ) | ||
|
|
||
| func CreateAuthenticatedHttpClient( | ||
| taskCtx plugin.TaskContext, | ||
| connection *models.GithubConnection, | ||
| baseClient *http.Client, | ||
| ) (*http.Client, errors.Error) { | ||
|
|
||
| logger := taskCtx.GetLogger() | ||
| db := taskCtx.GetDal() | ||
| encryptionSecret := taskCtx.GetConfig(plugin.EncodeKeyEnvStr) | ||
|
|
||
| if baseClient == nil { | ||
| baseClient = &http.Client{} | ||
| } | ||
|
|
||
| // Inject TokenProvider for OAuth refresh or GitHub App installation tokens. | ||
| var tp *token.TokenProvider | ||
| if connection.RefreshToken != "" { | ||
| tp = token.NewTokenProvider(connection, db, baseClient, logger, encryptionSecret) | ||
| } else if connection.AuthMethod == models.AppKey && connection.InstallationID != 0 { | ||
| tp = token.NewAppInstallationTokenProvider(connection, db, baseClient, logger, encryptionSecret) | ||
| } | ||
|
|
||
| baseTransport := baseClient.Transport | ||
| if baseTransport == nil { | ||
| baseTransport = http.DefaultTransport | ||
| } | ||
|
|
||
| if tp != nil { | ||
| baseClient.Transport = token.NewRefreshRoundTripper(baseTransport, tp) | ||
| logger.Info( | ||
| "Installed token refresh round tripper for connection %d (authMethod=%s)", | ||
| connection.ID, | ||
| connection.AuthMethod, | ||
| ) | ||
|
|
||
| } else if connection.Token != "" { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implemented that in static round tripper |
||
| baseClient.Transport = token.NewStaticRoundTripper( | ||
| baseTransport, | ||
| connection.Token, | ||
| ) | ||
| logger.Info( | ||
| "Installed static token round tripper for connection %d", | ||
| connection.ID, | ||
| ) | ||
| } | ||
|
|
||
| // Persist the freshly minted token so the DB has a correctly encrypted value. | ||
| // PrepareApiClient (called by NewApiClientFromConnection) mints the token | ||
| // in-memory but does not persist it; without this, the DB may contain a stale | ||
| // or corrupted token that breaks GET /connections. | ||
| if connection.AuthMethod == models.AppKey && connection.Token != "" { | ||
| if err := token.PersistEncryptedTokenColumns(db, connection, encryptionSecret, logger, false); err != nil { | ||
| logger.Warn(err, "Failed to persist initial token for connection %d", connection.ID) | ||
| } else { | ||
| logger.Info("Persisted initial token for connection %d", connection.ID) | ||
| } | ||
| } | ||
|
|
||
| return baseClient, nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,8 @@ package token | |
|
|
||
| import ( | ||
| "net/http" | ||
| "strings" | ||
| "sync/atomic" | ||
| ) | ||
|
|
||
| // RefreshRoundTripper is an HTTP transport middleware that automatically manages OAuth token refreshes. | ||
|
|
@@ -93,3 +95,36 @@ func (rt *RefreshRoundTripper) roundTripWithRetry(req *http.Request, refreshAtte | |
|
|
||
| return resp, nil | ||
| } | ||
|
|
||
| // StaticRoundTripper is an HTTP transport that injects a fixed bearer token. | ||
| // Unlike RefreshRoundTripper, it does NOT attempt refresh or retries. | ||
| type StaticRoundTripper struct { | ||
| base http.RoundTripper | ||
| tokens []string | ||
| idx atomic.Uint64 | ||
| } | ||
|
|
||
| func NewStaticRoundTripper(base http.RoundTripper, rawToken string) *StaticRoundTripper { | ||
| if base == nil { | ||
| base = http.DefaultTransport | ||
| } | ||
| parts := strings.Split(rawToken, ",") | ||
| tokens := make([]string, 0, len(parts)) | ||
| for _, t := range parts { | ||
| if t = strings.TrimSpace(t); t != "" { | ||
| tokens = append(tokens, t) | ||
| } | ||
| } | ||
| if len(tokens) == 0 { | ||
| tokens = []string{rawToken} | ||
| } | ||
| return &StaticRoundTripper{base: base, tokens: tokens} | ||
| } | ||
|
|
||
| func (rt *StaticRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | ||
| // always overrides headers put by SetupAuthentication, to make sure the token is always injected | ||
| tok := rt.tokens[rt.idx.Add(1)%uint64(len(rt.tokens))] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uint64 starts at 0, |
||
| reqClone := req.Clone(req.Context()) | ||
| reqClone.Header.Set("Authorization", "Bearer "+tok) | ||
| return rt.base.RoundTrip(reqClone) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,10 +20,7 @@ package impl | |
| import ( | ||
| "context" | ||
| "fmt" | ||
| "net/http" | ||
| "net/url" | ||
| "reflect" | ||
| "strings" | ||
| "time" | ||
|
|
||
| "github.com/apache/incubator-devlake/core/models/domainlayer/devops" | ||
|
|
@@ -39,7 +36,6 @@ import ( | |
| "github.com/apache/incubator-devlake/plugins/github_graphql/model/migrationscripts" | ||
| "github.com/apache/incubator-devlake/plugins/github_graphql/tasks" | ||
| "github.com/merico-ai/graphql" | ||
| "golang.org/x/oauth2" | ||
| ) | ||
|
|
||
| // make sure interface is implemented | ||
|
|
@@ -180,46 +176,10 @@ func (p GithubGraphql) PrepareTaskData(taskCtx plugin.TaskContext, options map[s | |
| return nil, err | ||
| } | ||
|
|
||
| tokens := strings.Split(connection.Token, ",") | ||
| src := oauth2.StaticTokenSource( | ||
| &oauth2.Token{AccessToken: tokens[0]}, | ||
| ) | ||
| oauthContext := taskCtx.GetContext() | ||
| proxy := connection.GetProxy() | ||
| if proxy != "" { | ||
| pu, err := url.Parse(proxy) | ||
| if err != nil { | ||
| return nil, errors.Convert(err) | ||
| } | ||
| if pu.Scheme == "http" || pu.Scheme == "socks5" { | ||
| proxyClient := &http.Client{ | ||
| Transport: &http.Transport{Proxy: http.ProxyURL(pu)}, | ||
| } | ||
| oauthContext = context.WithValue( | ||
| taskCtx.GetContext(), | ||
| oauth2.HTTPClient, | ||
| proxyClient, | ||
| ) | ||
| logger.Debug("Proxy set in oauthContext to %s", proxy) | ||
| } else { | ||
| return nil, errors.BadInput.New("Unsupported scheme set in proxy") | ||
| } | ||
| } | ||
|
|
||
| httpClient := oauth2.NewClient(oauthContext, src) | ||
| endpoint, err := errors.Convert01(url.Parse(connection.Endpoint)) | ||
| if err != nil { | ||
| return nil, errors.BadInput.Wrap(err, fmt.Sprintf("malformed connection endpoint supplied: %s", connection.Endpoint)) | ||
| } | ||
|
|
||
| // github.com and github enterprise have different graphql endpoints | ||
| endpoint.Path = "/graphql" // see https://docs.github.com/en/graphql/guides/forming-calls-with-graphql | ||
| if endpoint.Hostname() != "api.github.com" { | ||
| // see https://docs.github.com/en/enterprise-server@3.11/graphql/guides/forming-calls-with-graphql | ||
| endpoint.Path = "/api/graphql" | ||
| } | ||
| client := graphql.NewClient(endpoint.String(), httpClient) | ||
| graphqlClient, err := helper.CreateAsyncGraphqlClient(taskCtx, client, taskCtx.GetLogger(), | ||
| graphqlClient, err := tasks.CreateGraphqlClient( | ||
| taskCtx, | ||
| connection, | ||
| apiClient.ApiClient.GetClient(), | ||
| func(ctx context.Context, client *graphql.Client, logger log.Logger) (rateRemaining int, resetAt *time.Time, err errors.Error) { | ||
| var query GraphQueryRateLimit | ||
| dataErrors, err := errors.Convert01(client.Query(taskCtx.GetContext(), &query, nil)) | ||
|
|
@@ -230,8 +190,7 @@ func (p GithubGraphql) PrepareTaskData(taskCtx plugin.TaskContext, options map[s | |
| return 0, nil, errors.Default.Wrap(dataErrors[0], `query rate limit fail`) | ||
| } | ||
| if query.RateLimit == nil { | ||
| logger.Info(`github graphql rate limit are disabled, fallback to 5000req/hour`) | ||
| return 5000, nil, nil | ||
| return 0, nil, errors.Default.New("rate limit unavailable") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old code returned
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed the Warn logs as Info logs. but kept the error here so the handling is centralized |
||
| } | ||
| logger.Info(`github graphql init success with remaining %d/%d and will reset at %s`, | ||
| query.RateLimit.Remaining, query.RateLimit.Limit, query.RateLimit.ResetAt) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential deadlock on GitHub Enterprise? The Signal() at L139 only fires when
rateRemaining > 0, so all goroutines waiting in Query (L183: for rateRemaining <= 0 { Wait() }) block forever. Since Github returns this error every cycle rate limit never recovers.Check if use
max(apiClient.rateRemaining, defaultRateLimitConst)instead ofapiClient.rateRemainingworks, to ensure the fallback is always positive and Signal() fires.