diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 02ab492..48d3517 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -54,7 +54,7 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error { func printUsage(w io.Writer) { fmt.Fprintln(w, `Usage: - sth start [--host 0.0.0.0] [--port 3939] [--db /path/to/sessions.db] + sth start [--host 0.0.0.0] [--bind 0.0.0.0] [--port 3939] [--server-name ] [--db /path/to/sessions.db] sth send --tag --category --project [--server http://127.0.0.1:3939] sth tag [--rm] [--db /path/to/sessions.db] [--server http://...] sth categorize [--db /path/to/sessions.db] [--server http://...] @@ -70,9 +70,16 @@ func runStart(args []string, stdout io.Writer) error { return err } - host := server.DefaultHost + bindAddr := server.DefaultHost + _, hasHost := flags["host"] + _, hasBind := flags["bind"] + if hasHost && hasBind { + fmt.Fprintln(os.Stderr, "warning: --host is deprecated, use --bind instead; --host takes precedence") + } if value, ok := flags["host"]; ok { - host = value + bindAddr = value + } else if value, ok := flags["bind"]; ok { + bindAddr = value } port := server.DefaultPort @@ -87,12 +94,17 @@ func runStart(args []string, stdout io.Writer) error { return errors.New("port must be a positive integer") } + var serverName string + if value, ok := flags["server-name"]; ok { + serverName = value + } + store, err := openStore(flags) if err != nil { return err } - srv, err := server.New(host, port, store) + srv, err := server.New(bindAddr, port, store, serverName) if err != nil { return errors.Join(err, store.Close()) } @@ -103,7 +115,7 @@ func runStart(args []string, stdout io.Writer) error { origins := srv.Origins() if len(origins) == 0 { - fmt.Fprintf(stdout, "HTML server listening on %s:%d\n", host, port) + fmt.Fprintf(stdout, "HTML server listening on %s:%d\n", bindAddr, port) } else if len(origins) == 1 { fmt.Fprintf(stdout, "HTML server listening on %s\n", origins[0]) } else { diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 2e482ae..6cdd018 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -22,7 +22,7 @@ func TestSendPrintsSessionURL(t *testing.T) { t.Parallel() store := newTestStore(t) - srv, err := server.New("127.0.0.1", 0, store) + srv, err := server.New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } diff --git a/internal/cli/metadata_test.go b/internal/cli/metadata_test.go index 4efd071..9a05edf 100644 --- a/internal/cli/metadata_test.go +++ b/internal/cli/metadata_test.go @@ -325,7 +325,7 @@ func TestHomePageShowsMetadata(t *testing.T) { } store := newTestStore(t) - srv, err := server.New("127.0.0.1", 0, store) + srv, err := server.New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -381,7 +381,7 @@ func TestHomePageFilterByTag(t *testing.T) { } store := newTestStore(t) - srv, err := server.New("127.0.0.1", 0, store) + srv, err := server.New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -429,7 +429,7 @@ func TestHomePageSearch(t *testing.T) { t.Parallel() store := newTestStore(t) - srv, err := server.New("127.0.0.1", 0, store) + srv, err := server.New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } diff --git a/internal/server/server.go b/internal/server/server.go index 684ed2c..74fc184 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,6 +15,7 @@ import ( "os" "path" "path/filepath" + "regexp" "strings" "sync" @@ -29,9 +30,23 @@ const ( maxArchiveFiles = 2048 ) +// hostnamePattern matches valid hostnames with dot-separated labels. +// Each label starts and ends with [a-zA-Z0-9]; inner characters allow hyphens. +// Consecutive dots are inherently rejected because each label requires +// at least one alphanumeric character between dots. +var hostnamePattern = regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?)*$`) + +func isValidServerName(name string) bool { + if ip := net.ParseIP(name); ip != nil { + return ip.To4() != nil + } + return hostnamePattern.MatchString(name) +} + type Server struct { host string port int + serverName string store *session.Store httpServer *http.Server listener net.Listener @@ -401,19 +416,26 @@ var homePageTemplate = template.Must(template.New("home").Parse(` `)) -func New(host string, port int, store *session.Store) (*Server, error) { +func New(host string, port int, store *session.Store, serverName string) (*Server, error) { if host == "" { host = DefaultHost } + if serverName != "" { + if !isValidServerName(serverName) { + return nil, errors.New("server: --server-name must be a valid IPv4 address or hostname (letters, digits, dots, hyphens)") + } + } + if store == nil { return nil, errors.New("server: store must not be nil") } srv := &Server{ - host: host, - port: port, - store: store, + host: host, + port: port, + serverName: serverName, + store: store, } srv.httpServer = &http.Server{ @@ -474,6 +496,10 @@ func (s *Server) Origin() string { return "" } + // NOTE: scheme is hardcoded to http, consistent with the rest of the codebase. + if s.serverName != "" { + return fmt.Sprintf("http://%s:%d", s.serverName, address.Port) + } ip := address.IP if ip.IsUnspecified() { ip = net.IPv4(127, 0, 0, 1) @@ -497,6 +523,12 @@ func (s *Server) Origins() []string { port := address.Port + if s.serverName != "" { + return []string{ + fmt.Sprintf("http://%s:%d", s.serverName, port), + } + } + if address.IP.IsUnspecified() { origins := []string{fmt.Sprintf("http://127.0.0.1:%d", port)} @@ -807,10 +839,10 @@ func (s *Server) handleCreatePathSession(w http.ResponseWriter, r *http.Request) return } - baseURL := baseURL(r) + base := s.serverBaseURL(r) response := createSessionResponse{ SessionID: session.ID, - URL: baseURL + "/s/" + session.ID + "/", + URL: base + "/s/" + session.ID + "/", EntryFile: session.EntryFile, RootDir: session.RootDir, } @@ -927,10 +959,10 @@ func (s *Server) handleCreateUploadedSession(w http.ResponseWriter, r *http.Requ return } - baseURL := baseURL(r) + base := s.serverBaseURL(r) response := createSessionResponse{ SessionID: session.ID, - URL: baseURL + "/s/" + session.ID + "/", + URL: base + "/s/" + session.ID + "/", EntryFile: session.EntryFile, RootDir: session.RootDir, } @@ -1294,13 +1326,39 @@ func writeJSON(w http.ResponseWriter, status int, payload any) { _ = json.NewEncoder(w).Encode(payload) } +// determineScheme returns the URL scheme for the request. +// It trusts the X-Forwarded-Proto header, so the server must +// run behind a trusted reverse proxy when relying on this logic. +func determineScheme(r *http.Request) string { + if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil { + return "https" + } + return "http" +} + func baseURL(r *http.Request) string { - scheme := "http" - if r.TLS != nil { - scheme = "https" + return determineScheme(r) + "://" + r.Host +} + +func (s *Server) serverBaseURL(r *http.Request) string { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.serverName != "" { + port := s.port + if s.listener != nil { + if addr, ok := s.listener.Addr().(*net.TCPAddr); ok { + port = addr.Port + } + } + scheme := determineScheme(r) + if (scheme == "http" && port == 80) || (scheme == "https" && port == 443) { + return fmt.Sprintf("%s://%s", scheme, s.serverName) + } + return fmt.Sprintf("%s://%s:%d", scheme, s.serverName, port) } - return scheme + "://" + r.Host + return baseURL(r) } func hasPrefixSuffix(path, prefix, suffix string) bool { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index d3a3393..6546a99 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -5,12 +5,15 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "mime/multipart" + "net" "net/http" "net/url" "os" "path/filepath" + "strings" "testing" "time" @@ -20,12 +23,285 @@ import ( func TestNewRequiresStore(t *testing.T) { t.Parallel() - _, err := New("127.0.0.1", 0, nil) + _, err := New("127.0.0.1", 0, nil, "") if err == nil { t.Fatal("expected nil store to be rejected") } } +func TestServerNameOverridesOrigin(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + srv, err := New("127.0.0.1", 0, store, "192.168.2.14") + if err != nil { + t.Fatal(err) + } + if err := srv.Start(); err != nil { + t.Fatal(err) + } + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _ = srv.Stop(ctx) + }() + + origin := srv.Origin() + if !strings.HasPrefix(origin, "http://192.168.2.14:") { + t.Fatalf("expected origin to use server-name 192.168.2.14, got %q", origin) + } + + origins := srv.Origins() + if len(origins) != 1 { + t.Fatalf("expected 1 origin with server-name, got %d", len(origins)) + } + if !strings.HasPrefix(origins[0], "http://192.168.2.14:") { + t.Fatalf("expected origins[0] to use server-name 192.168.2.14, got %q", origins[0]) + } +} + +func TestServerNameDomain(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + srv, err := New("127.0.0.1", 0, store, "myhost.local") + if err != nil { + t.Fatal(err) + } + if err := srv.Start(); err != nil { + t.Fatal(err) + } + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _ = srv.Stop(ctx) + }() + + origin := srv.Origin() + if !strings.HasPrefix(origin, "http://myhost.local:") { + t.Fatalf("expected origin to use server-name myhost.local, got %q", origin) + } +} + +func TestServerNameEmptyFallsBack(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + srv, err := New("127.0.0.1", 0, store, "") + if err != nil { + t.Fatal(err) + } + if err := srv.Start(); err != nil { + t.Fatal(err) + } + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _ = srv.Stop(ctx) + }() + + origin := srv.Origin() + if !strings.HasPrefix(origin, "http://127.0.0.1:") { + t.Fatalf("expected origin to default to 127.0.0.1, got %q", origin) + } +} + +func TestServerNameValidation(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + + for _, invalid := range []string{ + "has space", + "has/slash", + "has:colon", + "http://host", + "https://host", + "host@evil", + "host#fragment", + "host?query=1", + "host%20name", + "host\nnewline", + "host\rtab", + "host\ttab", + "..", + ".start", + "end.", + "-hyphen", + "hyphen-", + "2001:db8::1", + "::1", + "fe80::1%eth0", + "a..b", + "host..domain.com", + } { + _, err := New("127.0.0.1", 0, store, invalid) + if err == nil { + t.Errorf("expected serverName %q to be rejected", invalid) + } + } +} + +func TestServerNameUsedInSessionURL(t *testing.T) { + t.Parallel() + + fixtureHTML, err := filepath.Abs(filepath.Join("..", "..", "fixtures", "basic", "index.html")) + if err != nil { + t.Fatal(err) + } + + store := newTestStore(t) + srv, err := New("127.0.0.1", 0, store, "192.168.2.14") + if err != nil { + t.Fatal(err) + } + if err := srv.Start(); err != nil { + t.Fatal(err) + } + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _ = srv.Stop(ctx) + }() + + body, err := json.Marshal(map[string]any{ + "filePath": fixtureHTML, + "tags": []string{"test"}, + "category": "test", + "project": "test", + }) + if err != nil { + t.Fatal(err) + } + + srv.mu.RLock() + actualAddr, ok := srv.listener.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("listener address is not TCP: %T", srv.listener.Addr()) + } + srv.mu.RUnlock() + actualURL := fmt.Sprintf("http://127.0.0.1:%d", actualAddr.Port) + + createResp, err := http.Post(actualURL+"/api/sessions", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + defer createResp.Body.Close() + + if createResp.StatusCode != http.StatusCreated { + respBody, _ := io.ReadAll(createResp.Body) + t.Fatalf("unexpected status: %d body=%s", createResp.StatusCode, respBody) + } + + var payload struct { + URL string `json:"url"` + } + if err := json.NewDecoder(createResp.Body).Decode(&payload); err != nil { + t.Fatal(err) + } + + expectedPrefix := "http://192.168.2.14:" + if !strings.HasPrefix(payload.URL, expectedPrefix) { + t.Fatalf("expected session URL to use serverName, got %q", payload.URL) + } +} + +func TestServerNameValidationAcceptsValid(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + + for _, valid := range []string{ + "192.168.2.14", + "myhost.local", + "sub.domain.example.com", + "host-with-dashes.local", + "a", + "localhost", + "10.0.0.1", + "255.255.255.255", + "host123", + } { + _, err := New("127.0.0.1", 0, store, valid) + if err != nil { + t.Errorf("expected serverName %q to be accepted, got error: %v", valid, err) + } + } +} + + +func TestServerBaseURLDefaultPort(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverName string + port int + scheme string + want string + }{ + {"http port 80 omits port", "example.com", 80, "http", "http://example.com"}, + {"https port 443 omits port", "example.com", 443, "https", "https://example.com"}, + {"http port 8080 includes port", "example.com", 8080, "http", "http://example.com:8080"}, + {"https port 8443 includes port", "example.com", 8443, "https", "https://example.com:8443"}, + {"http port 3939 includes port", "192.168.2.14", 3939, "http", "http://192.168.2.14:3939"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := newTestStore(t) + srv, err := New("127.0.0.1", tt.port, store, tt.serverName) + if err != nil { + t.Fatal(err) + } + + r := &http.Request{TLS: nil, Header: http.Header{}} + if tt.scheme == "https" { + r.Header.Set("X-Forwarded-Proto", "https") + } + + got := srv.serverBaseURL(r) + if got != tt.want { + t.Errorf("serverBaseURL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestServerBaseURLFallback(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + srv, err := New("127.0.0.1", 3939, store, "") + if err != nil { + t.Fatal(err) + } + + r := &http.Request{TLS: nil, Host: "127.0.0.1:3939", Header: http.Header{}} + got := srv.serverBaseURL(r) + if got != "http://127.0.0.1:3939" { + t.Errorf("serverBaseURL() fallback = %q, want http://127.0.0.1:3939", got) + } +} + +func TestServerBaseURLForwardedProto(t *testing.T) { + t.Parallel() + + store := newTestStore(t) + srv, err := New("127.0.0.1", 443, store, "secure.example.com") + if err != nil { + t.Fatal(err) + } + + r := &http.Request{TLS: nil, Host: "secure.example.com", Header: http.Header{}} + r.Header.Set("X-Forwarded-Proto", "https") + + got := srv.serverBaseURL(r) + if got != "https://secure.example.com" { + t.Errorf("serverBaseURL() with X-Forwarded-Proto = %q, want https://secure.example.com", got) + } +} + func TestCreateSessionAndServeAssets(t *testing.T) { t.Parallel() @@ -36,7 +312,7 @@ func TestCreateSessionAndServeAssets(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -130,7 +406,7 @@ func TestTraversalIsRejected(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -194,7 +470,7 @@ func TestCreateUploadedSessionAndServeAssets(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -372,7 +648,7 @@ func TestDeleteSessionSuccess(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -416,7 +692,7 @@ func TestDeleteSessionNotFound(t *testing.T) { t.Parallel() store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -455,7 +731,7 @@ func TestSearchByFileContent(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -500,7 +776,7 @@ func TestSearchNoResults(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -546,7 +822,7 @@ func TestSearchContentNoDuplicate(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -592,7 +868,7 @@ func TestDeleteSessionIdempotent(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -634,7 +910,7 @@ func TestDownloadSession(t *testing.T) { } store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) } @@ -692,7 +968,7 @@ func TestDownloadSessionNotFound(t *testing.T) { t.Parallel() store := newTestStore(t) - srv, err := New("127.0.0.1", 0, store) + srv, err := New("127.0.0.1", 0, store, "") if err != nil { t.Fatal(err) }