Skip to content
22 changes: 17 additions & 5 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error {

func printUsage(w io.Writer) {
fmt.Fprintln(w, `Usage:
sth start [--host 0.0.0.0] [--port 3939] [--db /path/to/sessions.db]
sth start [--host 0.0.0.0] [--bind 0.0.0.0] [--port 3939] [--server-name <addr>] [--db /path/to/sessions.db]
sth send <file.html> --tag <tag1,tag2,...> --category <cat> --project <proj> [--server http://127.0.0.1:3939]
sth tag [--rm] <session-id> <tag...> [--db /path/to/sessions.db] [--server http://...]
sth categorize <session-id> <category> [--db /path/to/sessions.db] [--server http://...]
Expand All @@ -70,9 +70,16 @@ func runStart(args []string, stdout io.Writer) error {
return err
}

host := server.DefaultHost
bindAddr := server.DefaultHost
_, hasHost := flags["host"]
_, hasBind := flags["bind"]
if hasHost && hasBind {
fmt.Fprintln(os.Stderr, "warning: --host is deprecated, use --bind instead; --host takes precedence")
}
if value, ok := flags["host"]; ok {
host = value
bindAddr = value
} else if value, ok := flags["bind"]; ok {
bindAddr = value
}

port := server.DefaultPort
Expand All @@ -87,12 +94,17 @@ func runStart(args []string, stdout io.Writer) error {
return errors.New("port must be a positive integer")
}

var serverName string
if value, ok := flags["server-name"]; ok {
serverName = value
}

store, err := openStore(flags)
if err != nil {
return err
}

srv, err := server.New(host, port, store)
srv, err := server.New(bindAddr, port, store, serverName)
if err != nil {
return errors.Join(err, store.Close())
}
Expand All @@ -103,7 +115,7 @@ func runStart(args []string, stdout io.Writer) error {

origins := srv.Origins()
if len(origins) == 0 {
fmt.Fprintf(stdout, "HTML server listening on %s:%d\n", host, port)
fmt.Fprintf(stdout, "HTML server listening on %s:%d\n", bindAddr, port)
} else if len(origins) == 1 {
fmt.Fprintf(stdout, "HTML server listening on %s\n", origins[0])
} else {
Expand Down
2 changes: 1 addition & 1 deletion internal/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestSendPrintsSessionURL(t *testing.T) {
t.Parallel()

store := newTestStore(t)
srv, err := server.New("127.0.0.1", 0, store)
srv, err := server.New("127.0.0.1", 0, store, "")
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/cli/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func TestHomePageShowsMetadata(t *testing.T) {
}

store := newTestStore(t)
srv, err := server.New("127.0.0.1", 0, store)
srv, err := server.New("127.0.0.1", 0, store, "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -381,7 +381,7 @@ func TestHomePageFilterByTag(t *testing.T) {
}

store := newTestStore(t)
srv, err := server.New("127.0.0.1", 0, store)
srv, err := server.New("127.0.0.1", 0, store, "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -429,7 +429,7 @@ func TestHomePageSearch(t *testing.T) {
t.Parallel()

store := newTestStore(t)
srv, err := server.New("127.0.0.1", 0, store)
srv, err := server.New("127.0.0.1", 0, store, "")
if err != nil {
t.Fatal(err)
}
Expand Down
82 changes: 70 additions & 12 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strings"
"sync"

Expand All @@ -29,9 +30,23 @@ const (
maxArchiveFiles = 2048
)

// hostnamePattern matches valid hostnames with dot-separated labels.
// Each label starts and ends with [a-zA-Z0-9]; inner characters allow hyphens.
// Consecutive dots are inherently rejected because each label requires
// at least one alphanumeric character between dots.
var hostnamePattern = regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?)*$`)

func isValidServerName(name string) bool {
if ip := net.ParseIP(name); ip != nil {
return ip.To4() != nil
}
return hostnamePattern.MatchString(name)
}

type Server struct {
host string
port int
serverName string
store *session.Store
httpServer *http.Server
listener net.Listener
Expand Down Expand Up @@ -401,19 +416,26 @@ var homePageTemplate = template.Must(template.New("home").Parse(`<!doctype html>
</body>
</html>`))

func New(host string, port int, store *session.Store) (*Server, error) {
func New(host string, port int, store *session.Store, serverName string) (*Server, error) {
if host == "" {
host = DefaultHost
}

if serverName != "" {
if !isValidServerName(serverName) {
return nil, errors.New("server: --server-name must be a valid IPv4 address or hostname (letters, digits, dots, hyphens)")
}
}

if store == nil {
return nil, errors.New("server: store must not be nil")
}

srv := &Server{
host: host,
port: port,
store: store,
host: host,
port: port,
serverName: serverName,
store: store,
}

srv.httpServer = &http.Server{
Expand Down Expand Up @@ -474,6 +496,10 @@ func (s *Server) Origin() string {
return ""
}

// NOTE: scheme is hardcoded to http, consistent with the rest of the codebase.
if s.serverName != "" {
return fmt.Sprintf("http://%s:%d", s.serverName, address.Port)
}
ip := address.IP
if ip.IsUnspecified() {
ip = net.IPv4(127, 0, 0, 1)
Expand All @@ -497,6 +523,12 @@ func (s *Server) Origins() []string {

port := address.Port

if s.serverName != "" {
return []string{
fmt.Sprintf("http://%s:%d", s.serverName, port),
}
}

if address.IP.IsUnspecified() {
origins := []string{fmt.Sprintf("http://127.0.0.1:%d", port)}

Expand Down Expand Up @@ -807,10 +839,10 @@ func (s *Server) handleCreatePathSession(w http.ResponseWriter, r *http.Request)
return
}

baseURL := baseURL(r)
base := s.serverBaseURL(r)
response := createSessionResponse{
SessionID: session.ID,
URL: baseURL + "/s/" + session.ID + "/",
URL: base + "/s/" + session.ID + "/",
EntryFile: session.EntryFile,
RootDir: session.RootDir,
}
Expand Down Expand Up @@ -927,10 +959,10 @@ func (s *Server) handleCreateUploadedSession(w http.ResponseWriter, r *http.Requ
return
}

baseURL := baseURL(r)
base := s.serverBaseURL(r)
response := createSessionResponse{
SessionID: session.ID,
URL: baseURL + "/s/" + session.ID + "/",
URL: base + "/s/" + session.ID + "/",
EntryFile: session.EntryFile,
RootDir: session.RootDir,
}
Expand Down Expand Up @@ -1294,13 +1326,39 @@ func writeJSON(w http.ResponseWriter, status int, payload any) {
_ = json.NewEncoder(w).Encode(payload)
}

// determineScheme returns the URL scheme for the request.
// It trusts the X-Forwarded-Proto header, so the server must
// run behind a trusted reverse proxy when relying on this logic.
func determineScheme(r *http.Request) string {
if r.Header.Get("X-Forwarded-Proto") == "https" || r.TLS != nil {
return "https"
}
return "http"
}

func baseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
return determineScheme(r) + "://" + r.Host
}

func (s *Server) serverBaseURL(r *http.Request) string {
s.mu.RLock()
defer s.mu.RUnlock()

if s.serverName != "" {
port := s.port
if s.listener != nil {
if addr, ok := s.listener.Addr().(*net.TCPAddr); ok {
port = addr.Port
}
}
scheme := determineScheme(r)
if (scheme == "http" && port == 80) || (scheme == "https" && port == 443) {
return fmt.Sprintf("%s://%s", scheme, s.serverName)
}
return fmt.Sprintf("%s://%s:%d", scheme, s.serverName, port)
}

return scheme + "://" + r.Host
return baseURL(r)
}

func hasPrefixSuffix(path, prefix, suffix string) bool {
Expand Down
Loading
Loading