diff --git a/libs/aitools/installer/installer.go b/libs/aitools/installer/installer.go index 548be8dcf3..d2caca5af0 100644 --- a/libs/aitools/installer/installer.go +++ b/libs/aitools/installer/installer.go @@ -12,6 +12,7 @@ import ( "path/filepath" "slices" "strings" + "sync" "time" "github.com/databricks/cli/internal/build" @@ -21,6 +22,7 @@ import ( "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/log" "golang.org/x/mod/semver" + "golang.org/x/sync/errgroup" ) const ( @@ -28,8 +30,28 @@ const ( skillsRepoName = "databricks-agent-skills" stableSkillsRepoPath = "skills" experimentalRepoPath = "experimental" + + // fetchConcurrency caps the number of in-flight skill file fetches. + // Each file is one HTTPS GET to raw.githubusercontent.com; sequential + // fetches were latency-bound on TLS handshakes. 8 is enough to amortise + // the round-trip across a typical skill's files without overwhelming the + // upstream CDN. + fetchConcurrency = 8 ) +// httpClient is shared across all skill file fetches so the underlying +// transport reuses TCP+TLS connections. The default MaxIdleConnsPerHost +// (2) is bumped to leave headroom above fetchConcurrency so a brief overlap +// between a closing and a new connection doesn't force a fresh handshake. +var httpClient = sync.OnceValue(func() *http.Client { + t := http.DefaultTransport.(*http.Transport).Clone() + t.MaxIdleConnsPerHost = fetchConcurrency * 2 + return &http.Client{ + Timeout: 30 * time.Second, + Transport: t, + } +}) + func manifestHasExperimental(m *Manifest) bool { for _, meta := range m.Skills { if meta.IsExperimental() { @@ -121,8 +143,7 @@ func fetchSkillFile(ctx context.Context, ref, repoDir, skillName, filePath strin return nil, fmt.Errorf("failed to create request: %w", err) } - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := httpClient().Do(req) if err != nil { return nil, fmt.Errorf("failed to fetch %s: %w", filePath, err) } @@ -555,25 +576,29 @@ func installSkillToDir(ctx context.Context, ref, repoDir, skillName, destDir str return fmt.Errorf("failed to create directory: %w", err) } + // Fetch files concurrently. Each file is a separate HTTPS GET, so + // wall-clock time is dominated by per-request TLS handshakes rather + // than payload size. + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(fetchConcurrency) for _, file := range files { - content, err := fetchFileFn(ctx, ref, repoDir, skillName, file) - if err != nil { - return err - } - - destPath := filepath.Join(destDir, file) - - if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - - log.Debugf(ctx, "Downloading %s/%s", skillName, file) - if err := os.WriteFile(destPath, content, 0o644); err != nil { - return fmt.Errorf("failed to write %s: %w", file, err) - } + g.Go(func() error { + content, err := fetchFileFn(gctx, ref, repoDir, skillName, file) + if err != nil { + return err + } + destPath := filepath.Join(destDir, file) + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + log.Debugf(gctx, "Downloading %s/%s", skillName, file) + if err := os.WriteFile(destPath, content, 0o644); err != nil { + return fmt.Errorf("failed to write %s: %w", file, err) + } + return nil + }) } - - return nil + return g.Wait() } // copyDir copies all files from src to dest, recreating the directory structure. diff --git a/libs/aitools/installer/installer_test.go b/libs/aitools/installer/installer_test.go index 710f486143..ad7c1810e9 100644 --- a/libs/aitools/installer/installer_test.go +++ b/libs/aitools/installer/installer_test.go @@ -3,11 +3,14 @@ package installer import ( "bytes" "context" + "errors" "io/fs" "log/slog" "os" "path/filepath" + "sync" "testing" + "time" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/aitools/agents" @@ -188,6 +191,114 @@ func TestBackupThirdPartySkillRegularFile(t *testing.T) { assert.ErrorIs(t, err, fs.ErrNotExist) } +func TestInstallSkillToDirFetchesFilesConcurrently(t *testing.T) { + baseCtx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + ctx := cmdio.MockDiscard(baseCtx) + + orig := fetchFileFn + t.Cleanup(func() { fetchFileFn = orig }) + + started := make(chan string, 2) + release := make(chan struct{}) + releaseOnce := sync.OnceFunc(func() { close(release) }) + t.Cleanup(releaseOnce) + + fetchFileFn = func(ctx context.Context, _, _, _, filePath string) ([]byte, error) { + started <- filePath + select { + case <-release: + return []byte(filePath), nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + destDir := filepath.Join(t.TempDir(), "databricks-test") + errCh := make(chan error, 1) + go func() { + errCh <- installSkillToDir(ctx, testSkillsRef, stableSkillsRepoPath, "databricks-test", destDir, []string{"one.md", "two.md"}) + }() + + fetched := make(map[string]bool, 2) + for range 2 { + select { + case filePath := <-started: + fetched[filePath] = true + case <-ctx.Done(): + require.FailNow(t, "timed out waiting for concurrent fetches to start") + } + } + assert.Equal(t, map[string]bool{"one.md": true, "two.md": true}, fetched) + + releaseOnce() + require.NoError(t, <-errCh) + + one, err := os.ReadFile(filepath.Join(destDir, "one.md")) + require.NoError(t, err) + assert.Equal(t, "one.md", string(one)) + two, err := os.ReadFile(filepath.Join(destDir, "two.md")) + require.NoError(t, err) + assert.Equal(t, "two.md", string(two)) +} + +func TestInstallSkillToDirCancelsInFlightFetchesOnError(t *testing.T) { + baseCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + ctx := cmdio.MockDiscard(baseCtx) + + orig := fetchFileFn + t.Cleanup(func() { fetchFileFn = orig }) + + fetchErr := errors.New("fetch failed") + blockedStarted := make(chan struct{}) + cancelled := make(chan struct{}) + + fetchFileFn = func(ctx context.Context, _, _, _, filePath string) ([]byte, error) { + switch filePath { + case "blocked.md": + close(blockedStarted) + <-ctx.Done() + close(cancelled) + return nil, ctx.Err() + case "fail.md": + select { + case <-blockedStarted: + return nil, fetchErr + case <-ctx.Done(): + return nil, ctx.Err() + } + default: + return []byte(filePath), nil + } + } + + destDir := filepath.Join(t.TempDir(), "databricks-test") + errCh := make(chan error, 1) + go func() { + errCh <- installSkillToDir(ctx, testSkillsRef, stableSkillsRepoPath, "databricks-test", destDir, []string{"blocked.md", "fail.md"}) + }() + + var err error + select { + case err = <-errCh: + case <-time.After(5 * time.Second): + cancel() + select { + case <-errCh: + case <-time.After(time.Second): + } + require.FailNow(t, "timed out waiting for errgroup cancellation") + } + require.ErrorIs(t, err, fetchErr) + + select { + case <-cancelled: + default: + require.Fail(t, "expected in-flight fetch to observe context cancellation") + } +} + // --- InstallSkillsForAgents tests --- func TestInstallSkillsForAgentsWritesState(t *testing.T) {