diff --git a/pkg/cmd/completion_test.go b/pkg/cmd/completion_test.go index 2fc9cb78..e6ee0e15 100644 --- a/pkg/cmd/completion_test.go +++ b/pkg/cmd/completion_test.go @@ -8,6 +8,10 @@ import ( "github.com/stretchr/testify/require" ) +// --------------------------------------------------------------------------- +// Tests from fish-completion branch (preserved) +// --------------------------------------------------------------------------- + func TestDetectShell(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/sentinel.go b/pkg/cmd/sentinel.go new file mode 100644 index 00000000..89a8f695 --- /dev/null +++ b/pkg/cmd/sentinel.go @@ -0,0 +1,221 @@ +package cmd + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// sentinelBegin and sentinelEnd mark the completion configuration block +// in shell config files (~/.zshrc, ~/.bashrc, ~/.bash_profile). This allows +// safe idempotent install/uninstall without corrupting the user's existing config. +const ( + sentinelBegin = "# begin stripe-completion -- managed by stripe cli, do not edit" + sentinelEnd = "# end stripe-completion" +) + +// computeAddSentinel returns the new file content with the sentinel block added +// or replaced. It performs no I/O. If both markers are present in the correct +// order, the existing block is replaced. If markers are absent, orphaned, or +// reversed, a new block is appended to the content. +func computeAddSentinel(content, line string) string { + block := fmt.Sprintf("%s\n%s\n%s", sentinelBegin, line, sentinelEnd) + + beginIdx := strings.Index(content, sentinelBegin) + endIdx := strings.Index(content, sentinelEnd) + if beginIdx >= 0 && endIdx >= 0 && endIdx > beginIdx { + end := endIdx + len(sentinelEnd) + // Include trailing newline if present + if end < len(content) && content[end] == '\n' { + end++ + } + return content[:beginIdx] + block + "\n" + content[end:] + } + + // Append sentinel block + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + return content + block + "\n" +} + +// computeRemoveSentinel returns the new file content with the sentinel block +// removed and a boolean indicating whether a block was found and removed. It +// performs no I/O. If markers are absent, orphaned, or reversed, the content +// is returned unchanged with false. +func computeRemoveSentinel(content string) (string, bool) { + beginIdx := strings.Index(content, sentinelBegin) + endIdx := strings.Index(content, sentinelEnd) + if beginIdx < 0 || endIdx < 0 || endIdx <= beginIdx { + return content, false + } + + end := endIdx + len(sentinelEnd) + // Include trailing newline if present + if end < len(content) && content[end] == '\n' { + end++ + } + return content[:beginIdx] + content[end:], true +} + +// readConfigFile opens the file at path, reads its contents, and returns the +// content string along with the file's permission bits. The file is opened once +// and stat'd on the same file descriptor to avoid TOCTOU races. If the file +// does not exist, ("", 0644, nil) is returned. +func readConfigFile(path string) (string, os.FileMode, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return "", 0644, nil + } + return "", 0, fmt.Errorf("reading %s: %w", path, err) + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + return "", 0, fmt.Errorf("reading %s: %w", path, err) + } + perm := info.Mode().Perm() + + data, err := io.ReadAll(f) + if err != nil { + return "", 0, fmt.Errorf("reading %s: %w", path, err) + } + + return string(data), perm, nil +} + +// atomicWriteFile writes data to path atomically by creating a temporary file +// in the same directory, syncing, and renaming over the destination. This +// avoids partial writes visible to concurrent readers. On any error after the +// temp file is created, the temp file is removed. +func atomicWriteFile(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + tmp, err := os.CreateTemp(dir, ".stripe-*") + if err != nil { + return fmt.Errorf("writing %s: %w", path, err) + } + tmpName := tmp.Name() + + // Ensure cleanup on any error path after file creation. + var writeErr error + defer func() { + if writeErr != nil { + os.Remove(tmpName) + } + }() + + if _, writeErr = tmp.Write(data); writeErr != nil { + tmp.Close() + return fmt.Errorf("writing %s: %w", path, writeErr) + } + if writeErr = tmp.Sync(); writeErr != nil { + tmp.Close() + return fmt.Errorf("writing %s: %w", path, writeErr) + } + if writeErr = tmp.Close(); writeErr != nil { + return fmt.Errorf("writing %s: %w", path, writeErr) + } + if writeErr = os.Chmod(tmpName, perm); writeErr != nil { + return fmt.Errorf("writing %s: %w", path, writeErr) + } + if writeErr = os.Rename(tmpName, path); writeErr != nil { + return fmt.Errorf("writing %s: %w", path, writeErr) + } + return nil +} + +// addSentinelBlock adds or replaces a sentinel-delimited block in the given +// config file. If the file does not exist, it is created with mode 0644. +// Existing file permissions are preserved. The operation is idempotent: +// calling it twice with the same line produces the same result as calling +// it once. If the file contains orphaned or reversed markers, a new block +// is appended rather than attempting to repair the malformed state. +func addSentinelBlock(configPath, line string) error { + content, perm, err := readConfigFile(configPath) + if err != nil { + return err + } + + newContent := computeAddSentinel(content, line) + return atomicWriteFile(configPath, []byte(newContent), perm) +} + +// removeSentinelBlock removes the sentinel-delimited block from the given +// config file. If the file does not exist, this is a no-op. If the markers +// are orphaned or reversed, the file is left unchanged. Existing file +// permissions are preserved. +func removeSentinelBlock(configPath string) error { + content, perm, err := readConfigFile(configPath) + if err != nil { + return err + } + + // If the file did not exist, readConfigFile returns ("", 0644, nil). + // computeRemoveSentinel("") returns ("", false), so !found handles both + // the missing-file case and the no-block-present case uniformly. + newContent, found := computeRemoveSentinel(content) + if !found { + return nil + } + + return atomicWriteFile(configPath, []byte(newContent), perm) +} + +// manualRemnant represents a line in a shell config file that references the +// completion script but is outside our sentinel-managed block. +type manualRemnant struct { + lineNumber int // 1-based, for display in user-facing warnings + lineText string // trimmed content of the matching line +} + +// findManualRemnants scans a shell config file for lines referencing the +// completion script filename that are outside our sentinel block. This detects +// manually-added source/load lines that the user may need to clean up. +// +// Lines inside the sentinel block, blank lines, and comment lines (starting +// with #) are excluded from the scan. Returns nil if the file cannot be read +// or no matches are found. +func findManualRemnants(configPath, scriptFilename string) []manualRemnant { + data, err := os.ReadFile(configPath) + if err != nil { + return nil + } + + var remnants []manualRemnant + inSentinelBlock := false + + for i, line := range strings.Split(string(data), "\n") { + trimmed := strings.TrimSpace(line) + + if strings.Contains(trimmed, sentinelBegin) { + inSentinelBlock = true + continue + } + if strings.Contains(trimmed, sentinelEnd) { + inSentinelBlock = false + continue + } + + if inSentinelBlock { + continue + } + + // Skip blank lines and comments + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + if strings.Contains(trimmed, scriptFilename) { + remnants = append(remnants, manualRemnant{ + lineNumber: i + 1, + lineText: trimmed, + }) + } + } + + return remnants +} diff --git a/pkg/cmd/sentinel_test.go b/pkg/cmd/sentinel_test.go new file mode 100644 index 00000000..dce8980e --- /dev/null +++ b/pkg/cmd/sentinel_test.go @@ -0,0 +1,352 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Pure function tests — no filesystem needed +// --------------------------------------------------------------------------- + +func TestComputeAddSentinel(t *testing.T) { + tests := []struct { + name string + content string + line string + wantHas []string // substrings that must be present + wantNot []string // substrings that must be absent + wantOnce []string // substrings that must appear exactly once + }{ + { + name: "empty content", + content: "", + line: "source ~/.stripe/stripe-completion.zsh", + wantHas: []string{sentinelBegin, "source ~/.stripe/stripe-completion.zsh", sentinelEnd}, + wantOnce: []string{sentinelBegin, sentinelEnd}, + }, + { + name: "existing content without sentinel", + content: "export PATH=/usr/local/bin:$PATH\n", + line: "source ~/.stripe/stripe-completion.zsh", + wantHas: []string{"export PATH=/usr/local/bin:$PATH\n", sentinelBegin, sentinelEnd}, + wantOnce: []string{sentinelBegin, sentinelEnd}, + }, + { + name: "replace existing block", + content: fmt.Sprintf("before\n%s\nold source line\n%s\nafter\n", sentinelBegin, sentinelEnd), + line: "new source line", + wantHas: []string{"before\n", "new source line", "after\n"}, + wantNot: []string{"old source line"}, + wantOnce: []string{sentinelBegin, sentinelEnd}, + }, + { + name: "orphaned begin only — appends new block", + content: fmt.Sprintf("before\n%s\norphaned source line\nafter\n", sentinelBegin), + line: "new source line", + wantHas: []string{"new source line", sentinelEnd}, + }, + { + name: "orphaned end only — appends new block", + content: fmt.Sprintf("before\n%s\nafter\n", sentinelEnd), + line: "new source line", + wantHas: []string{"new source line", sentinelBegin}, + wantOnce: []string{sentinelBegin}, + }, + { + name: "reversed markers — appends new block", + content: fmt.Sprintf("before\n%s\norphaned\n%s\nafter\n", sentinelEnd, sentinelBegin), + line: "new source line", + wantHas: []string{"new source line"}, + }, + { + name: "missing trailing newline — newline inserted before block", + content: "no trailing newline", + line: "source line", + wantHas: []string{"no trailing newline\n" + sentinelBegin}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := computeAddSentinel(tt.content, tt.line) + for _, s := range tt.wantHas { + assert.Contains(t, got, s) + } + for _, s := range tt.wantNot { + assert.NotContains(t, got, s) + } + for _, s := range tt.wantOnce { + assert.Equal(t, 1, strings.Count(got, s), "expected exactly one occurrence of %q", s) + } + }) + } +} + +func TestComputeRemoveSentinel(t *testing.T) { + tests := []struct { + name string + content string + wantFound bool + wantHas []string + wantNot []string + }{ + { + name: "no block present", + content: "export FOO=bar\n", + wantFound: false, + wantHas: []string{"export FOO=bar\n"}, + }, + { + name: "empty content", + content: "", + wantFound: false, + }, + { + name: "valid block removed", + content: fmt.Sprintf("before\n%s\nsource line\n%s\nafter\n", sentinelBegin, sentinelEnd), + wantFound: true, + wantHas: []string{"before\n", "after\n"}, + wantNot: []string{sentinelBegin, sentinelEnd, "source line"}, + }, + { + name: "orphaned begin only — no-op", + content: fmt.Sprintf("before\n%s\norphaned\nafter\n", sentinelBegin), + wantFound: false, + wantHas: []string{sentinelBegin, "orphaned"}, + }, + { + name: "orphaned end only — no-op", + content: fmt.Sprintf("before\n%s\nafter\n", sentinelEnd), + wantFound: false, + wantHas: []string{sentinelEnd}, + }, + { + name: "reversed markers — no-op", + content: fmt.Sprintf("before\n%s\norphaned\n%s\nafter\n", sentinelEnd, sentinelBegin), + wantFound: false, + wantHas: []string{sentinelEnd, sentinelBegin}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, found := computeRemoveSentinel(tt.content) + assert.Equal(t, tt.wantFound, found) + for _, s := range tt.wantHas { + assert.Contains(t, got, s) + } + for _, s := range tt.wantNot { + assert.NotContains(t, got, s) + } + }) + } +} + +// --------------------------------------------------------------------------- +// I/O wrapper tests — addSentinelBlock / removeSentinelBlock +// --------------------------------------------------------------------------- + +func TestAddSentinelBlockNewFile(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, ".zshrc") + + err := addSentinelBlock(configPath, "source /home/user/.stripe/stripe-completion.zsh") + require.NoError(t, err) + + data, err := os.ReadFile(configPath) + require.NoError(t, err) + + content := string(data) + assert.Contains(t, content, sentinelBegin) + assert.Contains(t, content, "source /home/user/.stripe/stripe-completion.zsh") + assert.Contains(t, content, sentinelEnd) +} + +func TestAddSentinelBlockPreservesPermissions(t *testing.T) { + if runtime.GOOS == "windows" || os.Getuid() == 0 { + t.Skip("Cannot test Unix file permissions on Windows or as root") + } + + dir := t.TempDir() + configPath := filepath.Join(dir, ".zshrc") + require.NoError(t, os.WriteFile(configPath, []byte("existing\n"), 0600)) + + err := addSentinelBlock(configPath, "source line") + require.NoError(t, err) + + info, err := os.Stat(configPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0600), info.Mode().Perm(), "file permissions should be preserved") +} + +func TestAddSentinelBlockReadPermissionDenied(t *testing.T) { + if runtime.GOOS == "windows" || os.Getuid() == 0 { + t.Skip("Cannot test Unix file permissions on Windows or as root") + } + + dir := t.TempDir() + configPath := filepath.Join(dir, ".zshrc") + require.NoError(t, os.WriteFile(configPath, []byte("content"), 0644)) + require.NoError(t, os.Chmod(configPath, 0000)) + t.Cleanup(func() { os.Chmod(configPath, 0644) }) + + err := addSentinelBlock(configPath, "source line") + assert.Error(t, err) +} + +func TestAddSentinelBlockWritePermissionDenied(t *testing.T) { + if runtime.GOOS == "windows" || os.Getuid() == 0 { + t.Skip("Cannot test Unix file permissions on Windows or as root") + } + + // Atomic write creates a new temp file and renames it. To prevent the write, + // we make the containing directory read-only so CreateTemp fails. + dir := t.TempDir() + configPath := filepath.Join(dir, ".zshrc") + require.NoError(t, os.WriteFile(configPath, []byte("existing\n"), 0644)) + require.NoError(t, os.Chmod(dir, 0555)) + t.Cleanup(func() { os.Chmod(dir, 0755) }) + + err := addSentinelBlock(configPath, "source line") + assert.Error(t, err) +} + +func TestRemoveSentinelBlockMissingFile(t *testing.T) { + err := removeSentinelBlock(filepath.Join(t.TempDir(), "nonexistent")) + assert.NoError(t, err) +} + +func TestRemoveSentinelBlockPreservesPermissions(t *testing.T) { + if runtime.GOOS == "windows" || os.Getuid() == 0 { + t.Skip("Cannot test Unix file permissions on Windows or as root") + } + + dir := t.TempDir() + configPath := filepath.Join(dir, ".zshrc") + content := fmt.Sprintf("before\n%s\nline\n%s\nafter\n", sentinelBegin, sentinelEnd) + require.NoError(t, os.WriteFile(configPath, []byte(content), 0600)) + + err := removeSentinelBlock(configPath) + require.NoError(t, err) + + info, err := os.Stat(configPath) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0600), info.Mode().Perm(), "file permissions should be preserved") +} + +// --------------------------------------------------------------------------- +// findManualRemnants +// --------------------------------------------------------------------------- + +func TestFindManualRemnants(t *testing.T) { + tests := []struct { + name string + content string + scriptFilename string + wantLen int + wantLineNums []int + }{ + { + name: "detects manual source line", + content: "export PATH=/usr/local/bin:$PATH\nsource ~/.stripe/stripe-completion.zsh\nalias ls='ls -G'\n", + scriptFilename: "stripe-completion.zsh", + wantLen: 1, + wantLineNums: []int{2}, + }, + { + name: "detects dot-source syntax", + content: ". /some/custom/path/stripe-completion.bash\n", + scriptFilename: "stripe-completion.bash", + wantLen: 1, + wantLineNums: []int{1}, + }, + { + name: "detects line with other commands", + content: "[ -f ~/.stripe/stripe-completion.zsh ] && source ~/.stripe/stripe-completion.zsh\n", + scriptFilename: "stripe-completion.zsh", + wantLen: 1, + wantLineNums: []int{1}, + }, + { + name: "detects custom path", + content: "source /opt/completions/stripe-completion.zsh\n", + scriptFilename: "stripe-completion.zsh", + wantLen: 1, + }, + { + name: "ignores lines inside sentinel block", + content: fmt.Sprintf("before\n%s\nsource ~/.stripe/stripe-completion.zsh\n%s\nafter\n", + sentinelBegin, sentinelEnd), + scriptFilename: "stripe-completion.zsh", + wantLen: 0, + }, + { + name: "ignores comment lines", + content: "# source ~/.stripe/stripe-completion.zsh\n", + scriptFilename: "stripe-completion.zsh", + wantLen: 0, + }, + { + name: "no match returns nil", + content: "export PATH=/usr/local/bin:$PATH\nalias ls='ls -G'\n", + scriptFilename: "stripe-completion.zsh", + wantLen: 0, + }, + { + name: "multiple matches", + content: "source ~/.stripe/stripe-completion.zsh\nexport FOO=bar\n. /other/stripe-completion.zsh\n", + scriptFilename: "stripe-completion.zsh", + wantLen: 2, + wantLineNums: []int{1, 3}, + }, + { + name: "manual line before sentinel block only", + content: fmt.Sprintf("source ~/my/stripe-completion.zsh\n%s\nsource ~/.stripe/stripe-completion.zsh\n%s\n", + sentinelBegin, sentinelEnd), + scriptFilename: "stripe-completion.zsh", + wantLen: 1, + wantLineNums: []int{1}, + }, + { + name: "manual line after sentinel block only", + content: fmt.Sprintf("%s\nsource ~/.stripe/stripe-completion.zsh\n%s\nsource ~/custom/stripe-completion.zsh\n", + sentinelBegin, sentinelEnd), + scriptFilename: "stripe-completion.zsh", + wantLen: 1, + wantLineNums: []int{4}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config") + require.NoError(t, os.WriteFile(configPath, []byte(tt.content), 0644)) + + remnants := findManualRemnants(configPath, tt.scriptFilename) + + if tt.wantLen == 0 { + // nil and empty slice are both acceptable for "no results" + assert.Len(t, remnants, 0) + } else { + require.Len(t, remnants, tt.wantLen) + for i, wantLine := range tt.wantLineNums { + assert.Equal(t, wantLine, remnants[i].lineNumber) + } + } + }) + } +} + +func TestFindManualRemnantsReturnsNilForMissingFile(t *testing.T) { + remnants := findManualRemnants(filepath.Join(t.TempDir(), "nonexistent"), "stripe-completion.zsh") + assert.Nil(t, remnants) +}