diff --git a/go.mod b/go.mod index 8b62549..cb23d80 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,10 @@ module github.com/sun-praise/static-html go 1.24.1 -require github.com/mattn/go-sqlite3 v1.14.37 +require ( + github.com/coder/websocket v1.8.14 + github.com/fsnotify/fsnotify v1.10.1 + github.com/mattn/go-sqlite3 v1.14.37 +) + +require golang.org/x/sys v0.13.0 // indirect diff --git a/go.sum b/go.sum index 9c79a75..3b399a6 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,8 @@ +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= +github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho= +github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= github.com/mattn/go-sqlite3 v1.14.37 h1:3DOZp4cXis1cUIpCfXLtmlGolNLp2VEqhiB/PARNBIg= github.com/mattn/go-sqlite3 v1.14.37/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 5e9692c..8345a34 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -47,6 +47,8 @@ func Run(args []string, stdout io.Writer, stderr io.Writer) error { return runSearch(args[1:], stdout) case "delete": return runDelete(args[1:], stdout) + case "watch": + return runWatch(args[1:], stdout) default: return fmt.Errorf("unknown command: %s", args[0]) } @@ -61,7 +63,8 @@ func printUsage(w io.Writer) { sth project [--db /path/to/sessions.db] [--server http://...] sth list [--tag ] [--category ] [--project ] [--limit ] [--offset ] [--db /path/to/sessions.db] sth search [--tag ] [--category ] [--project ] [--limit ] [--offset ] [--db /path/to/sessions.db] - sth delete [--db /path/to/sessions.db]`) + sth delete [--db /path/to/sessions.db] + sth watch --session [--server http://127.0.0.1:3939]`) } func runStart(args []string, stdout io.Writer) error { diff --git a/internal/cli/watch.go b/internal/cli/watch.go new file mode 100644 index 0000000..2406ff4 --- /dev/null +++ b/internal/cli/watch.go @@ -0,0 +1,258 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/sun-praise/static-html/internal/server" +) + +func runWatch(args []string, stdout io.Writer) error { + flags, positionals, err := parseArgs(args) + if err != nil { + return err + } + + if len(positionals) < 1 { + return errors.New("usage: sth watch --session [--server http://127.0.0.1:3939]") + } + + sessionID := flags["session"] + if sessionID == "" { + return errors.New("--session is required") + } + + serverURL := server.DefaultServerURL + if v, ok := flags["server"]; ok { + serverURL = v + } + + watchPath, err := filepath.Abs(positionals[0]) + if err != nil { + return fmt.Errorf("failed to resolve path: %w", err) + } + + info, err := os.Stat(watchPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("path does not exist: %q", watchPath) + } + return err + } + if !info.IsDir() { + return errors.New("watch path must be a directory") + } + + if err := validateSession(serverURL, sessionID); err != nil { + return err + } + + fmt.Fprintf(stdout, "Watching %s → session %s\n", watchPath, sessionID) + + return watchAndSync(context.Background(), watchPath, sessionID, serverURL, stdout) +} + +func validateSession(serverURL, sessionID string) error { + parsedURL, err := url.Parse(serverURL) + if err != nil { + return fmt.Errorf("invalid server URL: %w", err) + } + + resp, err := http.Get(parsedURL.ResolveReference(&url.URL{Path: "/api/sessions/" + sessionID + "/metadata"}).String()) + if err != nil { + return fmt.Errorf("could not reach server at %s: %w", parsedURL.Host, err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return fmt.Errorf("session %q not found", sessionID) + } + if resp.StatusCode >= 400 { + return fmt.Errorf("server returned status %d", resp.StatusCode) + } + return nil +} + +func watchAndSync(ctx context.Context, dir, sessionID, serverURL string, stdout io.Writer) error { + w, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("failed to create file watcher: %w", err) + } + defer w.Close() + + if err := addWatchDirs(w, dir); err != nil { + return fmt.Errorf("failed to set up file watching: %w", err) + } + + ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) + defer stop() + + var ( + mu sync.Mutex + pending = make(map[string]struct{}) + debounce *time.Timer + ) + const debounceMs = 300 * time.Millisecond + + flush := func() { + mu.Lock() + paths := make([]string, 0, len(pending)) + for p := range pending { + paths = append(paths, p) + } + pending = make(map[string]struct{}) + debounce = nil + mu.Unlock() + + if len(paths) == 0 { + return + } + + if err := uploadFiles(serverURL, sessionID, dir, paths); err != nil { + fmt.Fprintf(stdout, "Error syncing %d file(s): %v\n", len(paths), err) + } else { + for _, p := range paths { + rel, _ := filepath.Rel(dir, p) + fmt.Fprintf(stdout, "Synced: %s\n", rel) + } + } + } + + for { + select { + case <-ctx.Done(): + if debounce != nil { + debounce.Stop() + } + fmt.Fprintln(stdout, "\nStopped watching.") + return nil + case event, ok := <-w.Events: + if !ok { + return nil + } + if shouldIgnorePath(event.Name) { + continue + } + if event.Has(fsnotify.Create) { + if info, err := os.Stat(event.Name); err == nil && info.IsDir() { + _ = addWatchDirs(w, event.Name) + continue + } + } + mu.Lock() + pending[event.Name] = struct{}{} + mu.Unlock() + if debounce != nil { + debounce.Reset(debounceMs) + } else { + debounce = time.AfterFunc(debounceMs, flush) + } + case _, ok := <-w.Errors: + if !ok { + return nil + } + } + } +} + +func addWatchDirs(w *fsnotify.Watcher, root string) error { + return filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() && !shouldIgnorePath(path) { + return w.Add(path) + } + return nil + }) +} + +func uploadFiles(serverURL, sessionID, watchRoot string, paths []string) error { + pr, pw := io.Pipe() + mpw := multipart.NewWriter(pw) + + go func() { + defer pw.Close() + for _, p := range paths { + f, err := os.Open(p) + if err != nil { + fmt.Fprintf(os.Stderr, "watch: cannot open %s: %v\n", p, err) + continue + } + info, err := f.Stat() + if err != nil || !info.Mode().IsRegular() { + f.Close() + continue + } + rel, err := filepath.Rel(watchRoot, p) + if err != nil { + f.Close() + continue + } + part, err := mpw.CreateFormFile("files", filepath.ToSlash(rel)) + if err != nil { + f.Close() + _ = pw.CloseWithError(err) + return + } + if _, err := io.Copy(part, f); err != nil { + f.Close() + _ = pw.CloseWithError(err) + return + } + f.Close() + } + _ = mpw.Close() + }() + + parsedURL, err := url.Parse(serverURL) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPut, + parsedURL.ResolveReference(&url.URL{Path: "/api/sessions/" + sessionID + "/files"}).String(), + pr) + if err != nil { + return err + } + req.Header.Set("Content-Type", mpw.FormDataContentType()) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return fmt.Errorf("server returned %d", resp.StatusCode) + } + return nil +} + +func shouldIgnorePath(path string) bool { + base := filepath.Base(path) + if strings.HasPrefix(base, ".") && base != "." && base != ".." { + return true + } + ext := strings.ToLower(filepath.Ext(base)) + switch ext { + case ".swp", ".tmp": + return true + } + return false +} diff --git a/internal/live/client.go b/internal/live/client.go new file mode 100644 index 0000000..29350dc --- /dev/null +++ b/internal/live/client.go @@ -0,0 +1,30 @@ +package live + +import ( + "context" + "sync" + "time" + + "github.com/coder/websocket" +) + +type WSClient struct { + conn *websocket.Conn + mu sync.Mutex +} + +func NewWSClient(conn *websocket.Conn) *WSClient { + return &WSClient{conn: conn} +} + +func (c *WSClient) Send(data []byte) { + c.mu.Lock() + defer c.mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = c.conn.Write(ctx, websocket.MessageText, data) +} + +func (c *WSClient) Close() { + _ = c.conn.Close(websocket.StatusNormalClosure, "") +} diff --git a/internal/live/handler.go b/internal/live/handler.go new file mode 100644 index 0000000..8f87381 --- /dev/null +++ b/internal/live/handler.go @@ -0,0 +1,74 @@ +package live + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/coder/websocket" +) + +func HandleWebSocket(mgr *Manager, getSessionDir func(sessionID string) string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + sessionID := extractSessionID(r.URL.Path) + if sessionID == "" { + http.NotFound(w, r) + return + } + + dir := getSessionDir(sessionID) + if dir == "" { + http.NotFound(w, r) + return + } + + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + if err != nil { + return + } + + client := NewWSClient(conn) + hub := mgr.GetOrCreateHub(sessionID, dir) + hub.Register(client) + + go func() { + defer hub.Unregister(client) + ctx, cancel := context.WithTimeout(context.Background(), 24*time.Hour) + defer cancel() + for { + _, _, err := conn.Read(ctx) + if err != nil { + return + } + } + }() + } +} + +func extractSessionID(urlPath string) string { + s := strings.TrimPrefix(urlPath, "/s/") + if s == urlPath || s == "" { + return "" + } + end := len(s) + if idx := indexOf(s, '/'); idx >= 0 { + end = idx + } + sessionID := s[:end] + if sessionID == "" { + return "" + } + return sessionID +} + +func indexOf(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} diff --git a/internal/live/hub.go b/internal/live/hub.go new file mode 100644 index 0000000..5850365 --- /dev/null +++ b/internal/live/hub.go @@ -0,0 +1,146 @@ +package live + +import ( + "context" + "encoding/json" + "sync" +) + +type Message struct { + Type string `json:"type"` +} + +var ReloadMessage = Message{Type: "reload"} + +func ReloadJSON() []byte { + data, _ := json.Marshal(ReloadMessage) + return data +} + +type Client interface { + Send(data []byte) + Close() +} + +type Hub struct { + mu sync.Mutex + clients map[Client]struct{} + onChange func() + onEmpty func() +} + +func NewHub(onChange, onEmpty func()) *Hub { + return &Hub{ + clients: make(map[Client]struct{}), + onChange: onChange, + onEmpty: onEmpty, + } +} + +func (h *Hub) Register(c Client) { + h.mu.Lock() + defer h.mu.Unlock() + h.clients[c] = struct{}{} +} + +func (h *Hub) Unregister(c Client) { + h.mu.Lock() + delete(h.clients, c) + empty := len(h.clients) == 0 + onEmpty := h.onEmpty + h.mu.Unlock() + + c.Close() + + if empty && onEmpty != nil { + onEmpty() + } +} + +func (h *Hub) Broadcast(data []byte) { + h.mu.Lock() + clients := make([]Client, 0, len(h.clients)) + for c := range h.clients { + clients = append(clients, c) + } + h.mu.Unlock() + + for _, c := range clients { + c.Send(data) + } +} + +func (h *Hub) ClientCount() int { + h.mu.Lock() + defer h.mu.Unlock() + return len(h.clients) +} + +type Manager struct { + mu sync.Mutex + hubs map[string]*hubEntry + watchDir func(sessionID, dir string, notify func()) (context.CancelFunc, error) +} + +type hubEntry struct { + hub *Hub + cancelWatch context.CancelFunc +} + +func NewManager(watchDir func(sessionID, dir string, notify func()) (context.CancelFunc, error)) *Manager { + return &Manager{ + hubs: make(map[string]*hubEntry), + watchDir: watchDir, + } +} + +func (m *Manager) GetOrCreateHub(sessionID, dir string) *Hub { + m.mu.Lock() + defer m.mu.Unlock() + + if entry, ok := m.hubs[sessionID]; ok { + return entry.hub + } + + hub := NewHub(nil, nil) + entry := &hubEntry{hub: hub} + + hub.onChange = func() { + hub.Broadcast(ReloadJSON()) + } + hub.onEmpty = func() { + m.mu.Lock() + defer m.mu.Unlock() + if e, ok := m.hubs[sessionID]; ok { + if e.cancelWatch != nil { + e.cancelWatch() + } + delete(m.hubs, sessionID) + } + } + + if m.watchDir != nil && dir != "" { + cancel, err := m.watchDir(sessionID, dir, func() { + hub.Broadcast(ReloadJSON()) + }) + if err == nil { + entry.cancelWatch = cancel + } + } + + m.hubs[sessionID] = entry + return hub +} + +func (m *Manager) BroadcastTo(sessionID string, data []byte) { + m.mu.Lock() + entry, ok := m.hubs[sessionID] + m.mu.Unlock() + + // Safe TOCTOU: even if the hub is deleted concurrently after we release + // m.mu, the hub's Broadcast is a no-op on a stale client set. The worst + // case is a missed broadcast, which is acceptable for reload notifications. + if ok { + entry.hub.Broadcast(data) + } +} diff --git a/internal/live/inject.go b/internal/live/inject.go new file mode 100644 index 0000000..78237cb --- /dev/null +++ b/internal/live/inject.go @@ -0,0 +1,132 @@ +package live + +import ( + "bufio" + "bytes" + "io" + "net" + "net/http" + "strconv" + "strings" +) + +const liveReloadScript = `` + +var scriptBytes = []byte(liveReloadScript) +var headClose = []byte("") +var bodyClose = []byte("") + +func InjectMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/s/") { + h.ServeHTTP(w, r) + return + } + + if strings.HasSuffix(r.URL.Path, "/ws") { + h.ServeHTTP(w, r) + return + } + + bw := &bufferedWriter{ResponseWriter: w} + h.ServeHTTP(bw, r) + + if bw.hijacked { + return + } + + ct := bw.Header().Get("Content-Type") + if !strings.HasPrefix(ct, "text/html") { + bw.Header().Set("Content-Length", strconv.Itoa(bw.buf.Len())) + bw.realHeader(bw.code) + bw.buf.WriteTo(bw.ResponseWriter) + return + } + + injected := injectScript(bw.buf.Bytes(), scriptBytes) + bw.Header().Set("Content-Length", strconv.Itoa(len(injected))) + bw.realHeader(bw.code) + bw.ResponseWriter.Write(injected) + }) +} + +type bufferedWriter struct { + http.ResponseWriter + buf bytes.Buffer + code int + wroteHdr bool + hijacked bool +} + +func (bw *bufferedWriter) Write(b []byte) (int, error) { + return bw.buf.Write(b) +} + +func (bw *bufferedWriter) WriteHeader(code int) { + if !bw.wroteHdr { + bw.code = code + bw.wroteHdr = true + } +} + +func (bw *bufferedWriter) realHeader(code int) { + if !bw.wroteHdr { + bw.wroteHdr = true + bw.code = code + } + bw.ResponseWriter.WriteHeader(bw.code) +} + +func (bw *bufferedWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + bw.hijacked = true + return bw.ResponseWriter.(http.Hijacker).Hijack() +} + +func (bw *bufferedWriter) ReadFrom(r io.Reader) (int64, error) { + return io.Copy(&bw.buf, r) +} + +var _ http.Hijacker = (*bufferedWriter)(nil) +var _ io.ReaderFrom = (*bufferedWriter)(nil) + +func injectScript(body []byte, script []byte) []byte { + if idx := bytes.LastIndex(body, headClose); idx != -1 { + var result bytes.Buffer + result.Write(body[:idx]) + result.Write(script) + result.Write(body[idx:]) + return result.Bytes() + } + + if idx := bytes.LastIndex(body, bodyClose); idx != -1 { + var result bytes.Buffer + result.Write(body[:idx]) + result.Write(script) + result.Write(body[idx:]) + return result.Bytes() + } + + var result bytes.Buffer + result.Write(body) + result.Write(script) + return result.Bytes() +} diff --git a/internal/live/live_test.go b/internal/live/live_test.go new file mode 100644 index 0000000..fdb6448 --- /dev/null +++ b/internal/live/live_test.go @@ -0,0 +1,199 @@ +package live + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/coder/websocket" +) + +func TestInjectMiddlewareSkipsNonSessionPaths(t *testing.T) { + t.Parallel() + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + }) + + handler := InjectMiddleware(inner) + req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + body := rec.Body.String() + if strings.Contains(body, "WebSocket") { + t.Fatal("non-session path should not have script injected") + } + if body != `{"ok":true}` { + t.Fatalf("body = %q, want unchanged JSON", body) + } +} + +func TestInjectMiddlewareAddsScriptToHTML(t *testing.T) { + t.Parallel() + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`TestHello`)) + }) + + handler := InjectMiddleware(inner) + req := httptest.NewRequest(http.MethodGet, "/s/abc123/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", rec.Code) + } + body := rec.Body.String() + if !strings.Contains(body, "WebSocket") { + t.Fatal("HTML response should contain WebSocket script") + } + if !strings.Contains(body, "Test") { + t.Fatal("original content should be preserved") + } +} + +func TestInjectMiddlewareSkipsNonHTML(t *testing.T) { + t.Parallel() + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/css") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`body { color: red; }`)) + }) + + handler := InjectMiddleware(inner) + req := httptest.NewRequest(http.MethodGet, "/s/abc123/style.css", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + body := rec.Body.String() + if strings.Contains(body, "WebSocket") { + t.Fatal("CSS response should not have script injected") + } + if body != `body { color: red; }` { + t.Fatalf("body = %q, want unchanged CSS", body) + } +} + +func TestInjectMiddlewareSkipsWSPath(t *testing.T) { + t.Parallel() + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusSwitchingProtocols) + w.Write([]byte("upgraded")) + }) + + handler := InjectMiddleware(inner) + req := httptest.NewRequest(http.MethodGet, "/s/abc123/ws", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusSwitchingProtocols { + t.Fatalf("status = %d, want 101", rec.Code) + } +} + +func TestHubBroadcast(t *testing.T) { + hub := NewHub(nil, nil) + + received := make(chan []byte, 1) + c := &mockClient{send: func(data []byte) { received <- data }} + hub.Register(c) + + hub.Broadcast(ReloadJSON()) + + select { + case data := <-received: + var msg Message + if err := json.Unmarshal(data, &msg); err != nil { + t.Fatal(err) + } + if msg.Type != "reload" { + t.Fatalf("type = %q, want reload", msg.Type) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for broadcast") + } +} + +func TestHubUnregisterCallsOnEmpty(t *testing.T) { + called := make(chan struct{}, 1) + hub := NewHub(nil, func() { called <- struct{}{} }) + + c := &mockClient{} + hub.Register(c) + hub.Unregister(c) + + select { + case <-called: + case <-time.After(time.Second): + t.Fatal("onEmpty not called") + } +} + +func TestWebSocketHandler(t *testing.T) { + mgr := NewManager(nil) + handler := HandleWebSocket(mgr, func(sessionID string) string { + return "/tmp/test-session" + }) + + server := httptest.NewServer(handler) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/s/test-session/ws" + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + t.Fatalf("websocket dial: %v", err) + } + + // Wait for server goroutine to register the client + for i := 0; i < 50; i++ { + mgr.mu.Lock() + entry, ok := mgr.hubs["test-session"] + mgr.mu.Unlock() + if ok && entry.hub.ClientCount() > 0 { + break + } + time.Sleep(20 * time.Millisecond) + } + + mgr.mu.Lock() + entry, ok := mgr.hubs["test-session"] + mgr.mu.Unlock() + + if !ok { + t.Fatal("expected hub to be created for test-session") + } + if entry.hub.ClientCount() != 1 { + t.Fatalf("expected 1 client, got %d", entry.hub.ClientCount()) + } + + conn.Close(websocket.StatusNormalClosure, "") +} + +type mockClient struct { + send func(data []byte) +} + +func (m *mockClient) Send(data []byte) { + if m.send != nil { + m.send(data) + } +} + +func (m *mockClient) Close() {} diff --git a/internal/live/watcher.go b/internal/live/watcher.go new file mode 100644 index 0000000..95a8c99 --- /dev/null +++ b/internal/live/watcher.go @@ -0,0 +1,94 @@ +package live + +import ( + "context" + "os" + "path/filepath" + "strings" + "time" + + "github.com/fsnotify/fsnotify" +) + +const debounceInterval = 300 * time.Millisecond + +func WatchDir(ctx context.Context, dir string, notify func()) (context.CancelFunc, error) { + w, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + if err := addWatchRecursive(w, dir); err != nil { + w.Close() + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + + go func() { + defer w.Close() + + var timer *time.Timer + for { + select { + case <-ctx.Done(): + if timer != nil { + timer.Stop() + } + return + case event, ok := <-w.Events: + if !ok { + return + } + if shouldIgnorePath(event.Name) { + continue + } + if event.Has(fsnotify.Create) { + if info, err := os.Stat(event.Name); err == nil && info.IsDir() { + _ = addWatchRecursive(w, event.Name) + } + } + if timer != nil { + timer.Reset(debounceInterval) + } else { + timer = time.AfterFunc(debounceInterval, func() { + notify() + }) + } + case <-w.Errors: + return + } + } + }() + + return cancel, nil +} + +func addWatchRecursive(w *fsnotify.Watcher, root string) error { + return filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if shouldIgnorePath(path) { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } + if d.IsDir() { + return w.Add(path) + } + return nil + }) +} + +func shouldIgnorePath(path string) bool { + base := filepath.Base(path) + if strings.HasPrefix(base, ".") && base != "." { + return true + } + if strings.HasSuffix(base, ".swp") || strings.HasSuffix(base, ".tmp") { + return true + } + return false +} diff --git a/internal/server/server.go b/internal/server/server.go index 74fc184..d7dbfaa 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -19,6 +19,7 @@ import ( "strings" "sync" + "github.com/sun-praise/static-html/internal/live" "github.com/sun-praise/static-html/internal/session" ) @@ -50,6 +51,7 @@ type Server struct { store *session.Store httpServer *http.Server listener net.Listener + liveMgr *live.Manager mu sync.RWMutex } @@ -436,10 +438,13 @@ func New(host string, port int, store *session.Store, serverName string) (*Serve port: port, serverName: serverName, store: store, + liveMgr: live.NewManager(func(sessionID, dir string, notify func()) (context.CancelFunc, error) { + return live.WatchDir(context.Background(), dir, notify) + }), } srv.httpServer = &http.Server{ - Handler: srv.routes(), + Handler: live.InjectMiddleware(srv.routes()), } return srv, nil @@ -555,6 +560,23 @@ func (s *Server) routes() http.Handler { case r.Method == http.MethodPost && r.URL.Path == "/api/sessions": s.handleCreateSession(w, r) case r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/s/"): + if strings.HasSuffix(r.URL.Path, "/ws") { + sid := strings.TrimPrefix(r.URL.Path, "/s/") + sid = strings.TrimSuffix(sid, "/ws") + _, found, err := s.store.Get(sid) + if err != nil || !found { + http.NotFound(w, r) + return + } + live.HandleWebSocket(s.liveMgr, func(id string) string { + s, ok, _ := s.store.Get(id) + if ok { + return s.StoredRootDir + } + return "" + }).ServeHTTP(w, r) + return + } s.handlePreview(w, r) case r.Method == http.MethodPut && hasPrefixSuffix(r.URL.Path, "/api/sessions/", "/tags"): s.handleAddTags(w, r) @@ -574,6 +596,8 @@ func (s *Server) routes() http.Handler { s.handleDownloadSession(w, r) case r.Method == http.MethodDelete && isExactSessionPath(r.URL.Path): s.handleDeleteSession(w, r) + case r.Method == http.MethodPut && hasPrefixSuffix(r.URL.Path, "/api/sessions/", "/files"): + s.handleUpdateFiles(w, r) default: http.NotFound(w, r) } @@ -1567,3 +1591,100 @@ func (s *Server) handleDeleteSession(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "deleted"}) } + +const maxIncrementalBytes = 50 << 20 + +func (s *Server) handleUpdateFiles(w http.ResponseWriter, r *http.Request) { + sessionID, ok := extractSessionIDFromMetaPath(r.URL.Path, "/api/sessions/", "/files") + if !ok { + writeJSONError(w, http.StatusBadRequest, "Invalid session ID.") + return + } + + sess, found, err := s.store.Get(sessionID) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, err.Error()) + return + } + if !found { + writeJSONError(w, http.StatusNotFound, "Session not found.") + return + } + + r.Body = http.MaxBytesReader(w, r.Body, maxIncrementalBytes) + if err := r.ParseMultipartForm(maxIncrementalBytes); err != nil { + writeJSONError(w, http.StatusBadRequest, "Failed to parse multipart form upload.") + return + } + defer func() { + if r.MultipartForm != nil { + _ = r.MultipartForm.RemoveAll() + } + }() + + count := 0 + for _, headers := range r.MultipartForm.File { + for _, hdr := range headers { + if hdr.Size > maxIncrementalBytes { + writeJSONError(w, http.StatusRequestEntityTooLarge, + fmt.Sprintf("File %q exceeds 50MB limit.", hdr.Filename)) + return + } + + src, err := hdr.Open() + if err != nil { + writeJSONError(w, http.StatusBadRequest, fmt.Sprintf("Failed to read file %q.", hdr.Filename)) + return + } + + relPath := filepath.Clean(filepath.FromSlash(hdr.Filename)) + if relPath == "." || relPath == ".." || strings.HasPrefix(relPath, ".."+string(filepath.Separator)) { + src.Close() + continue + } + if hdr.Filename == "" || strings.HasSuffix(relPath, string(filepath.Separator)) { + src.Close() + continue + } + + targetPath := filepath.Join(sess.StoredRootDir, relPath) + if !IsSubpath(sess.StoredRootDir, targetPath) { + src.Close() + writeJSONError(w, http.StatusBadRequest, "File path escapes session root.") + return + } + + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + src.Close() + writeJSONError(w, http.StatusInternalServerError, err.Error()) + return + } + + dst, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + if err != nil { + src.Close() + writeJSONError(w, http.StatusInternalServerError, err.Error()) + return + } + + _, copyErr := io.Copy(dst, src) + closeErr := errors.Join(src.Close(), dst.Close()) + if copyErr != nil { + writeJSONError(w, http.StatusInternalServerError, copyErr.Error()) + return + } + if closeErr != nil { + writeJSONError(w, http.StatusInternalServerError, closeErr.Error()) + return + } + count++ + } + } + + s.liveMgr.BroadcastTo(sessionID, live.ReloadJSON()) + + writeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "files_updated": count, + }) +}