Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/shush
/dist
*.sock
._*
.DS_Store
1 change: 1 addition & 0 deletions capability.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Capability token resolution: --capability/--cap flag > SHUSH_CAPABILITY env > empty.
package shush

import (
Expand Down
56 changes: 45 additions & 11 deletions client_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ func Ping(socketPath string) error {

// PingWithCapability checks if an agent is running with the provided capability.
func PingWithCapability(socketPath, capability string) error {
client := &Client{SocketPath: socketPath, Capability: capability}
_, err := client.do(request{Action: actionPing})
_, err := newAgentClient(socketPath, capability).do(request{Action: actionPing})
return err
}

Expand All @@ -27,14 +26,19 @@ func Stop(socketPath string) error {

// StopWithCapability requests the agent to shut down with the provided capability.
func StopWithCapability(socketPath, capability string) error {
client := &Client{SocketPath: socketPath, Capability: capability}
_, err := client.do(request{Action: actionStop})
_, err := newAgentClient(socketPath, capability).do(request{Action: actionStop})
return err
}

// ClearAll removes all cached secrets and tokens from the agent.
func ClearAll(socketPath string) error {
client := &Client{SocketPath: socketPath, Capability: ResolveCapability(nil)}
return ClearAllWithCapability(socketPath, ResolveCapability(nil))
}

// ClearAllWithCapability removes all cached secrets and tokens using the
// provided capability.
func ClearAllWithCapability(socketPath, capability string) error {
client := newAgentClient(socketPath, capability)
if _, err := client.do(request{Action: actionClear}); err != nil {
return err
}
Expand All @@ -44,14 +48,29 @@ func ClearAll(socketPath string) error {

// ClearPrefix removes all cached secrets whose keys start with the prefix.
func ClearPrefix(socketPath, prefix string) error {
client := &Client{SocketPath: socketPath, Capability: ResolveCapability(nil)}
return client.ClearPrefix(prefix)
return ClearPrefixWithCapability(socketPath, prefix, ResolveCapability(nil))
}

// ClearPrefixWithCapability removes cached secrets by prefix using the
// provided capability.
func ClearPrefixWithCapability(socketPath, prefix, capability string) error {
return newAgentClient(socketPath, capability).ClearPrefix(prefix)
}

// ClearByKeyPattern removes all cached secrets matching the glob pattern.
func ClearByKeyPattern(socketPath, pattern string) error {
client := &Client{SocketPath: socketPath, Capability: ResolveCapability(nil)}
return client.ClearByKeyPattern(pattern)
return ClearByKeyPatternWithCapability(socketPath, pattern, ResolveCapability(nil))
}

// ClearByKeyPatternWithCapability removes cached secrets by glob pattern
// using the provided capability.
func ClearByKeyPatternWithCapability(socketPath, pattern, capability string) error {
return newAgentClient(socketPath, capability).ClearByKeyPattern(pattern)
}

// newAgentClient builds a Client used for one-shot agent control calls.
func newAgentClient(socketPath, capability string) *Client {
return &Client{SocketPath: socketPath, Capability: capability}
}

// StartProcess starts an agent server in a background process.
Expand All @@ -66,6 +85,11 @@ func StartProcessWithCapability(socketPath string, serveArgs []string, capabilit
return startProcess(socketPath, serveArgs, capability)
}

const (
agentStartupAttempts = 40
agentStartupInterval = 25 * time.Millisecond
)

func startProcess(socketPath string, serveArgs []string, capability string) error {
capability = strings.TrimSpace(capability)
if err := PingWithCapability(socketPath, capability); err == nil {
Expand All @@ -74,6 +98,13 @@ func startProcess(socketPath string, serveArgs []string, capability string) erro
return err
}

if err := spawnAgentProcess(socketPath, serveArgs, capability); err != nil {
return err
}
return waitForAgentReady(socketPath, capability)
}

func spawnAgentProcess(socketPath string, serveArgs []string, capability string) error {
if len(serveArgs) == 0 {
exePath, err := os.Executable()
if err != nil {
Expand All @@ -94,10 +125,13 @@ func startProcess(socketPath string, serveArgs []string, capability string) erro
return fmt.Errorf("start agent process: %w", err)
}
_ = cmd.Process.Release()
return nil
}

func waitForAgentReady(socketPath, capability string) error {
var lastErr error
for i := 0; i < 40; i++ {
time.Sleep(25 * time.Millisecond)
for i := 0; i < agentStartupAttempts; i++ {
time.Sleep(agentStartupInterval)
lastErr = PingWithCapability(socketPath, capability)
if lastErr == nil {
return nil
Expand Down
68 changes: 49 additions & 19 deletions client_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strings"
"syscall"
"time"
)

Expand Down Expand Up @@ -71,17 +74,17 @@ func (c *Client) Set(secret []byte) error {
SecretB64: base64.StdEncoding.EncodeToString(secret),
ExpiresAt: time.Now().Add(c.TTL).Unix(),
}
if _, err := c.do(req); err == nil {
return nil
}

capability := strings.TrimSpace(c.Capability)
if capability == "" {
capability = ResolveCapability(nil)
}
if err := startProcess(c.SocketPath, c.ServeArgs, capability); err != nil {
return err
if err := PingWithCapability(c.SocketPath, capability); err != nil {
if !isDialError(err) {
return err
}
if err := startProcess(c.SocketPath, c.ServeArgs, capability); err != nil {
return err
}
}

_, err := c.do(req)
return err
}
Expand All @@ -95,10 +98,7 @@ func (c *Client) Clear() error {
Action: actionClear,
Key: c.Key,
})
if err != nil {
return nil
}
return nil
return err
}

// ClearPrefix removes cached secrets whose keys start with the provided prefix.
Expand Down Expand Up @@ -128,9 +128,6 @@ func (c *Client) ClearByKeyPattern(pattern string) error {
func (c *Client) do(req request) (response, error) {
if strings.TrimSpace(req.Capability) == "" {
req.Capability = strings.TrimSpace(c.Capability)
if req.Capability == "" {
req.Capability = ResolveCapability(nil)
}
}

conn, err := net.DialTimeout("unix", c.SocketPath, DefaultDialTimeout)
Expand All @@ -152,10 +149,43 @@ func (c *Client) do(req request) (response, error) {
return response{}, err
}
if !resp.OK {
if strings.TrimSpace(resp.Error) == "" {
return response{}, errors.New("agent request failed")
}
return response{}, errors.New(resp.Error)
return response{}, decodeRemoteError(resp.Error)
}
return resp, nil
}

// isDialError reports whether err indicates that the agent socket is
// unreachable (missing or refusing connections), distinguishing transport
// failures from protocol-level rejections like capability errors.
func isDialError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) {
return true
}
if os.IsNotExist(err) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" {
return true
}
return false
}

// decodeRemoteError rebinds sentinel errors that travel as strings over the
// wire so callers can use errors.Is across the process boundary.
func decodeRemoteError(msg string) error {
trimmed := strings.TrimSpace(msg)
if trimmed == "" {
return errors.New("agent request failed")
}
switch trimmed {
case ErrCapabilityRequired.Error():
return fmt.Errorf("%w", ErrCapabilityRequired)
case ErrCapabilityInvalid.Error():
return fmt.Errorf("%w", ErrCapabilityInvalid)
}
return errors.New(trimmed)
}
36 changes: 30 additions & 6 deletions client_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func RegisterTokenWithCapability(socketPath, capability, secret string, ttl time
if err != nil {
return "", nil, err
}
client := &Client{SocketPath: socketPath, Capability: capability}
client := newAgentClient(socketPath, capability)
_, err = client.do(request{
Action: actionSetToken,
Token: token,
Expand All @@ -58,28 +58,52 @@ func RegisterTokenWithCapability(socketPath, capability, secret string, ttl time
}

// ResolveToken retrieves and consumes a one-time token.
//
// Deprecated: use ResolveTokenBytes to keep the secret as []byte and avoid
// leaving an immutable string copy in memory.
func ResolveToken(socketPath, token string) (string, error) {
return ResolveTokenWithCapability(socketPath, ResolveCapability(nil), token)
}

// ResolveTokenWithCapability retrieves and consumes a token with capability.
//
// Deprecated: use ResolveTokenBytesWithCapability to keep the secret as
// []byte and avoid leaving an immutable string copy in memory.
func ResolveTokenWithCapability(socketPath, capability, token string) (string, error) {
client := &Client{SocketPath: socketPath, Capability: capability}
secret, err := ResolveTokenBytesWithCapability(socketPath, capability, token)
if err != nil {
return "", err
}
out := string(secret)
Wipe(secret)
return out, nil
}

// ResolveTokenBytes retrieves and consumes a one-time token, returning the
// secret as a mutable []byte so callers can Wipe it after use.
func ResolveTokenBytes(socketPath, token string) ([]byte, error) {
return ResolveTokenBytesWithCapability(socketPath, ResolveCapability(nil), token)
}

// ResolveTokenBytesWithCapability retrieves and consumes a token with
// capability, returning the secret as a mutable []byte.
func ResolveTokenBytesWithCapability(socketPath, capability, token string) ([]byte, error) {
client := newAgentClient(socketPath, capability)
resp, err := client.do(request{
Action: actionGetToken,
Token: token,
})
if err != nil {
return "", err
return nil, err
}
if !resp.Found {
return "", errors.New("token not found or expired")
return nil, errors.New("token not found or expired")
}
secret, err := base64.StdEncoding.DecodeString(resp.SecretB64)
if err != nil || len(secret) == 0 {
return "", errors.New("invalid token secret payload")
return nil, errors.New("invalid token secret payload")
}
return string(secret), nil
return secret, nil
}

func generateToken() (string, error) {
Expand Down
Loading