From 436878b7c1da71550c2d13c21fe56c46ea909d18 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sat, 4 Apr 2026 19:05:54 -0400 Subject: [PATCH 01/10] refactor(api): restructure monolithic server into modular internal packages --- api/Makefile | 3 +- api/afn-rest.go | 339 ------------------- api/afn-rest_test.go | 268 --------------- api/internal/app/config.go | 39 +++ api/internal/app/run.go | 51 +++ api/internal/contact/service.go | 108 ++++++ api/internal/contact/service_test.go | 94 +++++ api/internal/httpapi/handler_stats.go | 42 +++ api/internal/httpapi/handler_stats_test.go | 88 +++++ api/internal/httpapi/middleware_auth.go | 38 +++ api/internal/httpapi/middleware_auth_test.go | 71 ++++ api/internal/httpapi/middleware_cors.go | 12 + api/internal/httpapi/router.go | 25 ++ api/internal/httpapi/router_test.go | 68 ++++ api/internal/resources/index.go | 44 +++ api/internal/resources/schema.go | 144 ++++++++ api/internal/stats/service.go | 77 +++++ api/internal/stats/service_test.go | 89 +++++ 18 files changed, 991 insertions(+), 609 deletions(-) delete mode 100644 api/afn-rest.go delete mode 100644 api/afn-rest_test.go create mode 100644 api/internal/app/config.go create mode 100644 api/internal/app/run.go create mode 100644 api/internal/contact/service.go create mode 100644 api/internal/contact/service_test.go create mode 100644 api/internal/httpapi/handler_stats.go create mode 100644 api/internal/httpapi/handler_stats_test.go create mode 100644 api/internal/httpapi/middleware_auth.go create mode 100644 api/internal/httpapi/middleware_auth_test.go create mode 100644 api/internal/httpapi/middleware_cors.go create mode 100644 api/internal/httpapi/router.go create mode 100644 api/internal/httpapi/router_test.go create mode 100644 api/internal/resources/index.go create mode 100644 api/internal/resources/schema.go create mode 100644 api/internal/stats/service.go create mode 100644 api/internal/stats/service_test.go diff --git a/api/Makefile b/api/Makefile index a2d5db3..ad99ba2 100644 --- a/api/Makefile +++ b/api/Makefile @@ -1,5 +1,4 @@ .PHONY: api api: - CGO_ENABLED=0 go build -v - + CGO_ENABLED=0 go build -v -o api ./cmd/api diff --git a/api/afn-rest.go b/api/afn-rest.go deleted file mode 100644 index 08b88ea..0000000 --- a/api/afn-rest.go +++ /dev/null @@ -1,339 +0,0 @@ -package main - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "os" - "regexp" - "strings" - "text/template" - "time" - - "github.com/julsemaan/anyfile-notepad/utils" - "github.com/julsemaan/rest-layer-file" - cache "github.com/patrickmn/go-cache" - "github.com/rs/rest-layer/resource" - "github.com/rs/rest-layer/rest" - "github.com/rs/rest-layer/schema" - "gopkg.in/alexcesaro/statsd.v2" -) - -// Precompiled regexes -var pathStatsRE = regexp.MustCompile(`^/stats`) -var emailRegex = regexp.MustCompile(`\S+@\S+`) - -// Small constants for collection names to avoid magic strings -const ( - mimeTypesColl = "mime_types" - extensionsColl = "extensions" - syntaxesColl = "syntaxes" - settingsColl = "settings" - contactColl = "contact_requests" -) - -var contactRequestsCache = cache.New(24*time.Hour, 1*time.Minute) -var maxContactRequestsPerDay = 10 -var sendEmail = utils.SendEmail - -var statsdConn, _ = statsd.New(statsd.Address(os.Getenv("AFN_STATSD_URI"))) - -func main() { - defer statsdConn.Close() - - schema.CreatedField.ReadOnly = false - schema.UpdatedField.ReadOnly = false - - var ( - mime_type = schema.Schema{ - Description: `The mime_type object`, - Fields: schema.Fields{ - "id": schema.IDField, - "created_at": schema.CreatedField, - "updated_at": schema.UpdatedField, - "type_name": { - Required: true, - Filterable: true, - }, - "integrated": { - Default: false, - Filterable: true, - Validator: &schema.Bool{}, - }, - "discovered_by": { - Default: "John Doe", - Filterable: true, - }, - }, - } - - extension = schema.Schema{ - Description: `Represents an extension`, - Fields: schema.Fields{ - "id": schema.IDField, - "created_at": schema.CreatedField, - "updated_at": schema.UpdatedField, - "name": { - Required: true, - Filterable: true, - }, - "syntax_id": { - Required: true, - Filterable: true, - Validator: &schema.Reference{ - Path: "syntaxes", - }, - }, - "mime_type_id": { - Required: true, - Filterable: true, - Validator: &schema.Reference{ - Path: "mime_types", - }, - }, - }, - } - - syntax = schema.Schema{ - Description: `Represents a syntax`, - Fields: schema.Fields{ - "id": schema.IDField, - "created_at": schema.CreatedField, - "updated_at": schema.UpdatedField, - "display_name": { - Required: true, - Filterable: true, - }, - "ace_js_mode": { - Required: true, - Filterable: true, - }, - }, - } - - setting = schema.Schema{ - Description: `Represents a setting`, - Fields: schema.Fields{ - "id": schema.IDField, - "created_at": schema.CreatedField, - "updated_at": schema.UpdatedField, - "var_name": { - Required: true, - Filterable: true, - }, - "value": { - Required: true, - Filterable: true, - }, - }, - } - - contactRequest = schema.Schema{ - Description: "Represents a contact request", - Fields: schema.Fields{ - "id": schema.IDField, - "created_at": schema.CreatedField, - "updated_at": schema.UpdatedField, - "contact_email": { - Required: true, - Filterable: true, - Validator: &emailValidator{}, - }, - "message": { - Required: true, - Filterable: true, - Validator: &schema.String{ - MinLen: 10, - MaxLen: 2000, - }, - }, - }, - } - ) - - // Create a REST API resource index - index := resource.NewIndex() - - directory := os.Getenv("AFN_REST_DATA_DIR") - if directory == "" { - directory = "./db" - } - - index.Bind("mime_types", mime_type, filestore.NewHandler(directory, mimeTypesColl, []string{"type_name"}), resource.Conf{ - AllowedModes: resource.ReadWrite, - }) - - index.Bind("extensions", extension, filestore.NewHandler(directory, extensionsColl, []string{"name"}), resource.Conf{ - AllowedModes: resource.ReadWrite, - }) - - index.Bind("syntaxes", syntax, filestore.NewHandler(directory, syntaxesColl, []string{"ace_js_mode", "display_name"}), resource.Conf{ - AllowedModes: resource.ReadWrite, - }) - - index.Bind("settings", setting, filestore.NewHandler(directory, settingsColl, []string{"var_name"}), resource.Conf{ - AllowedModes: resource.ReadWrite, - }) - - contactRequests := index.Bind("contact_requests", contactRequest, filestore.NewHandler(directory, contactColl, []string{"id"}), resource.Conf{ - AllowedModes: resource.ReadWrite, - }) - contactRequests.Use(resource.InsertEventHandlerFunc(insertContactRequestHook)) - contactRequests.Use(resource.InsertedEventHandlerFunc(insertedContactRequestHook)) - - // Create API HTTP handler for the resource graph - api, err := rest.NewHandler(index) - if err != nil { - log.Fatalf("Invalid API configuration: %s", err) - } - - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - - if pathStatsRE.MatchString(r.URL.Path) { - handleStats(w, r) - return - } else if isOpenResource(r) { - // pass - } else if !authenticate(w, r) { - return - } - api.ServeHTTP(w, r) - }) - - // Serve it - log.Print("Serving API on http://localhost:8080") - if err := http.ListenAndServe(":8080", nil); err != nil { - log.Fatal(err) - } -} - -func parseStatsPayload(w http.ResponseWriter, r *http.Request) (map[string]string, error) { - buf, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Invalid payload", http.StatusBadRequest) - return nil, err - } - dec := json.NewDecoder(bytes.NewBuffer(buf)) - var s map[string]string - if err := dec.Decode(&s); err != nil { - http.Error(w, "Invalid JSON", http.StatusBadRequest) - return nil, err - } - if r.Header.Get("X-Forwarded-For") != "" { - s["ip"] = strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0] - } else { - re := regexp.MustCompile("^([0-9.]+):") - s["ip"] = re.FindAllStringSubmatch(r.RemoteAddr, 1)[0][1] - } - return s, nil -} - -func isOpenResource(r *http.Request) bool { - if r.Method == "OPTIONS" { - return true - } - - if strings.HasPrefix(r.URL.Path, "/contact_requests") { - if r.Method == "POST" { - log.Print("Allowing without authentication for creating a contact request") - return true - } else { - return false - } - } - - if r.Method == "GET" { - log.Print("Allowing without authentication for namespace that don't modify resources") - return true - } - - return false -} - -func authenticate(w http.ResponseWriter, r *http.Request) bool { - username, password, ok := r.BasicAuth() - if !ok || username != os.Getenv("AFN_REST_USERNAME") || password != os.Getenv("AFN_REST_PASSWORD") { - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("Unauthorized")) - return false - } - return true -} - -func insertContactRequestHook(ctx context.Context, items []*resource.Item) error { - if contactRequestsCache.ItemCount() >= maxContactRequestsPerDay { - return errors.New("Too many contact requests, try again later") - } else { - for _, item := range items { - contactRequestsCache.SetDefault(item.ID.(string), item.ID) - } - return nil - } -} - -func insertedContactRequestHook(ctx context.Context, items []*resource.Item, err *error) { - if *err != nil { - log.Println("Not sending contact request email because there was an error saving the contact request") - return - } - - emails := []string{os.Getenv("AFN_SUPPORT_EMAIL")} - for _, item := range items { - msgTemplate, _ := template.New("contact-email").Parse(`Subject: Anyfile Notepad - Message from {{.ReplyTo}} -To: {{.Emails}} -Reply-To: {{.ReplyTo}} - -{{.Message}} -`) - var msgBytes bytes.Buffer - msgTemplate.Execute(&msgBytes, struct { - Emails string - Message string - ReplyTo string - }{ - Emails: strings.Join(emails, ";"), - Message: item.Payload["message"].(string), - ReplyTo: item.Payload["contact_email"].(string), - }) - msg, _ := io.ReadAll(&msgBytes) - sendEmail(emails, msg) - } - -} - -type emailValidator struct { -} - -func (emailValidator) Validate(value interface{}) (interface{}, error) { - email := value.(string) - if res := emailRegex.MatchString(email); !res { - return email, errors.New("Invalid email format") - } else { - return email, nil - } -} - -func handleStats(w http.ResponseWriter, r *http.Request) { - log.Print("Allowing without authentication for stats namespace") - if statsRequest, err := parseStatsPayload(w, r); err == nil { - statsdConn.Increment(fmt.Sprintf("afn.stats-hits.%s", strings.Replace(statsRequest["ip"], ".", "_", -1))) - log.Printf("Stats request from %s", statsRequest["ip"]) - switch statsRequest["type"] { - case "increment": - log.Printf("afn.stats-hits.%s from %s", statsRequest["key"], statsRequest["ip"]) - statsdConn.Increment(statsRequest["key"]) - } - w.Write([]byte("OK")) - return - } else { - return - } -} diff --git a/api/afn-rest_test.go b/api/afn-rest_test.go deleted file mode 100644 index cc52ab5..0000000 --- a/api/afn-rest_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package main - -import ( - "context" - "errors" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - cache "github.com/patrickmn/go-cache" - "github.com/rs/rest-layer/resource" -) - -type errReader struct{} - -func (errReader) Read(p []byte) (int, error) { - return 0, errors.New("read failed") -} - -func (errReader) Close() error { - return nil -} - -func TestParseStatsPayload(t *testing.T) { - t.Run("uses forwarded for ip", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"hits"}`)) - req.Header.Set("X-Forwarded-For", "203.0.113.10, 70.41.3.18") - w := httptest.NewRecorder() - - payload, err := parseStatsPayload(w, req) - if err != nil { - t.Fatalf("unexpected parse error: %v", err) - } - - if payload["ip"] != "203.0.113.10" { - t.Fatalf("expected forwarded ip, got %q", payload["ip"]) - } - }) - - t.Run("uses remote addr when header missing", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"hits"}`)) - req.RemoteAddr = "192.0.2.15:43210" - w := httptest.NewRecorder() - - payload, err := parseStatsPayload(w, req) - if err != nil { - t.Fatalf("unexpected parse error: %v", err) - } - - if payload["ip"] != "192.0.2.15" { - t.Fatalf("expected remote ip, got %q", payload["ip"]) - } - }) - - t.Run("invalid json returns bad request", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"broken":`)) - req.RemoteAddr = "192.0.2.15:43210" - w := httptest.NewRecorder() - - _, err := parseStatsPayload(w, req) - if err == nil { - t.Fatal("expected parse error") - } - if w.Code != http.StatusBadRequest { - t.Fatalf("expected bad request status, got %d", w.Code) - } - }) - - t.Run("body read failures return bad request", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/stats", nil) - req.Body = io.NopCloser(errReader{}) - req.RemoteAddr = "192.0.2.15:43210" - w := httptest.NewRecorder() - - _, err := parseStatsPayload(w, req) - if err == nil { - t.Fatal("expected read error") - } - if w.Code != http.StatusBadRequest { - t.Fatalf("expected bad request status, got %d", w.Code) - } - }) -} - -func TestIsOpenResource(t *testing.T) { - cases := []struct { - name string - method string - path string - expected bool - }{ - {name: "options always open", method: http.MethodOptions, path: "/anything", expected: true}, - {name: "contact_requests post open", method: http.MethodPost, path: "/contact_requests", expected: true}, - {name: "contact_requests get closed", method: http.MethodGet, path: "/contact_requests", expected: false}, - {name: "get open", method: http.MethodGet, path: "/syntaxes", expected: true}, - {name: "post protected", method: http.MethodPost, path: "/syntaxes", expected: false}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest(tc.method, tc.path, nil) - if isOpenResource(req) != tc.expected { - t.Fatalf("expected %v", tc.expected) - } - }) - } -} - -func TestAuthenticate(t *testing.T) { - t.Setenv("AFN_REST_USERNAME", "test-user") - t.Setenv("AFN_REST_PASSWORD", "test-pass") - - t.Run("accepts valid basic auth", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) - req.SetBasicAuth("test-user", "test-pass") - w := httptest.NewRecorder() - - if !authenticate(w, req) { - t.Fatal("expected authentication to pass") - } - }) - - t.Run("rejects invalid basic auth", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) - req.SetBasicAuth("bad-user", "bad-pass") - w := httptest.NewRecorder() - - if authenticate(w, req) { - t.Fatal("expected authentication to fail") - } - if w.Code != http.StatusUnauthorized { - t.Fatalf("expected unauthorized status, got %d", w.Code) - } - if strings.TrimSpace(w.Body.String()) != "Unauthorized" { - t.Fatalf("expected unauthorized body, got %q", w.Body.String()) - } - }) - - t.Run("rejects missing basic auth", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) - w := httptest.NewRecorder() - - if authenticate(w, req) { - t.Fatal("expected missing auth to fail") - } - if w.Code != http.StatusUnauthorized { - t.Fatalf("expected unauthorized status, got %d", w.Code) - } - }) -} - -func TestInsertContactRequestHook(t *testing.T) { - originalCache := contactRequestsCache - originalMax := maxContactRequestsPerDay - contactRequestsCache = cache.New(24*time.Hour, time.Minute) - maxContactRequestsPerDay = 2 - t.Cleanup(func() { - contactRequestsCache = originalCache - maxContactRequestsPerDay = originalMax - }) - - items := []*resource.Item{{ID: "req-1"}, {ID: "req-2"}} - if err := insertContactRequestHook(context.Background(), items); err != nil { - t.Fatalf("unexpected insert hook error: %v", err) - } - - if contactRequestsCache.ItemCount() != 2 { - t.Fatalf("expected cache to contain two entries, got %d", contactRequestsCache.ItemCount()) - } - - err := insertContactRequestHook(context.Background(), []*resource.Item{{ID: "req-3"}}) - if err == nil { - t.Fatal("expected insert hook to reject above threshold") - } - if err.Error() != "Too many contact requests, try again later" { - t.Fatalf("unexpected error message: %v", err) - } -} - -func TestInsertedContactRequestHookNoopOnError(t *testing.T) { - incomingErr := errors.New("storage failure") - err := error(incomingErr) - insertedContactRequestHook(context.Background(), []*resource.Item{{ID: "req-1"}}, &err) - - if err == nil || err.Error() != "storage failure" { - t.Fatalf("expected input error to be preserved, got %v", err) - } -} - -func TestEmailValidator(t *testing.T) { - v := emailValidator{} - - if _, err := v.Validate("name@example.com"); err != nil { - t.Fatalf("expected valid email, got error: %v", err) - } - - if _, err := v.Validate("not-an-email"); err == nil { - t.Fatal("expected invalid email to fail validation") - } -} - -func TestInsertedContactRequestHookSendsEmail(t *testing.T) { - originalSend := sendEmail - t.Cleanup(func() { - sendEmail = originalSend - }) - - t.Setenv("AFN_SUPPORT_EMAIL", "support@example.com") - sent := false - sendEmail = func(to []string, msg []byte) error { - sent = true - if len(to) != 1 || to[0] != "support@example.com" { - t.Fatalf("unexpected recipients: %#v", to) - } - if !strings.Contains(string(msg), "Need help") { - t.Fatalf("expected message content in email body, got %q", string(msg)) - } - return nil - } - - var err error - insertedContactRequestHook(context.Background(), []*resource.Item{{ - ID: "req-1", - Payload: map[string]interface{}{ - "message": "Need help", - "contact_email": "user@example.com", - }, - }}, &err) - - if !sent { - t.Fatal("expected sendEmail to be called") - } -} - -func TestHandleStats(t *testing.T) { - t.Run("returns OK on increment stats payload", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"hits"}`)) - req.RemoteAddr = "192.0.2.15:12345" - w := httptest.NewRecorder() - - handleStats(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } - if strings.TrimSpace(w.Body.String()) != "OK" { - t.Fatalf("unexpected body: %q", w.Body.String()) - } - }) - - t.Run("returns OK on non-increment payload", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"noop","key":"hits"}`)) - req.RemoteAddr = "192.0.2.16:12345" - w := httptest.NewRecorder() - - handleStats(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", w.Code) - } - if strings.TrimSpace(w.Body.String()) != "OK" { - t.Fatalf("unexpected body: %q", w.Body.String()) - } - }) -} diff --git a/api/internal/app/config.go b/api/internal/app/config.go new file mode 100644 index 0000000..751508c --- /dev/null +++ b/api/internal/app/config.go @@ -0,0 +1,39 @@ +package app + +import "os" + +const defaultDataDir = "./db" +const defaultListenAddr = ":8080" +const defaultContactRequestsPerDay = 10 + +type Config struct { + DataDir string + ListenAddr string + Username string + Password string + SupportEmail string + StatsdAddress string + MaxContactRequestsPerDay int +} + +func LoadConfigFromEnv() Config { + dataDir := os.Getenv("AFN_REST_DATA_DIR") + if dataDir == "" { + dataDir = defaultDataDir + } + + listenAddr := os.Getenv("AFN_REST_LISTEN_ADDR") + if listenAddr == "" { + listenAddr = defaultListenAddr + } + + return Config{ + DataDir: dataDir, + ListenAddr: listenAddr, + Username: os.Getenv("AFN_REST_USERNAME"), + Password: os.Getenv("AFN_REST_PASSWORD"), + SupportEmail: os.Getenv("AFN_SUPPORT_EMAIL"), + StatsdAddress: os.Getenv("AFN_STATSD_URI"), + MaxContactRequestsPerDay: defaultContactRequestsPerDay, + } +} diff --git a/api/internal/app/run.go b/api/internal/app/run.go new file mode 100644 index 0000000..6dc2b89 --- /dev/null +++ b/api/internal/app/run.go @@ -0,0 +1,51 @@ +package app + +import ( + "log" + "net/http" + + "github.com/julsemaan/anyfile-notepad/api/internal/contact" + "github.com/julsemaan/anyfile-notepad/api/internal/httpapi" + "github.com/julsemaan/anyfile-notepad/api/internal/resources" + "github.com/julsemaan/anyfile-notepad/api/internal/stats" + "github.com/julsemaan/anyfile-notepad/utils" + cache "github.com/patrickmn/go-cache" + "github.com/rs/rest-layer/resource" + "github.com/rs/rest-layer/rest" + "github.com/rs/rest-layer/schema" + "gopkg.in/alexcesaro/statsd.v2" +) + +func Run(cfg Config) error { + schema.CreatedField.ReadOnly = false + schema.UpdatedField.ReadOnly = false + + statsConn, _ := statsd.New(statsd.Address(cfg.StatsdAddress)) + if statsConn != nil { + defer statsConn.Close() + } + + statsService := stats.NewService(statsConn) + contactCache := cache.New(24*60*60*1000000000, 60*1000000000) + contactService := contact.NewService(contactCache, cfg.MaxContactRequestsPerDay, cfg.SupportEmail, utils.SendEmail) + + index := resources.BuildIndex(cfg.DataDir, resources.ContactHooks{ + Insert: resource.InsertEventHandlerFunc(contactService.BeforeInsert), + Inserted: resource.InsertedEventHandlerFunc(contactService.AfterInsert), + }) + + restHandler, err := rest.NewHandler(index) + if err != nil { + return err + } + + router := httpapi.NewRouter( + restHandler, + httpapi.NewStatsHandler(statsService), + cfg.Username, + cfg.Password, + ) + + log.Printf("Serving API on http://localhost%s", cfg.ListenAddr) + return http.ListenAndServe(cfg.ListenAddr, router) +} diff --git a/api/internal/contact/service.go b/api/internal/contact/service.go new file mode 100644 index 0000000..65a6813 --- /dev/null +++ b/api/internal/contact/service.go @@ -0,0 +1,108 @@ +package contact + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "strings" + "text/template" + + "github.com/rs/rest-layer/resource" +) + +var errTooManyRequests = errors.New("Too many contact requests, try again later") + +var messageTemplate = template.Must(template.New("contact-email").Parse(`Subject: Anyfile Notepad - Message from {{.ReplyTo}} +To: {{.Emails}} +Reply-To: {{.ReplyTo}} + +{{.Message}} +`)) + +type Cache interface { + ItemCount() int + SetDefault(key string, value interface{}) +} + +type Sender func(to []string, msg []byte) error + +type Service struct { + cache Cache + maxPerDay int + supportEmail string + sendEmail Sender +} + +func NewService(cache Cache, maxPerDay int, supportEmail string, sendEmail Sender) *Service { + return &Service{ + cache: cache, + maxPerDay: maxPerDay, + supportEmail: supportEmail, + sendEmail: sendEmail, + } +} + +func (s *Service) BeforeInsert(_ context.Context, items []*resource.Item) error { + if s.cache != nil && s.cache.ItemCount() >= s.maxPerDay { + return errTooManyRequests + } + + for _, item := range items { + if s.cache != nil { + id := fmt.Sprint(item.ID) + s.cache.SetDefault(id, item.ID) + } + } + + return nil +} + +func (s *Service) AfterInsert(_ context.Context, items []*resource.Item, err *error) { + if err != nil && *err != nil { + log.Println("Not sending contact request email because there was an error saving the contact request") + return + } + + if s.sendEmail == nil || s.supportEmail == "" { + return + } + + recipients := []string{s.supportEmail} + for _, item := range items { + msg, buildErr := buildMessage(recipients, item) + if buildErr != nil { + log.Printf("Unable to build contact request email: %v", buildErr) + continue + } + _ = s.sendEmail(recipients, msg) + } +} + +func buildMessage(recipients []string, item *resource.Item) ([]byte, error) { + message, ok := item.Payload["message"].(string) + if !ok { + return nil, errors.New("message is missing") + } + replyTo, ok := item.Payload["contact_email"].(string) + if !ok { + return nil, errors.New("contact_email is missing") + } + + var msgBytes bytes.Buffer + if err := messageTemplate.Execute(&msgBytes, struct { + Emails string + Message string + ReplyTo string + }{ + Emails: strings.Join(recipients, ";"), + Message: message, + ReplyTo: replyTo, + }); err != nil { + return nil, err + } + + return io.ReadAll(&msgBytes) +} diff --git a/api/internal/contact/service_test.go b/api/internal/contact/service_test.go new file mode 100644 index 0000000..ab64cc3 --- /dev/null +++ b/api/internal/contact/service_test.go @@ -0,0 +1,94 @@ +package contact + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/rs/rest-layer/resource" +) + +type cacheStub struct { + items map[string]interface{} +} + +func newCacheStub() *cacheStub { + return &cacheStub{items: map[string]interface{}{}} +} + +func (c *cacheStub) ItemCount() int { + return len(c.items) +} + +func (c *cacheStub) SetDefault(key string, value interface{}) { + c.items[key] = value +} + +func TestBeforeInsert(t *testing.T) { + cache := newCacheStub() + svc := NewService(cache, 2, "support@example.com", nil) + + items := []*resource.Item{{ID: "req-1"}, {ID: "req-2"}} + if err := svc.BeforeInsert(context.Background(), items); err != nil { + t.Fatalf("unexpected insert hook error: %v", err) + } + + if cache.ItemCount() != 2 { + t.Fatalf("expected cache to contain two entries, got %d", cache.ItemCount()) + } + + err := svc.BeforeInsert(context.Background(), []*resource.Item{{ID: "req-3"}}) + if err == nil { + t.Fatal("expected insert hook to reject above threshold") + } + if err.Error() != "Too many contact requests, try again later" { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestAfterInsertNoopOnError(t *testing.T) { + sent := false + svc := NewService(newCacheStub(), 10, "support@example.com", func([]string, []byte) error { + sent = true + return nil + }) + + incomingErr := errors.New("storage failure") + err := error(incomingErr) + svc.AfterInsert(context.Background(), []*resource.Item{{ID: "req-1"}}, &err) + + if sent { + t.Fatal("did not expect email to be sent") + } + if err == nil || err.Error() != "storage failure" { + t.Fatalf("expected input error to be preserved, got %v", err) + } +} + +func TestAfterInsertSendsEmail(t *testing.T) { + sent := false + svc := NewService(newCacheStub(), 10, "support@example.com", func(to []string, msg []byte) error { + sent = true + if len(to) != 1 || to[0] != "support@example.com" { + t.Fatalf("unexpected recipients: %#v", to) + } + if !strings.Contains(string(msg), "Need help") { + t.Fatalf("expected message content in email body, got %q", string(msg)) + } + return nil + }) + + var err error + svc.AfterInsert(context.Background(), []*resource.Item{{ + ID: "req-1", + Payload: map[string]interface{}{ + "message": "Need help", + "contact_email": "user@example.com", + }, + }}, &err) + + if !sent { + t.Fatal("expected sendEmail to be called") + } +} diff --git a/api/internal/httpapi/handler_stats.go b/api/internal/httpapi/handler_stats.go new file mode 100644 index 0000000..72d0f75 --- /dev/null +++ b/api/internal/httpapi/handler_stats.go @@ -0,0 +1,42 @@ +package httpapi + +import ( + "errors" + "fmt" + "log" + "net/http" + + "github.com/julsemaan/anyfile-notepad/api/internal/stats" +) + +type StatsService interface { + ParsePayload(r *http.Request) (map[string]string, error) + Record(payload map[string]string) +} + +func NewStatsHandler(statsService StatsService) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Print("Allowing without authentication for stats namespace") + + payload, err := statsService.ParsePayload(r) + if err != nil { + switch { + case errors.Is(err, stats.ErrInvalidPayload): + http.Error(w, "Invalid payload", http.StatusBadRequest) + case errors.Is(err, stats.ErrInvalidJSON): + http.Error(w, "Invalid JSON", http.StatusBadRequest) + default: + http.Error(w, "Invalid payload", http.StatusBadRequest) + } + return + } + + statsService.Record(payload) + log.Printf("Stats request from %s", payload["ip"]) + if payload["type"] == "increment" { + log.Printf("afn.stats-hits.%s from %s", payload["key"], payload["ip"]) + } + + _, _ = w.Write([]byte(fmt.Sprint("OK"))) + }) +} diff --git a/api/internal/httpapi/handler_stats_test.go b/api/internal/httpapi/handler_stats_test.go new file mode 100644 index 0000000..00e816e --- /dev/null +++ b/api/internal/httpapi/handler_stats_test.go @@ -0,0 +1,88 @@ +package httpapi + +import ( + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/julsemaan/anyfile-notepad/api/internal/stats" +) + +type statsServiceStub struct { + payload map[string]string + err error + recorded bool +} + +func (s *statsServiceStub) ParsePayload(*http.Request) (map[string]string, error) { + if s.err != nil { + return nil, s.err + } + return s.payload, nil +} + +func (s *statsServiceStub) Record(map[string]string) { + s.recorded = true +} + +func TestStatsHandler(t *testing.T) { + t.Run("returns OK on increment payload", func(t *testing.T) { + stub := &statsServiceStub{payload: map[string]string{"ip": "192.0.2.15", "type": "increment", "key": "hits"}} + handler := NewStatsHandler(stub) + + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"hits"}`)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if strings.TrimSpace(w.Body.String()) != "OK" { + t.Fatalf("unexpected body: %q", w.Body.String()) + } + if !stub.recorded { + t.Fatal("expected stats payload to be recorded") + } + }) + + t.Run("invalid json returns bad request", func(t *testing.T) { + stub := &statsServiceStub{err: stats.ErrInvalidJSON} + handler := NewStatsHandler(stub) + + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"broken":`)) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + }) + + t.Run("invalid payload returns bad request", func(t *testing.T) { + stub := &statsServiceStub{err: stats.ErrInvalidPayload} + handler := NewStatsHandler(stub) + + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader("{}")) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + }) + + t.Run("unknown errors return bad request", func(t *testing.T) { + stub := &statsServiceStub{err: errors.New("boom")} + handler := NewStatsHandler(stub) + + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader("{}")) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + }) +} diff --git a/api/internal/httpapi/middleware_auth.go b/api/internal/httpapi/middleware_auth.go new file mode 100644 index 0000000..d20379e --- /dev/null +++ b/api/internal/httpapi/middleware_auth.go @@ -0,0 +1,38 @@ +package httpapi + +import ( + "log" + "net/http" + "strings" +) + +func IsOpenResource(r *http.Request) bool { + if r.Method == http.MethodOptions { + return true + } + + if strings.HasPrefix(r.URL.Path, "/contact_requests") { + if r.Method == http.MethodPost { + log.Print("Allowing without authentication for creating a contact request") + return true + } + return false + } + + if r.Method == http.MethodGet { + log.Print("Allowing without authentication for namespace that don't modify resources") + return true + } + + return false +} + +func Authenticate(w http.ResponseWriter, r *http.Request, username string, password string) bool { + requestUsername, requestPassword, ok := r.BasicAuth() + if !ok || requestUsername != username || requestPassword != password { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("Unauthorized")) + return false + } + return true +} diff --git a/api/internal/httpapi/middleware_auth_test.go b/api/internal/httpapi/middleware_auth_test.go new file mode 100644 index 0000000..e7927ee --- /dev/null +++ b/api/internal/httpapi/middleware_auth_test.go @@ -0,0 +1,71 @@ +package httpapi + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestIsOpenResource(t *testing.T) { + testCases := []struct { + name string + method string + path string + expected bool + }{ + {name: "options always open", method: http.MethodOptions, path: "/anything", expected: true}, + {name: "contact requests post open", method: http.MethodPost, path: "/contact_requests", expected: true}, + {name: "contact requests get closed", method: http.MethodGet, path: "/contact_requests", expected: false}, + {name: "read resource open", method: http.MethodGet, path: "/syntaxes", expected: true}, + {name: "write resource closed", method: http.MethodPost, path: "/syntaxes", expected: false}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + req := httptest.NewRequest(testCase.method, testCase.path, nil) + if IsOpenResource(req) != testCase.expected { + t.Fatalf("expected %v", testCase.expected) + } + }) + } +} + +func TestAuthenticate(t *testing.T) { + t.Run("accepts valid basic auth", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) + req.SetBasicAuth("good-user", "good-password") + w := httptest.NewRecorder() + + if !Authenticate(w, req, "good-user", "good-password") { + t.Fatal("expected authentication to pass") + } + }) + + t.Run("rejects invalid basic auth", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) + req.SetBasicAuth("bad-user", "bad-password") + w := httptest.NewRecorder() + + if Authenticate(w, req, "good-user", "good-password") { + t.Fatal("expected authentication to fail") + } + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected unauthorized status, got %d", w.Code) + } + if w.Body.String() != "Unauthorized" { + t.Fatalf("expected unauthorized body, got %q", w.Body.String()) + } + }) + + t.Run("rejects missing basic auth", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) + w := httptest.NewRecorder() + + if Authenticate(w, req, "good-user", "good-password") { + t.Fatal("expected missing auth to fail") + } + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected unauthorized status, got %d", w.Code) + } + }) +} diff --git a/api/internal/httpapi/middleware_cors.go b/api/internal/httpapi/middleware_cors.go new file mode 100644 index 0000000..cc917f4 --- /dev/null +++ b/api/internal/httpapi/middleware_cors.go @@ -0,0 +1,12 @@ +package httpapi + +import "net/http" + +func withCORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + next.ServeHTTP(w, r) + }) +} diff --git a/api/internal/httpapi/router.go b/api/internal/httpapi/router.go new file mode 100644 index 0000000..bed2c9e --- /dev/null +++ b/api/internal/httpapi/router.go @@ -0,0 +1,25 @@ +package httpapi + +import ( + "net/http" + "regexp" +) + +var pathStatsRE = regexp.MustCompile(`^/stats`) + +func NewRouter(apiHandler http.Handler, statsHandler http.Handler, username string, password string) http.Handler { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if pathStatsRE.MatchString(r.URL.Path) { + statsHandler.ServeHTTP(w, r) + return + } + + if !IsOpenResource(r) && !Authenticate(w, r, username, password) { + return + } + + apiHandler.ServeHTTP(w, r) + }) + + return withCORS(handler) +} diff --git a/api/internal/httpapi/router_test.go b/api/internal/httpapi/router_test.go new file mode 100644 index 0000000..4c2ab33 --- /dev/null +++ b/api/internal/httpapi/router_test.go @@ -0,0 +1,68 @@ +package httpapi + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRouter(t *testing.T) { + apiCalled := false + statsCalled := false + + apiHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalled = true + w.WriteHeader(http.StatusAccepted) + }) + statsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + statsCalled = true + w.WriteHeader(http.StatusOK) + }) + + router := NewRouter(apiHandler, statsHandler, "user", "password") + + t.Run("stats route bypasses auth", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/stats", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if !statsCalled { + t.Fatal("expected stats handler to be called") + } + if apiCalled { + t.Fatal("did not expect api handler to be called") + } + }) + + t.Run("protected route requires auth", func(t *testing.T) { + apiCalled = false + statsCalled = false + + req := httptest.NewRequest(http.MethodPost, "/syntaxes", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + if apiCalled { + t.Fatal("did not expect api handler to be called") + } + }) + + t.Run("open route and cors headers", func(t *testing.T) { + apiCalled = false + statsCalled = false + + req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if !apiCalled { + t.Fatal("expected api handler to be called") + } + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Fatal("expected cors headers") + } + }) +} diff --git a/api/internal/resources/index.go b/api/internal/resources/index.go new file mode 100644 index 0000000..015713e --- /dev/null +++ b/api/internal/resources/index.go @@ -0,0 +1,44 @@ +package resources + +import ( + filestore "github.com/julsemaan/rest-layer-file" + "github.com/rs/rest-layer/resource" +) + +type ContactHooks struct { + Insert resource.InsertEventHandler + Inserted resource.InsertedEventHandler +} + +func BuildIndex(directory string, hooks ContactHooks) resource.Index { + index := resource.NewIndex() + + index.Bind(MimeTypesCollection, MimeTypeSchema(), filestore.NewHandler(directory, MimeTypesCollection, []string{"type_name"}), resource.Conf{ + AllowedModes: resource.ReadWrite, + }) + + index.Bind(ExtensionsCollection, ExtensionSchema(), filestore.NewHandler(directory, ExtensionsCollection, []string{"name"}), resource.Conf{ + AllowedModes: resource.ReadWrite, + }) + + index.Bind(SyntaxesCollection, SyntaxSchema(), filestore.NewHandler(directory, SyntaxesCollection, []string{"ace_js_mode", "display_name"}), resource.Conf{ + AllowedModes: resource.ReadWrite, + }) + + index.Bind(SettingsCollection, SettingSchema(), filestore.NewHandler(directory, SettingsCollection, []string{"var_name"}), resource.Conf{ + AllowedModes: resource.ReadWrite, + }) + + contactRequests := index.Bind(ContactRequestsCollection, ContactRequestSchema(), filestore.NewHandler(directory, ContactRequestsCollection, []string{"id"}), resource.Conf{ + AllowedModes: resource.ReadWrite, + }) + + if hooks.Insert != nil { + contactRequests.Use(hooks.Insert) + } + if hooks.Inserted != nil { + contactRequests.Use(hooks.Inserted) + } + + return index +} diff --git a/api/internal/resources/schema.go b/api/internal/resources/schema.go new file mode 100644 index 0000000..33e0ca9 --- /dev/null +++ b/api/internal/resources/schema.go @@ -0,0 +1,144 @@ +package resources + +import ( + "errors" + "regexp" + + "github.com/rs/rest-layer/schema" +) + +const MimeTypesCollection = "mime_types" +const ExtensionsCollection = "extensions" +const SyntaxesCollection = "syntaxes" +const SettingsCollection = "settings" +const ContactRequestsCollection = "contact_requests" + +var emailRegex = regexp.MustCompile(`\S+@\S+`) + +func MimeTypeSchema() schema.Schema { + return schema.Schema{ + Description: "The mime_type object", + Fields: schema.Fields{ + "id": schema.IDField, + "created_at": schema.CreatedField, + "updated_at": schema.UpdatedField, + "type_name": { + Required: true, + Filterable: true, + }, + "integrated": { + Default: false, + Filterable: true, + Validator: &schema.Bool{}, + }, + "discovered_by": { + Default: "John Doe", + Filterable: true, + }, + }, + } +} + +func ExtensionSchema() schema.Schema { + return schema.Schema{ + Description: "Represents an extension", + Fields: schema.Fields{ + "id": schema.IDField, + "created_at": schema.CreatedField, + "updated_at": schema.UpdatedField, + "name": { + Required: true, + Filterable: true, + }, + "syntax_id": { + Required: true, + Filterable: true, + Validator: &schema.Reference{ + Path: SyntaxesCollection, + }, + }, + "mime_type_id": { + Required: true, + Filterable: true, + Validator: &schema.Reference{ + Path: MimeTypesCollection, + }, + }, + }, + } +} + +func SyntaxSchema() schema.Schema { + return schema.Schema{ + Description: "Represents a syntax", + Fields: schema.Fields{ + "id": schema.IDField, + "created_at": schema.CreatedField, + "updated_at": schema.UpdatedField, + "display_name": { + Required: true, + Filterable: true, + }, + "ace_js_mode": { + Required: true, + Filterable: true, + }, + }, + } +} + +func SettingSchema() schema.Schema { + return schema.Schema{ + Description: "Represents a setting", + Fields: schema.Fields{ + "id": schema.IDField, + "created_at": schema.CreatedField, + "updated_at": schema.UpdatedField, + "var_name": { + Required: true, + Filterable: true, + }, + "value": { + Required: true, + Filterable: true, + }, + }, + } +} + +func ContactRequestSchema() schema.Schema { + return schema.Schema{ + Description: "Represents a contact request", + Fields: schema.Fields{ + "id": schema.IDField, + "created_at": schema.CreatedField, + "updated_at": schema.UpdatedField, + "contact_email": { + Required: true, + Filterable: true, + Validator: emailValidator{}, + }, + "message": { + Required: true, + Filterable: true, + Validator: &schema.String{ + MinLen: 10, + MaxLen: 2000, + }, + }, + }, + } +} + +type emailValidator struct{} + +func (emailValidator) Validate(value interface{}) (interface{}, error) { + email, ok := value.(string) + if !ok { + return value, errors.New("Invalid email format") + } + if !emailRegex.MatchString(email) { + return email, errors.New("Invalid email format") + } + return email, nil +} diff --git a/api/internal/stats/service.go b/api/internal/stats/service.go new file mode 100644 index 0000000..a9a48ad --- /dev/null +++ b/api/internal/stats/service.go @@ -0,0 +1,77 @@ +package stats + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "regexp" + "strings" +) + +var ErrInvalidPayload = errors.New("invalid payload") +var ErrInvalidJSON = errors.New("invalid json") + +var remoteAddrRegex = regexp.MustCompile(`^([0-9.]+):`) + +type Metrics interface { + Increment(bucket string) +} + +type Service struct { + metrics Metrics +} + +func NewService(metrics Metrics) *Service { + return &Service{metrics: metrics} +} + +func (s *Service) ParsePayload(r *http.Request) (map[string]string, error) { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, ErrInvalidPayload + } + + var payload map[string]string + decoder := json.NewDecoder(bytes.NewBuffer(body)) + if err := decoder.Decode(&payload); err != nil { + return nil, ErrInvalidJSON + } + + payload["ip"] = extractIP(r) + return payload, nil +} + +func (s *Service) Record(payload map[string]string) { + if s.metrics == nil { + return + } + + ipKey := strings.ReplaceAll(payload["ip"], ".", "_") + s.metrics.Increment("afn.stats-hits." + ipKey) + + if payload["type"] == "increment" { + s.metrics.Increment(payload["key"]) + } +} + +func extractIP(r *http.Request) string { + forwardedFor := r.Header.Get("X-Forwarded-For") + if forwardedFor != "" { + return strings.TrimSpace(strings.Split(forwardedFor, ",")[0]) + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil { + return host + } + + matches := remoteAddrRegex.FindAllStringSubmatch(r.RemoteAddr, 1) + if len(matches) > 0 && len(matches[0]) > 1 { + return matches[0][1] + } + + return r.RemoteAddr +} diff --git a/api/internal/stats/service_test.go b/api/internal/stats/service_test.go new file mode 100644 index 0000000..5ac1a30 --- /dev/null +++ b/api/internal/stats/service_test.go @@ -0,0 +1,89 @@ +package stats + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" +) + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, errors.New("read failed") +} + +func (errReader) Close() error { + return nil +} + +type metricsStub struct { + keys []string +} + +func (s *metricsStub) Increment(bucket string) { + s.keys = append(s.keys, bucket) +} + +func TestParsePayload(t *testing.T) { + svc := NewService(nil) + + t.Run("uses forwarded for ip", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"hits"}`)) + req.Header.Set("X-Forwarded-For", "203.0.113.10, 70.41.3.18") + + payload, err := svc.ParsePayload(req) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if payload["ip"] != "203.0.113.10" { + t.Fatalf("expected forwarded ip, got %q", payload["ip"]) + } + }) + + t.Run("uses remote addr when header missing", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"hits"}`)) + req.RemoteAddr = "192.0.2.15:43210" + + payload, err := svc.ParsePayload(req) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if payload["ip"] != "192.0.2.15" { + t.Fatalf("expected remote ip, got %q", payload["ip"]) + } + }) + + t.Run("invalid json returns error", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"broken":`)) + _, err := svc.ParsePayload(req) + if !errors.Is(err, ErrInvalidJSON) { + t.Fatalf("expected ErrInvalidJSON, got %v", err) + } + }) + + t.Run("body read failures return error", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/stats", nil) + req.Body = io.NopCloser(errReader{}) + + _, err := svc.ParsePayload(req) + if !errors.Is(err, ErrInvalidPayload) { + t.Fatalf("expected ErrInvalidPayload, got %v", err) + } + }) +} + +func TestRecord(t *testing.T) { + stub := &metricsStub{} + svc := NewService(stub) + + svc.Record(map[string]string{"ip": "192.0.2.15", "type": "increment", "key": "hits"}) + + expected := []string{"afn.stats-hits.192_0_2_15", "hits"} + if !reflect.DeepEqual(stub.keys, expected) { + t.Fatalf("unexpected metrics keys: %#v", stub.keys) + } +} From 467b33922a1fa116ae2826c6aceb38a59bb95deb Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sat, 4 Apr 2026 19:14:39 -0400 Subject: [PATCH 02/10] feat(api): add fallback path resolution for api build target --- api/Makefile | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/api/Makefile b/api/Makefile index ad99ba2..eed8f04 100644 --- a/api/Makefile +++ b/api/Makefile @@ -1,4 +1,11 @@ .PHONY: api api: - CGO_ENABLED=0 go build -v -o api ./cmd/api + @if [ -d ./cmd/api ]; then \ + CGO_ENABLED=0 go build -v -o api ./cmd/api; \ + elif [ -d ./api/cmd/api ]; then \ + CGO_ENABLED=0 go build -v -o api ./api/cmd/api; \ + else \ + echo "cannot find api main package (expected ./cmd/api or ./api/cmd/api)"; \ + exit 1; \ + fi From d4d7426ec117341fd056d1e7943bae58687ef707 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sat, 4 Apr 2026 19:20:39 -0400 Subject: [PATCH 03/10] fix: harden security and fix bugs across auth, routing, and validation --- api/internal/app/config.go | 17 +++++++- api/internal/app/config_test.go | 41 ++++++++++++++++++++ api/internal/app/run.go | 8 +++- api/internal/httpapi/middleware_auth.go | 1 + api/internal/httpapi/middleware_auth_test.go | 3 ++ api/internal/httpapi/router.go | 2 +- api/internal/httpapi/router_test.go | 19 +++++++++ api/internal/resources/schema.go | 11 ++++-- api/internal/resources/schema_test.go | 35 +++++++++++++++++ api/internal/stats/service.go | 5 ++- api/internal/stats/service_test.go | 19 +++++++++ 11 files changed, 151 insertions(+), 10 deletions(-) create mode 100644 api/internal/app/config_test.go create mode 100644 api/internal/resources/schema_test.go diff --git a/api/internal/app/config.go b/api/internal/app/config.go index 751508c..8ccc958 100644 --- a/api/internal/app/config.go +++ b/api/internal/app/config.go @@ -1,10 +1,14 @@ package app -import "os" +import ( + "os" + "strconv" +) const defaultDataDir = "./db" const defaultListenAddr = ":8080" const defaultContactRequestsPerDay = 10 +const envMaxContactRequestsPerDay = "AFN_MAX_CONTACT_REQUESTS_PER_DAY" type Config struct { DataDir string @@ -34,6 +38,15 @@ func LoadConfigFromEnv() Config { Password: os.Getenv("AFN_REST_PASSWORD"), SupportEmail: os.Getenv("AFN_SUPPORT_EMAIL"), StatsdAddress: os.Getenv("AFN_STATSD_URI"), - MaxContactRequestsPerDay: defaultContactRequestsPerDay, + MaxContactRequestsPerDay: loadMaxContactRequestsPerDay(), } } + +func loadMaxContactRequestsPerDay() int { + maxPerDay, err := strconv.Atoi(os.Getenv(envMaxContactRequestsPerDay)) + if err != nil || maxPerDay <= 0 { + return defaultContactRequestsPerDay + } + + return maxPerDay +} diff --git a/api/internal/app/config_test.go b/api/internal/app/config_test.go new file mode 100644 index 0000000..d25335f --- /dev/null +++ b/api/internal/app/config_test.go @@ -0,0 +1,41 @@ +package app + +import "testing" + +func TestLoadConfigFromEnv(t *testing.T) { + t.Run("uses default max contact requests when env missing", func(t *testing.T) { + t.Setenv(envMaxContactRequestsPerDay, "") + + cfg := LoadConfigFromEnv() + if cfg.MaxContactRequestsPerDay != defaultContactRequestsPerDay { + t.Fatalf("expected default max contact requests, got %d", cfg.MaxContactRequestsPerDay) + } + }) + + t.Run("uses configured max contact requests when valid", func(t *testing.T) { + t.Setenv(envMaxContactRequestsPerDay, "25") + + cfg := LoadConfigFromEnv() + if cfg.MaxContactRequestsPerDay != 25 { + t.Fatalf("expected configured max contact requests, got %d", cfg.MaxContactRequestsPerDay) + } + }) + + t.Run("falls back to default when invalid", func(t *testing.T) { + t.Setenv(envMaxContactRequestsPerDay, "invalid") + + cfg := LoadConfigFromEnv() + if cfg.MaxContactRequestsPerDay != defaultContactRequestsPerDay { + t.Fatalf("expected default max contact requests, got %d", cfg.MaxContactRequestsPerDay) + } + }) + + t.Run("falls back to default when non positive", func(t *testing.T) { + t.Setenv(envMaxContactRequestsPerDay, "0") + + cfg := LoadConfigFromEnv() + if cfg.MaxContactRequestsPerDay != defaultContactRequestsPerDay { + t.Fatalf("expected default max contact requests, got %d", cfg.MaxContactRequestsPerDay) + } + }) +} diff --git a/api/internal/app/run.go b/api/internal/app/run.go index 6dc2b89..6d3ddf3 100644 --- a/api/internal/app/run.go +++ b/api/internal/app/run.go @@ -3,6 +3,7 @@ package app import ( "log" "net/http" + "time" "github.com/julsemaan/anyfile-notepad/api/internal/contact" "github.com/julsemaan/anyfile-notepad/api/internal/httpapi" @@ -20,13 +21,16 @@ func Run(cfg Config) error { schema.CreatedField.ReadOnly = false schema.UpdatedField.ReadOnly = false - statsConn, _ := statsd.New(statsd.Address(cfg.StatsdAddress)) + statsConn, err := statsd.New(statsd.Address(cfg.StatsdAddress)) + if err != nil { + log.Printf("warning: statsd initialization failed: %v", err) + } if statsConn != nil { defer statsConn.Close() } statsService := stats.NewService(statsConn) - contactCache := cache.New(24*60*60*1000000000, 60*1000000000) + contactCache := cache.New(24*time.Hour, time.Minute) contactService := contact.NewService(contactCache, cfg.MaxContactRequestsPerDay, cfg.SupportEmail, utils.SendEmail) index := resources.BuildIndex(cfg.DataDir, resources.ContactHooks{ diff --git a/api/internal/httpapi/middleware_auth.go b/api/internal/httpapi/middleware_auth.go index d20379e..12f978d 100644 --- a/api/internal/httpapi/middleware_auth.go +++ b/api/internal/httpapi/middleware_auth.go @@ -30,6 +30,7 @@ func IsOpenResource(r *http.Request) bool { func Authenticate(w http.ResponseWriter, r *http.Request, username string, password string) bool { requestUsername, requestPassword, ok := r.BasicAuth() if !ok || requestUsername != username || requestPassword != password { + w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`) w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte("Unauthorized")) return false diff --git a/api/internal/httpapi/middleware_auth_test.go b/api/internal/httpapi/middleware_auth_test.go index e7927ee..d7aab6c 100644 --- a/api/internal/httpapi/middleware_auth_test.go +++ b/api/internal/httpapi/middleware_auth_test.go @@ -55,6 +55,9 @@ func TestAuthenticate(t *testing.T) { if w.Body.String() != "Unauthorized" { t.Fatalf("expected unauthorized body, got %q", w.Body.String()) } + if w.Header().Get("WWW-Authenticate") != `Basic realm="restricted"` { + t.Fatalf("expected basic auth challenge header, got %q", w.Header().Get("WWW-Authenticate")) + } }) t.Run("rejects missing basic auth", func(t *testing.T) { diff --git a/api/internal/httpapi/router.go b/api/internal/httpapi/router.go index bed2c9e..435e476 100644 --- a/api/internal/httpapi/router.go +++ b/api/internal/httpapi/router.go @@ -5,7 +5,7 @@ import ( "regexp" ) -var pathStatsRE = regexp.MustCompile(`^/stats`) +var pathStatsRE = regexp.MustCompile(`^/stats(?:/|$)`) func NewRouter(apiHandler http.Handler, statsHandler http.Handler, username string, password string) http.Handler { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/api/internal/httpapi/router_test.go b/api/internal/httpapi/router_test.go index 4c2ab33..b4e6252 100644 --- a/api/internal/httpapi/router_test.go +++ b/api/internal/httpapi/router_test.go @@ -22,6 +22,9 @@ func TestRouter(t *testing.T) { router := NewRouter(apiHandler, statsHandler, "user", "password") t.Run("stats route bypasses auth", func(t *testing.T) { + apiCalled = false + statsCalled = false + req := httptest.NewRequest(http.MethodPost, "/stats", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) @@ -34,6 +37,22 @@ func TestRouter(t *testing.T) { } }) + t.Run("stats prefix does not bypass auth", func(t *testing.T) { + apiCalled = false + statsCalled = false + + req := httptest.NewRequest(http.MethodPost, "/statsanything", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } + if statsCalled { + t.Fatal("did not expect stats handler to be called") + } + }) + t.Run("protected route requires auth", func(t *testing.T) { apiCalled = false statsCalled = false diff --git a/api/internal/resources/schema.go b/api/internal/resources/schema.go index 33e0ca9..aa8a7cd 100644 --- a/api/internal/resources/schema.go +++ b/api/internal/resources/schema.go @@ -2,7 +2,8 @@ package resources import ( "errors" - "regexp" + "net/mail" + "strings" "github.com/rs/rest-layer/schema" ) @@ -13,8 +14,6 @@ const SyntaxesCollection = "syntaxes" const SettingsCollection = "settings" const ContactRequestsCollection = "contact_requests" -var emailRegex = regexp.MustCompile(`\S+@\S+`) - func MimeTypeSchema() schema.Schema { return schema.Schema{ Description: "The mime_type object", @@ -137,7 +136,11 @@ func (emailValidator) Validate(value interface{}) (interface{}, error) { if !ok { return value, errors.New("Invalid email format") } - if !emailRegex.MatchString(email) { + if strings.ContainsAny(email, "\r\n") { + return email, errors.New("Invalid email format") + } + addr, err := mail.ParseAddress(email) + if err != nil || addr.Address != email { return email, errors.New("Invalid email format") } return email, nil diff --git a/api/internal/resources/schema_test.go b/api/internal/resources/schema_test.go new file mode 100644 index 0000000..bda0c4f --- /dev/null +++ b/api/internal/resources/schema_test.go @@ -0,0 +1,35 @@ +package resources + +import "testing" + +func TestEmailValidatorValidate(t *testing.T) { + validator := emailValidator{} + + t.Run("accepts valid email", func(t *testing.T) { + value, err := validator.Validate("john.doe@example.com") + if err != nil { + t.Fatalf("expected valid email, got error: %v", err) + } + if value != "john.doe@example.com" { + t.Fatalf("expected preserved email, got %#v", value) + } + }) + + testCases := []struct { + name string + input interface{} + }{ + {name: "non email value", input: "not-an-email"}, + {name: "crlf injection", input: "john.doe@example.com\r\nBcc:evil@example.com"}, + {name: "embedded whitespace", input: "john doe@example.com"}, + {name: "non string value", input: 123}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + if _, err := validator.Validate(testCase.input); err == nil { + t.Fatalf("expected invalid input %#v to be rejected", testCase.input) + } + }) + } +} diff --git a/api/internal/stats/service.go b/api/internal/stats/service.go index a9a48ad..2496048 100644 --- a/api/internal/stats/service.go +++ b/api/internal/stats/service.go @@ -39,6 +39,9 @@ func (s *Service) ParsePayload(r *http.Request) (map[string]string, error) { if err := decoder.Decode(&payload); err != nil { return nil, ErrInvalidJSON } + if payload == nil { + payload = map[string]string{} + } payload["ip"] = extractIP(r) return payload, nil @@ -49,7 +52,7 @@ func (s *Service) Record(payload map[string]string) { return } - ipKey := strings.ReplaceAll(payload["ip"], ".", "_") + ipKey := strings.NewReplacer(".", "_", ":", "_").Replace(payload["ip"]) s.metrics.Increment("afn.stats-hits." + ipKey) if payload["type"] == "increment" { diff --git a/api/internal/stats/service_test.go b/api/internal/stats/service_test.go index 5ac1a30..813e404 100644 --- a/api/internal/stats/service_test.go +++ b/api/internal/stats/service_test.go @@ -65,6 +65,19 @@ func TestParsePayload(t *testing.T) { } }) + t.Run("json null initializes payload map", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`null`)) + req.RemoteAddr = "192.0.2.15:43210" + + payload, err := svc.ParsePayload(req) + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if payload["ip"] != "192.0.2.15" { + t.Fatalf("expected remote ip, got %q", payload["ip"]) + } + }) + t.Run("body read failures return error", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/stats", nil) req.Body = io.NopCloser(errReader{}) @@ -86,4 +99,10 @@ func TestRecord(t *testing.T) { if !reflect.DeepEqual(stub.keys, expected) { t.Fatalf("unexpected metrics keys: %#v", stub.keys) } + + stub.keys = nil + svc.Record(map[string]string{"ip": "2001:db8::1"}) + if len(stub.keys) != 1 || stub.keys[0] != "afn.stats-hits.2001_db8__1" { + t.Fatalf("expected sanitized ipv6 metric key, got %#v", stub.keys) + } } From 164766d91fe24e29363b672e561c8ecb9aebc6d3 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sat, 4 Apr 2026 19:31:48 -0400 Subject: [PATCH 04/10] feat(api): add main entry point and fix gitignore pattern --- api/.gitignore | 2 +- api/cmd/api/main.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 api/cmd/api/main.go diff --git a/api/.gitignore b/api/.gitignore index cf6b797..79a74ae 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -1,4 +1,4 @@ -api +/api afn-rest-32 afn-rest-64 db/ diff --git a/api/cmd/api/main.go b/api/cmd/api/main.go new file mode 100644 index 0000000..8e6b4f7 --- /dev/null +++ b/api/cmd/api/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "log" + + "github.com/julsemaan/anyfile-notepad/api/internal/app" +) + +func main() { + if err := app.Run(app.LoadConfigFromEnv()); err != nil { + log.Fatal(err) + } +} From 800c546094645d03e3e3b4f9827688cc604e968f Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sat, 4 Apr 2026 19:37:53 -0400 Subject: [PATCH 05/10] fix(api): improve auth security, error handling, and input validation --- Makefile | 3 +-- api/internal/contact/service.go | 4 +++- api/internal/httpapi/handler_stats.go | 4 ++-- api/internal/httpapi/middleware_auth.go | 12 +++++++++++- api/internal/httpapi/middleware_auth_test.go | 13 +++++++++++++ api/internal/resources/schema.go | 8 +++++--- api/internal/stats/service.go | 5 ++++- api/internal/stats/service_test.go | 6 ++++++ 8 files changed, 45 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 7e6af65..fc20c34 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,7 @@ all-golang: cd webserver && make webserver && cd .. - cd api && make afn-rest-32 && cd .. - cd api && make afn-rest-64 && cd .. + cd api && make api && cd .. client-dist.tgz: make client/dist diff --git a/api/internal/contact/service.go b/api/internal/contact/service.go index 65a6813..6017f7b 100644 --- a/api/internal/contact/service.go +++ b/api/internal/contact/service.go @@ -77,7 +77,9 @@ func (s *Service) AfterInsert(_ context.Context, items []*resource.Item, err *er log.Printf("Unable to build contact request email: %v", buildErr) continue } - _ = s.sendEmail(recipients, msg) + if sendErr := s.sendEmail(recipients, msg); sendErr != nil { + log.Printf("Unable to send contact request email: %v", sendErr) + } } } diff --git a/api/internal/httpapi/handler_stats.go b/api/internal/httpapi/handler_stats.go index 72d0f75..17ff44e 100644 --- a/api/internal/httpapi/handler_stats.go +++ b/api/internal/httpapi/handler_stats.go @@ -32,9 +32,9 @@ func NewStatsHandler(statsService StatsService) http.Handler { } statsService.Record(payload) - log.Printf("Stats request from %s", payload["ip"]) + log.Printf("Stats request from %q", payload["ip"]) if payload["type"] == "increment" { - log.Printf("afn.stats-hits.%s from %s", payload["key"], payload["ip"]) + log.Printf("afn.stats-hits.%q from %q", payload["key"], payload["ip"]) } _, _ = w.Write([]byte(fmt.Sprint("OK"))) diff --git a/api/internal/httpapi/middleware_auth.go b/api/internal/httpapi/middleware_auth.go index 12f978d..fdfd8aa 100644 --- a/api/internal/httpapi/middleware_auth.go +++ b/api/internal/httpapi/middleware_auth.go @@ -1,6 +1,7 @@ package httpapi import ( + "crypto/subtle" "log" "net/http" "strings" @@ -28,8 +29,17 @@ func IsOpenResource(r *http.Request) bool { } func Authenticate(w http.ResponseWriter, r *http.Request, username string, password string) bool { + if username == "" || password == "" { + w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`) + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("Unauthorized")) + return false + } + requestUsername, requestPassword, ok := r.BasicAuth() - if !ok || requestUsername != username || requestPassword != password { + validUsername := subtle.ConstantTimeCompare([]byte(requestUsername), []byte(username)) == 1 + validPassword := subtle.ConstantTimeCompare([]byte(requestPassword), []byte(password)) == 1 + if !ok || !validUsername || !validPassword { w.Header().Set("WWW-Authenticate", `Basic realm="restricted"`) w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte("Unauthorized")) diff --git a/api/internal/httpapi/middleware_auth_test.go b/api/internal/httpapi/middleware_auth_test.go index d7aab6c..684000f 100644 --- a/api/internal/httpapi/middleware_auth_test.go +++ b/api/internal/httpapi/middleware_auth_test.go @@ -71,4 +71,17 @@ func TestAuthenticate(t *testing.T) { t.Fatalf("expected unauthorized status, got %d", w.Code) } }) + + t.Run("rejects empty configured credentials", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) + req.SetBasicAuth("", "") + w := httptest.NewRecorder() + + if Authenticate(w, req, "", "") { + t.Fatal("expected authentication to fail with empty config") + } + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected unauthorized status, got %d", w.Code) + } + }) } diff --git a/api/internal/resources/schema.go b/api/internal/resources/schema.go index aa8a7cd..8b12aee 100644 --- a/api/internal/resources/schema.go +++ b/api/internal/resources/schema.go @@ -14,6 +14,8 @@ const SyntaxesCollection = "syntaxes" const SettingsCollection = "settings" const ContactRequestsCollection = "contact_requests" +var errInvalidEmailFormat = errors.New("invalid email format") + func MimeTypeSchema() schema.Schema { return schema.Schema{ Description: "The mime_type object", @@ -134,14 +136,14 @@ type emailValidator struct{} func (emailValidator) Validate(value interface{}) (interface{}, error) { email, ok := value.(string) if !ok { - return value, errors.New("Invalid email format") + return value, errInvalidEmailFormat } if strings.ContainsAny(email, "\r\n") { - return email, errors.New("Invalid email format") + return email, errInvalidEmailFormat } addr, err := mail.ParseAddress(email) if err != nil || addr.Address != email { - return email, errors.New("Invalid email format") + return email, errInvalidEmailFormat } return email, nil } diff --git a/api/internal/stats/service.go b/api/internal/stats/service.go index 2496048..9b6ddb1 100644 --- a/api/internal/stats/service.go +++ b/api/internal/stats/service.go @@ -15,6 +15,7 @@ var ErrInvalidPayload = errors.New("invalid payload") var ErrInvalidJSON = errors.New("invalid json") var remoteAddrRegex = regexp.MustCompile(`^([0-9.]+):`) +var metricKeyRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]{1,64}$`) type Metrics interface { Increment(bucket string) @@ -56,7 +57,9 @@ func (s *Service) Record(payload map[string]string) { s.metrics.Increment("afn.stats-hits." + ipKey) if payload["type"] == "increment" { - s.metrics.Increment(payload["key"]) + if metricKeyRegex.MatchString(payload["key"]) { + s.metrics.Increment(payload["key"]) + } } } diff --git a/api/internal/stats/service_test.go b/api/internal/stats/service_test.go index 813e404..f776b61 100644 --- a/api/internal/stats/service_test.go +++ b/api/internal/stats/service_test.go @@ -105,4 +105,10 @@ func TestRecord(t *testing.T) { if len(stub.keys) != 1 || stub.keys[0] != "afn.stats-hits.2001_db8__1" { t.Fatalf("expected sanitized ipv6 metric key, got %#v", stub.keys) } + + stub.keys = nil + svc.Record(map[string]string{"ip": "192.0.2.15", "type": "increment", "key": "bad:key"}) + if len(stub.keys) != 1 || stub.keys[0] != "afn.stats-hits.192_0_2_15" { + t.Fatalf("expected invalid metric key to be ignored, got %#v", stub.keys) + } } From 37e5635d5d20bb028f62e013b92c513500987b47 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sat, 4 Apr 2026 23:44:28 +0000 Subject: [PATCH 06/10] fix(resources): use pointer for emailValidator and reorder make in dev script --- api/internal/resources/schema.go | 2 +- docker4dev/docker-compose.yaml | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/internal/resources/schema.go b/api/internal/resources/schema.go index 8b12aee..3f2608a 100644 --- a/api/internal/resources/schema.go +++ b/api/internal/resources/schema.go @@ -117,7 +117,7 @@ func ContactRequestSchema() schema.Schema { "contact_email": { Required: true, Filterable: true, - Validator: emailValidator{}, + Validator: &emailValidator{}, }, "message": { Required: true, diff --git a/docker4dev/docker-compose.yaml b/docker4dev/docker-compose.yaml index 4ff79fb..952b5c4 100644 --- a/docker4dev/docker-compose.yaml +++ b/docker4dev/docker-compose.yaml @@ -20,10 +20,10 @@ services: apt update && apt install -y inotify-tools while true; do cd /src/api + make ./api & inotifywait -r -e create,modify,delete . pkill -9 -f api - make done build: context: ../api/ @@ -88,4 +88,3 @@ services: volumes: api: client: - From e760e958597a9cd861abb9f2e380e6cc5b866cc9 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sun, 5 Apr 2026 00:06:35 +0000 Subject: [PATCH 07/10] feat: add optional TLS email sending and HTTP access log middleware --- api/internal/app/email.go | 89 +++++++++++++ api/internal/app/email_test.go | 121 ++++++++++++++++++ api/internal/app/run.go | 3 +- api/internal/httpapi/middleware_access_log.go | 34 +++++ .../httpapi/middleware_access_log_test.go | 52 ++++++++ api/internal/httpapi/router.go | 2 +- pages/contact.markdown | 2 +- utils/utils.go | 70 +++++++++- utils/utils_test.go | 55 ++++++++ webserver/billing_handlers.go | 3 +- webserver/email.go | 89 +++++++++++++ webserver/email_test.go | 121 ++++++++++++++++++ 12 files changed, 632 insertions(+), 9 deletions(-) create mode 100644 api/internal/app/email.go create mode 100644 api/internal/app/email_test.go create mode 100644 api/internal/httpapi/middleware_access_log.go create mode 100644 api/internal/httpapi/middleware_access_log_test.go create mode 100644 webserver/email.go create mode 100644 webserver/email_test.go diff --git a/api/internal/app/email.go b/api/internal/app/email.go new file mode 100644 index 0000000..6f34dbd --- /dev/null +++ b/api/internal/app/email.go @@ -0,0 +1,89 @@ +package app + +import ( + "crypto/tls" + "fmt" + "net" + "net/smtp" + "os" + "strconv" + "strings" +) + +var smtpSendMail = smtp.SendMail +var smtpSendMailWithTLSConfig = sendMailWithTLSConfig + +func sendEmailWithOptionalTLS(to []string, msg []byte) error { + host := os.Getenv("SMTP_HOST") + port := os.Getenv("SMTP_PORT") + from := os.Getenv("SMTP_FROM") + user := os.Getenv("SMTP_USER") + password := os.Getenv("SMTP_PASSWORD") + addr := net.JoinHostPort(host, port) + + var auth smtp.Auth + if user != "" || password != "" { + auth = smtp.PlainAuth("", user, password, host) + } + + rawSkipTLSVerify := strings.TrimSpace(os.Getenv("SMTP_SKIP_TLS_VERIFY")) + rawSkipTLSVerify = strings.Trim(rawSkipTLSVerify, "\"'") + skipTLSVerify, _ := strconv.ParseBool(rawSkipTLSVerify) + + err := error(nil) + if skipTLSVerify { + err = smtpSendMailWithTLSConfig(addr, host, from, to, msg, auth, true) + } else { + err = smtpSendMail(addr, auth, from, to, msg) + } + if err != nil { + fmt.Println("ERROR: Unable to send email:", err) + return err + } + + return nil +} + +func sendMailWithTLSConfig(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + client, err := smtp.Dial(addr) + if err != nil { + return err + } + defer client.Close() + + tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: skipTLSVerify} + if err := client.StartTLS(tlsConfig); err != nil { + return err + } + + if auth != nil { + if err := client.Auth(auth); err != nil { + return err + } + } + + if err := client.Mail(from); err != nil { + return err + } + + for _, recipient := range to { + if err := client.Rcpt(recipient); err != nil { + return err + } + } + + writer, err := client.Data() + if err != nil { + return err + } + + if _, err := writer.Write(msg); err != nil { + return err + } + + if err := writer.Close(); err != nil { + return err + } + + return client.Quit() +} diff --git a/api/internal/app/email_test.go b/api/internal/app/email_test.go new file mode 100644 index 0000000..17d1c2c --- /dev/null +++ b/api/internal/app/email_test.go @@ -0,0 +1,121 @@ +package app + +import ( + "errors" + "net/smtp" + "testing" +) + +func TestSendEmailWithOptionalTLS(t *testing.T) { + t.Setenv("SMTP_USER", "smtp-user") + t.Setenv("SMTP_PASSWORD", "smtp-pass") + t.Setenv("SMTP_HOST", "smtp.example.com") + t.Setenv("SMTP_PORT", "2525") + t.Setenv("SMTP_FROM", "noreply@example.com") + + originalSender := smtpSendMail + originalSenderWithTLSConfig := smtpSendMailWithTLSConfig + t.Cleanup(func() { + smtpSendMail = originalSender + smtpSendMailWithTLSConfig = originalSenderWithTLSConfig + }) + + t.Run("successful send with default sender", func(t *testing.T) { + called := false + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + called = true + if addr != "smtp.example.com:2525" { + t.Fatalf("unexpected smtp addr: %s", addr) + } + if from != "noreply@example.com" { + t.Fatalf("unexpected from address: %s", from) + } + if len(to) != 1 || to[0] != "support@example.com" { + t.Fatalf("unexpected recipients: %#v", to) + } + if string(msg) != "Subject: test\n\nHello" { + t.Fatalf("unexpected message body: %q", string(msg)) + } + if auth == nil { + t.Fatal("expected smtp auth to be created") + } + return nil + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("Subject: test\n\nHello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("expected smtp sender to be called") + } + }) + + t.Run("empty credentials disable smtp auth", func(t *testing.T) { + t.Setenv("SMTP_USER", "") + t.Setenv("SMTP_PASSWORD", "") + + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + if auth != nil { + t.Fatal("expected smtp auth to be nil") + } + return nil + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("sender error bubbles up", func(t *testing.T) { + expectedErr := errors.New("smtp unavailable") + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + return expectedErr + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("msg")) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got %v", expectedErr, err) + } + }) + + t.Run("quoted true enables tls skip verify sender", func(t *testing.T) { + t.Setenv("SMTP_SKIP_TLS_VERIFY", "'true'") + + normalSenderCalled := false + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + normalSenderCalled = true + return nil + } + + customSenderCalled := false + smtpSendMailWithTLSConfig = func(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + customSenderCalled = true + if addr != "smtp.example.com:2525" { + t.Fatalf("unexpected smtp addr: %s", addr) + } + if host != "smtp.example.com" { + t.Fatalf("unexpected smtp host: %s", host) + } + if !skipTLSVerify { + t.Fatal("expected skipTLSVerify to be true") + } + return nil + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !customSenderCalled { + t.Fatal("expected custom tls sender to be called") + } + if normalSenderCalled { + t.Fatal("did not expect normal sender to be called") + } + }) +} diff --git a/api/internal/app/run.go b/api/internal/app/run.go index 6d3ddf3..276230d 100644 --- a/api/internal/app/run.go +++ b/api/internal/app/run.go @@ -9,7 +9,6 @@ import ( "github.com/julsemaan/anyfile-notepad/api/internal/httpapi" "github.com/julsemaan/anyfile-notepad/api/internal/resources" "github.com/julsemaan/anyfile-notepad/api/internal/stats" - "github.com/julsemaan/anyfile-notepad/utils" cache "github.com/patrickmn/go-cache" "github.com/rs/rest-layer/resource" "github.com/rs/rest-layer/rest" @@ -31,7 +30,7 @@ func Run(cfg Config) error { statsService := stats.NewService(statsConn) contactCache := cache.New(24*time.Hour, time.Minute) - contactService := contact.NewService(contactCache, cfg.MaxContactRequestsPerDay, cfg.SupportEmail, utils.SendEmail) + contactService := contact.NewService(contactCache, cfg.MaxContactRequestsPerDay, cfg.SupportEmail, sendEmailWithOptionalTLS) index := resources.BuildIndex(cfg.DataDir, resources.ContactHooks{ Insert: resource.InsertEventHandlerFunc(contactService.BeforeInsert), diff --git a/api/internal/httpapi/middleware_access_log.go b/api/internal/httpapi/middleware_access_log.go new file mode 100644 index 0000000..cb15d92 --- /dev/null +++ b/api/internal/httpapi/middleware_access_log.go @@ -0,0 +1,34 @@ +package httpapi + +import ( + "log" + "net/http" + "time" +) + +type accessLogResponseWriter struct { + http.ResponseWriter + status int +} + +func (w *accessLogResponseWriter) WriteHeader(statusCode int) { + w.status = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + +func withAccessLog(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + startedAt := time.Now() + wrappedWriter := &accessLogResponseWriter{ResponseWriter: w, status: http.StatusOK} + + next.ServeHTTP(wrappedWriter, r) + + log.Printf( + "access method=%s path=%s status=%d duration=%s", + r.Method, + r.URL.Path, + wrappedWriter.status, + time.Since(startedAt).Round(time.Millisecond), + ) + }) +} diff --git a/api/internal/httpapi/middleware_access_log_test.go b/api/internal/httpapi/middleware_access_log_test.go new file mode 100644 index 0000000..0e51801 --- /dev/null +++ b/api/internal/httpapi/middleware_access_log_test.go @@ -0,0 +1,52 @@ +package httpapi + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestWithAccessLog(t *testing.T) { + t.Run("logs response status", func(t *testing.T) { + var logBuffer bytes.Buffer + originalWriter := log.Writer() + defer log.SetOutput(originalWriter) + log.SetOutput(&logBuffer) + + handler := withAccessLog(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + req := httptest.NewRequest(http.MethodGet, "/syntaxes", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + output := logBuffer.String() + if !strings.Contains(output, "access method=GET path=/syntaxes status=202") { + t.Fatalf("expected access log line for status 202, got %q", output) + } + }) + + t.Run("logs default status when no explicit write header", func(t *testing.T) { + var logBuffer bytes.Buffer + originalWriter := log.Writer() + defer log.SetOutput(originalWriter) + log.SetOutput(&logBuffer) + + handler := withAccessLog(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + })) + + req := httptest.NewRequest(http.MethodGet, "/stats", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + output := logBuffer.String() + if !strings.Contains(output, "access method=GET path=/stats status=200") { + t.Fatalf("expected access log line for status 200, got %q", output) + } + }) +} diff --git a/api/internal/httpapi/router.go b/api/internal/httpapi/router.go index 435e476..f13197b 100644 --- a/api/internal/httpapi/router.go +++ b/api/internal/httpapi/router.go @@ -21,5 +21,5 @@ func NewRouter(apiHandler http.Handler, statsHandler http.Handler, username stri apiHandler.ServeHTTP(w, r) }) - return withCORS(handler) + return withAccessLog(withCORS(handler)) } diff --git a/pages/contact.markdown b/pages/contact.markdown index 60ab74e..36cdab7 100644 --- a/pages/contact.markdown +++ b/pages/contact.markdown @@ -61,7 +61,7 @@ $(function(){ e.preventDefault(); $.ajax({ type: "POST", - url: "https://api.anyfile-notepad.semaan.ca/contact_requests", + url: "http://localhost:8001/contact_requests", data: JSON.stringify({ "contact_email":$('#contact_email').val(), "message":$('#message').val(), diff --git a/utils/utils.go b/utils/utils.go index ec209dc..d97b934 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,19 +1,39 @@ package utils import ( + "crypto/tls" "fmt" + "net" "net/smtp" "os" + "strconv" ) var smtpSendMail = smtp.SendMail +var smtpSendMailWithTLSConfig = sendMailWithTLSConfig func SendEmail(to []string, msg []byte) error { - // Choose auth method and set it up - auth := smtp.PlainAuth("", os.Getenv("SMTP_USER"), os.Getenv("SMTP_PASSWORD"), os.Getenv("SMTP_HOST")) + host := os.Getenv("SMTP_HOST") + port := os.Getenv("SMTP_PORT") + from := os.Getenv("SMTP_FROM") + user := os.Getenv("SMTP_USER") + password := os.Getenv("SMTP_PASSWORD") + addr := net.JoinHostPort(host, port) + + var auth smtp.Auth + if user != "" || password != "" { + auth = smtp.PlainAuth("", user, password, host) + } + + skipTLSVerify, _ := strconv.ParseBool(os.Getenv("SMTP_SKIP_TLS_VERIFY")) // Here we do it all: connect to our server, set up a message and send it - err := smtpSendMail(os.Getenv("SMTP_HOST")+":"+os.Getenv("SMTP_PORT"), auth, os.Getenv("SMTP_FROM"), to, msg) + err := error(nil) + if skipTLSVerify { + err = smtpSendMailWithTLSConfig(addr, host, from, to, msg, auth, true) + } else { + err = smtpSendMail(addr, auth, from, to, msg) + } if err != nil { fmt.Println("ERROR: Unable to send email:", err) return err @@ -21,3 +41,47 @@ func SendEmail(to []string, msg []byte) error { return nil } + +func sendMailWithTLSConfig(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + client, err := smtp.Dial(addr) + if err != nil { + return err + } + defer client.Close() + + tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: skipTLSVerify} + if err := client.StartTLS(tlsConfig); err != nil { + return err + } + + if auth != nil { + if err := client.Auth(auth); err != nil { + return err + } + } + + if err := client.Mail(from); err != nil { + return err + } + + for _, recipient := range to { + if err := client.Rcpt(recipient); err != nil { + return err + } + } + + writer, err := client.Data() + if err != nil { + return err + } + + if _, err := writer.Write(msg); err != nil { + return err + } + + if err := writer.Close(); err != nil { + return err + } + + return client.Quit() +} diff --git a/utils/utils_test.go b/utils/utils_test.go index e42cdd7..8ac5de6 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -14,8 +14,10 @@ func TestSendEmail(t *testing.T) { t.Setenv("SMTP_FROM", "noreply@example.com") originalSender := smtpSendMail + originalSenderWithTLSConfig := smtpSendMailWithTLSConfig t.Cleanup(func() { smtpSendMail = originalSender + smtpSendMailWithTLSConfig = originalSenderWithTLSConfig }) t.Run("successful send", func(t *testing.T) { @@ -49,6 +51,23 @@ func TestSendEmail(t *testing.T) { } }) + t.Run("empty credentials disable smtp auth", func(t *testing.T) { + t.Setenv("SMTP_USER", "") + t.Setenv("SMTP_PASSWORD", "") + + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + if auth != nil { + t.Fatal("expected smtp auth to be nil") + } + return nil + } + + err := SendEmail([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + t.Run("sender error bubbles up", func(t *testing.T) { expectedErr := errors.New("smtp unavailable") smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { @@ -63,4 +82,40 @@ func TestSendEmail(t *testing.T) { t.Fatalf("expected %v, got %v", expectedErr, err) } }) + + t.Run("skip tls verify uses custom tls sender", func(t *testing.T) { + t.Setenv("SMTP_SKIP_TLS_VERIFY", "true") + + normalSenderCalled := false + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + normalSenderCalled = true + return nil + } + + customSenderCalled := false + smtpSendMailWithTLSConfig = func(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + customSenderCalled = true + if addr != "smtp.example.com:2525" { + t.Fatalf("unexpected smtp addr: %s", addr) + } + if host != "smtp.example.com" { + t.Fatalf("unexpected smtp host: %s", host) + } + if !skipTLSVerify { + t.Fatal("expected skipTLSVerify to be true") + } + return nil + } + + err := SendEmail([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !customSenderCalled { + t.Fatal("expected custom tls sender to be called") + } + if normalSenderCalled { + t.Fatal("did not expect normal sender to be called") + } + }) } diff --git a/webserver/billing_handlers.go b/webserver/billing_handlers.go index 76bf6fa..0f814b1 100644 --- a/webserver/billing_handlers.go +++ b/webserver/billing_handlers.go @@ -14,7 +14,6 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/gin-gonic/gin" "github.com/inverse-inc/packetfence/go/sharedutils" - "github.com/julsemaan/anyfile-notepad/utils" stripe "github.com/stripe/stripe-go" "github.com/stripe/stripe-go/customer" "github.com/stripe/stripe-go/sub" @@ -32,7 +31,7 @@ var ( stripeCustomerUpdate = customer.Update stripeCustomerNew = customer.New stripeWebhookConstructEvent = webhook.ConstructEvent - sendEmail = utils.SendEmail + sendEmail = sendEmailWithOptionalTLS generateCancelLinkID = func() string { return secureRandomString(16) } ) diff --git a/webserver/email.go b/webserver/email.go new file mode 100644 index 0000000..9d9eeef --- /dev/null +++ b/webserver/email.go @@ -0,0 +1,89 @@ +package main + +import ( + "crypto/tls" + "fmt" + "net" + "net/smtp" + "os" + "strconv" + "strings" +) + +var smtpSendMail = smtp.SendMail +var smtpSendMailWithTLSConfig = sendMailWithTLSConfig + +func sendEmailWithOptionalTLS(to []string, msg []byte) error { + host := os.Getenv("SMTP_HOST") + port := os.Getenv("SMTP_PORT") + from := os.Getenv("SMTP_FROM") + user := os.Getenv("SMTP_USER") + password := os.Getenv("SMTP_PASSWORD") + addr := net.JoinHostPort(host, port) + + var auth smtp.Auth + if user != "" || password != "" { + auth = smtp.PlainAuth("", user, password, host) + } + + rawSkipTLSVerify := strings.TrimSpace(os.Getenv("SMTP_SKIP_TLS_VERIFY")) + rawSkipTLSVerify = strings.Trim(rawSkipTLSVerify, "\"'") + skipTLSVerify, _ := strconv.ParseBool(rawSkipTLSVerify) + + err := error(nil) + if skipTLSVerify { + err = smtpSendMailWithTLSConfig(addr, host, from, to, msg, auth, true) + } else { + err = smtpSendMail(addr, auth, from, to, msg) + } + if err != nil { + fmt.Println("ERROR: Unable to send email:", err) + return err + } + + return nil +} + +func sendMailWithTLSConfig(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + client, err := smtp.Dial(addr) + if err != nil { + return err + } + defer client.Close() + + tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: skipTLSVerify} + if err := client.StartTLS(tlsConfig); err != nil { + return err + } + + if auth != nil { + if err := client.Auth(auth); err != nil { + return err + } + } + + if err := client.Mail(from); err != nil { + return err + } + + for _, recipient := range to { + if err := client.Rcpt(recipient); err != nil { + return err + } + } + + writer, err := client.Data() + if err != nil { + return err + } + + if _, err := writer.Write(msg); err != nil { + return err + } + + if err := writer.Close(); err != nil { + return err + } + + return client.Quit() +} diff --git a/webserver/email_test.go b/webserver/email_test.go new file mode 100644 index 0000000..a284840 --- /dev/null +++ b/webserver/email_test.go @@ -0,0 +1,121 @@ +package main + +import ( + "errors" + "net/smtp" + "testing" +) + +func TestSendEmailWithOptionalTLS(t *testing.T) { + t.Setenv("SMTP_USER", "smtp-user") + t.Setenv("SMTP_PASSWORD", "smtp-pass") + t.Setenv("SMTP_HOST", "smtp.example.com") + t.Setenv("SMTP_PORT", "2525") + t.Setenv("SMTP_FROM", "noreply@example.com") + + originalSender := smtpSendMail + originalSenderWithTLSConfig := smtpSendMailWithTLSConfig + t.Cleanup(func() { + smtpSendMail = originalSender + smtpSendMailWithTLSConfig = originalSenderWithTLSConfig + }) + + t.Run("successful send with default sender", func(t *testing.T) { + called := false + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + called = true + if addr != "smtp.example.com:2525" { + t.Fatalf("unexpected smtp addr: %s", addr) + } + if from != "noreply@example.com" { + t.Fatalf("unexpected from address: %s", from) + } + if len(to) != 1 || to[0] != "support@example.com" { + t.Fatalf("unexpected recipients: %#v", to) + } + if string(msg) != "Subject: test\n\nHello" { + t.Fatalf("unexpected message body: %q", string(msg)) + } + if auth == nil { + t.Fatal("expected smtp auth to be created") + } + return nil + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("Subject: test\n\nHello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("expected smtp sender to be called") + } + }) + + t.Run("empty credentials disable smtp auth", func(t *testing.T) { + t.Setenv("SMTP_USER", "") + t.Setenv("SMTP_PASSWORD", "") + + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + if auth != nil { + t.Fatal("expected smtp auth to be nil") + } + return nil + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("sender error bubbles up", func(t *testing.T) { + expectedErr := errors.New("smtp unavailable") + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + return expectedErr + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("msg")) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, expectedErr) { + t.Fatalf("expected %v, got %v", expectedErr, err) + } + }) + + t.Run("quoted true enables tls skip verify sender", func(t *testing.T) { + t.Setenv("SMTP_SKIP_TLS_VERIFY", "'true'") + + normalSenderCalled := false + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + normalSenderCalled = true + return nil + } + + customSenderCalled := false + smtpSendMailWithTLSConfig = func(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + customSenderCalled = true + if addr != "smtp.example.com:2525" { + t.Fatalf("unexpected smtp addr: %s", addr) + } + if host != "smtp.example.com" { + t.Fatalf("unexpected smtp host: %s", host) + } + if !skipTLSVerify { + t.Fatal("expected skipTLSVerify to be true") + } + return nil + } + + err := sendEmailWithOptionalTLS([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !customSenderCalled { + t.Fatal("expected custom tls sender to be called") + } + if normalSenderCalled { + t.Fatal("did not expect normal sender to be called") + } + }) +} From a3cb48c57d101fbece1554b26b2278dbda50d163 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sun, 5 Apr 2026 00:16:11 +0000 Subject: [PATCH 08/10] refactor(contact): move email sending from AfterInsert to BeforeInsert hook --- api/internal/contact/service.go | 32 ++++++++------- api/internal/contact/service_test.go | 58 +++++++++++++++++++++------- 2 files changed, 62 insertions(+), 28 deletions(-) diff --git a/api/internal/contact/service.go b/api/internal/contact/service.go index 6017f7b..0edff2a 100644 --- a/api/internal/contact/service.go +++ b/api/internal/contact/service.go @@ -14,6 +14,7 @@ import ( ) var errTooManyRequests = errors.New("Too many contact requests, try again later") +var errUnableToSendEmail = errors.New("Unable to process contact request, please try again later") var messageTemplate = template.Must(template.New("contact-email").Parse(`Subject: Anyfile Notepad - Message from {{.ReplyTo}} To: {{.Emails}} @@ -50,10 +51,20 @@ func (s *Service) BeforeInsert(_ context.Context, items []*resource.Item) error return errTooManyRequests } + if s.sendEmail == nil || s.supportEmail == "" { + return nil + } + + recipients := []string{s.supportEmail} for _, item := range items { - if s.cache != nil { - id := fmt.Sprint(item.ID) - s.cache.SetDefault(id, item.ID) + msg, buildErr := buildMessage(recipients, item) + if buildErr != nil { + log.Printf("Unable to build contact request email: %v", buildErr) + return errUnableToSendEmail + } + if sendErr := s.sendEmail(recipients, msg); sendErr != nil { + log.Printf("Unable to send contact request email: %v", sendErr) + return errUnableToSendEmail } } @@ -66,19 +77,10 @@ func (s *Service) AfterInsert(_ context.Context, items []*resource.Item, err *er return } - if s.sendEmail == nil || s.supportEmail == "" { - return - } - - recipients := []string{s.supportEmail} for _, item := range items { - msg, buildErr := buildMessage(recipients, item) - if buildErr != nil { - log.Printf("Unable to build contact request email: %v", buildErr) - continue - } - if sendErr := s.sendEmail(recipients, msg); sendErr != nil { - log.Printf("Unable to send contact request email: %v", sendErr) + if s.cache != nil { + id := fmt.Sprint(item.ID) + s.cache.SetDefault(id, item.ID) } } } diff --git a/api/internal/contact/service_test.go b/api/internal/contact/service_test.go index ab64cc3..887f284 100644 --- a/api/internal/contact/service_test.go +++ b/api/internal/contact/service_test.go @@ -34,10 +34,12 @@ func TestBeforeInsert(t *testing.T) { t.Fatalf("unexpected insert hook error: %v", err) } - if cache.ItemCount() != 2 { - t.Fatalf("expected cache to contain two entries, got %d", cache.ItemCount()) + if cache.ItemCount() != 0 { + t.Fatalf("expected cache to remain unchanged before insert, got %d", cache.ItemCount()) } + cache.SetDefault("existing-1", "existing-1") + cache.SetDefault("existing-2", "existing-2") err := svc.BeforeInsert(context.Background(), []*resource.Item{{ID: "req-3"}}) if err == nil { t.Fatal("expected insert hook to reject above threshold") @@ -48,25 +50,22 @@ func TestBeforeInsert(t *testing.T) { } func TestAfterInsertNoopOnError(t *testing.T) { - sent := false - svc := NewService(newCacheStub(), 10, "support@example.com", func([]string, []byte) error { - sent = true - return nil - }) + cache := newCacheStub() + svc := NewService(cache, 10, "support@example.com", nil) incomingErr := errors.New("storage failure") err := error(incomingErr) svc.AfterInsert(context.Background(), []*resource.Item{{ID: "req-1"}}, &err) - if sent { - t.Fatal("did not expect email to be sent") + if cache.ItemCount() != 0 { + t.Fatalf("did not expect cache to be updated, got %d", cache.ItemCount()) } if err == nil || err.Error() != "storage failure" { t.Fatalf("expected input error to be preserved, got %v", err) } } -func TestAfterInsertSendsEmail(t *testing.T) { +func TestBeforeInsertSendsEmail(t *testing.T) { sent := false svc := NewService(newCacheStub(), 10, "support@example.com", func(to []string, msg []byte) error { sent = true @@ -79,16 +78,49 @@ func TestAfterInsertSendsEmail(t *testing.T) { return nil }) - var err error - svc.AfterInsert(context.Background(), []*resource.Item{{ + if err := svc.BeforeInsert(context.Background(), []*resource.Item{{ ID: "req-1", Payload: map[string]interface{}{ "message": "Need help", "contact_email": "user@example.com", }, - }}, &err) + }}); err != nil { + t.Fatalf("unexpected insert hook error: %v", err) + } if !sent { t.Fatal("expected sendEmail to be called") } } + +func TestBeforeInsertFailsWhenEmailSendFails(t *testing.T) { + svc := NewService(newCacheStub(), 10, "support@example.com", func([]string, []byte) error { + return errors.New("smtp failure") + }) + + err := svc.BeforeInsert(context.Background(), []*resource.Item{{ + ID: "req-1", + Payload: map[string]interface{}{ + "message": "Need help", + "contact_email": "user@example.com", + }, + }}) + if err == nil { + t.Fatal("expected insert hook to fail when email send fails") + } + if err.Error() != "Unable to process contact request, please try again later" { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestAfterInsertCachesSuccessfulInsert(t *testing.T) { + cache := newCacheStub() + svc := NewService(cache, 10, "support@example.com", nil) + + var err error + svc.AfterInsert(context.Background(), []*resource.Item{{ID: "req-1"}, {ID: "req-2"}}, &err) + + if cache.ItemCount() != 2 { + t.Fatalf("expected cache to contain two entries, got %d", cache.ItemCount()) + } +} From 2c2a5eb123b4452e7dbc3b57a7068b1992fa0b8b Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Sun, 5 Apr 2026 00:44:32 +0000 Subject: [PATCH 09/10] fix(api): harden stats/contact paths and simplify build --- api/.gitignore | 2 - api/Makefile | 9 +--- api/internal/app/run.go | 6 ++- api/internal/contact/service.go | 7 ++- api/internal/contact/service_test.go | 8 ++-- api/internal/httpapi/handler_stats.go | 2 + api/internal/httpapi/handler_stats_test.go | 13 +++++ api/internal/httpapi/middleware_auth.go | 1 - api/internal/stats/service.go | 55 +++++++++++++++++++--- api/internal/stats/service_test.go | 16 +++++++ 10 files changed, 92 insertions(+), 27 deletions(-) diff --git a/api/.gitignore b/api/.gitignore index 79a74ae..6e59138 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -1,6 +1,4 @@ /api -afn-rest-32 -afn-rest-64 db/ vendor/*/ test-server.sh diff --git a/api/Makefile b/api/Makefile index eed8f04..ad99ba2 100644 --- a/api/Makefile +++ b/api/Makefile @@ -1,11 +1,4 @@ .PHONY: api api: - @if [ -d ./cmd/api ]; then \ - CGO_ENABLED=0 go build -v -o api ./cmd/api; \ - elif [ -d ./api/cmd/api ]; then \ - CGO_ENABLED=0 go build -v -o api ./api/cmd/api; \ - else \ - echo "cannot find api main package (expected ./cmd/api or ./api/cmd/api)"; \ - exit 1; \ - fi + CGO_ENABLED=0 go build -v -o api ./cmd/api diff --git a/api/internal/app/run.go b/api/internal/app/run.go index 276230d..d5efdbf 100644 --- a/api/internal/app/run.go +++ b/api/internal/app/run.go @@ -28,7 +28,11 @@ func Run(cfg Config) error { defer statsConn.Close() } - statsService := stats.NewService(statsConn) + var metrics stats.Metrics + if statsConn != nil { + metrics = statsConn + } + statsService := stats.NewService(metrics) contactCache := cache.New(24*time.Hour, time.Minute) contactService := contact.NewService(contactCache, cfg.MaxContactRequestsPerDay, cfg.SupportEmail, sendEmailWithOptionalTLS) diff --git a/api/internal/contact/service.go b/api/internal/contact/service.go index 0edff2a..265f290 100644 --- a/api/internal/contact/service.go +++ b/api/internal/contact/service.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "log" "strings" "text/template" @@ -13,8 +12,8 @@ import ( "github.com/rs/rest-layer/resource" ) -var errTooManyRequests = errors.New("Too many contact requests, try again later") -var errUnableToSendEmail = errors.New("Unable to process contact request, please try again later") +var errTooManyRequests = errors.New("too many contact requests, try again later") +var errUnableToSendEmail = errors.New("unable to process contact request, please try again later") var messageTemplate = template.Must(template.New("contact-email").Parse(`Subject: Anyfile Notepad - Message from {{.ReplyTo}} To: {{.Emails}} @@ -108,5 +107,5 @@ func buildMessage(recipients []string, item *resource.Item) ([]byte, error) { return nil, err } - return io.ReadAll(&msgBytes) + return msgBytes.Bytes(), nil } diff --git a/api/internal/contact/service_test.go b/api/internal/contact/service_test.go index 887f284..d824fd4 100644 --- a/api/internal/contact/service_test.go +++ b/api/internal/contact/service_test.go @@ -44,8 +44,8 @@ func TestBeforeInsert(t *testing.T) { if err == nil { t.Fatal("expected insert hook to reject above threshold") } - if err.Error() != "Too many contact requests, try again later" { - t.Fatalf("unexpected error message: %v", err) + if !errors.Is(err, errTooManyRequests) { + t.Fatalf("expected errTooManyRequests, got %v", err) } } @@ -108,8 +108,8 @@ func TestBeforeInsertFailsWhenEmailSendFails(t *testing.T) { if err == nil { t.Fatal("expected insert hook to fail when email send fails") } - if err.Error() != "Unable to process contact request, please try again later" { - t.Fatalf("unexpected error message: %v", err) + if !errors.Is(err, errUnableToSendEmail) { + t.Fatalf("expected errUnableToSendEmail, got %v", err) } } diff --git a/api/internal/httpapi/handler_stats.go b/api/internal/httpapi/handler_stats.go index 17ff44e..18b1b9f 100644 --- a/api/internal/httpapi/handler_stats.go +++ b/api/internal/httpapi/handler_stats.go @@ -21,6 +21,8 @@ func NewStatsHandler(statsService StatsService) http.Handler { payload, err := statsService.ParsePayload(r) if err != nil { switch { + case errors.Is(err, stats.ErrPayloadTooLarge): + http.Error(w, "Payload too large", http.StatusRequestEntityTooLarge) case errors.Is(err, stats.ErrInvalidPayload): http.Error(w, "Invalid payload", http.StatusBadRequest) case errors.Is(err, stats.ErrInvalidJSON): diff --git a/api/internal/httpapi/handler_stats_test.go b/api/internal/httpapi/handler_stats_test.go index 00e816e..7769fce 100644 --- a/api/internal/httpapi/handler_stats_test.go +++ b/api/internal/httpapi/handler_stats_test.go @@ -73,6 +73,19 @@ func TestStatsHandler(t *testing.T) { } }) + t.Run("payload too large returns request entity too large", func(t *testing.T) { + stub := &statsServiceStub{err: stats.ErrPayloadTooLarge} + handler := NewStatsHandler(stub) + + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader("{}")) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("expected 413, got %d", w.Code) + } + }) + t.Run("unknown errors return bad request", func(t *testing.T) { stub := &statsServiceStub{err: errors.New("boom")} handler := NewStatsHandler(stub) diff --git a/api/internal/httpapi/middleware_auth.go b/api/internal/httpapi/middleware_auth.go index fdfd8aa..963c235 100644 --- a/api/internal/httpapi/middleware_auth.go +++ b/api/internal/httpapi/middleware_auth.go @@ -21,7 +21,6 @@ func IsOpenResource(r *http.Request) bool { } if r.Method == http.MethodGet { - log.Print("Allowing without authentication for namespace that don't modify resources") return true } diff --git a/api/internal/stats/service.go b/api/internal/stats/service.go index 9b6ddb1..46d4897 100644 --- a/api/internal/stats/service.go +++ b/api/internal/stats/service.go @@ -1,7 +1,6 @@ package stats import ( - "bytes" "encoding/json" "errors" "io" @@ -13,6 +12,10 @@ import ( var ErrInvalidPayload = errors.New("invalid payload") var ErrInvalidJSON = errors.New("invalid json") +var ErrPayloadTooLarge = errors.New("payload too large") + +const maxPayloadSizeBytes = 4 * 1024 +const unknownIPMetricKey = "unknown" var remoteAddrRegex = regexp.MustCompile(`^([0-9.]+):`) var metricKeyRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]{1,64}$`) @@ -30,14 +33,27 @@ func NewService(metrics Metrics) *Service { } func (s *Service) ParsePayload(r *http.Request) (map[string]string, error) { - body, err := io.ReadAll(r.Body) - if err != nil { + limitedBody := &io.LimitedReader{R: r.Body, N: maxPayloadSizeBytes + 1} + var payload map[string]string + decoder := json.NewDecoder(limitedBody) + if err := decoder.Decode(&payload); err != nil { + if limitedBody.N <= 0 { + return nil, ErrPayloadTooLarge + } + + var syntaxErr *json.SyntaxError + var unmarshalTypeErr *json.UnmarshalTypeError + if errors.As(err, &syntaxErr) || errors.As(err, &unmarshalTypeErr) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { + return nil, ErrInvalidJSON + } return nil, ErrInvalidPayload } - var payload map[string]string - decoder := json.NewDecoder(bytes.NewBuffer(body)) - if err := decoder.Decode(&payload); err != nil { + if limitedBody.N <= 0 { + return nil, ErrPayloadTooLarge + } + + if err := decoder.Decode(&struct{}{}); err != io.EOF { return nil, ErrInvalidJSON } if payload == nil { @@ -53,7 +69,7 @@ func (s *Service) Record(payload map[string]string) { return } - ipKey := strings.NewReplacer(".", "_", ":", "_").Replace(payload["ip"]) + ipKey := normalizeIPMetricKey(payload["ip"]) s.metrics.Increment("afn.stats-hits." + ipKey) if payload["type"] == "increment" { @@ -63,6 +79,31 @@ func (s *Service) Record(payload map[string]string) { } } +func normalizeIPMetricKey(raw string) string { + candidate := strings.TrimSpace(raw) + if candidate == "" { + return unknownIPMetricKey + } + + if host, _, err := net.SplitHostPort(candidate); err == nil { + candidate = host + } else { + candidate = strings.Trim(candidate, "[]") + } + + ip := net.ParseIP(candidate) + if ip == nil { + return unknownIPMetricKey + } + + ipKey := strings.NewReplacer(".", "_", ":", "_").Replace(ip.String()) + if ipKey == "" || len(ipKey) > 64 { + return unknownIPMetricKey + } + + return ipKey +} + func extractIP(r *http.Request) string { forwardedFor := r.Header.Get("X-Forwarded-For") if forwardedFor != "" { diff --git a/api/internal/stats/service_test.go b/api/internal/stats/service_test.go index f776b61..c58ce09 100644 --- a/api/internal/stats/service_test.go +++ b/api/internal/stats/service_test.go @@ -78,6 +78,16 @@ func TestParsePayload(t *testing.T) { } }) + t.Run("too large payload returns error", func(t *testing.T) { + tooLargeValue := strings.Repeat("a", int(maxPayloadSizeBytes)+32) + req := httptest.NewRequest(http.MethodPost, "/stats", strings.NewReader(`{"type":"increment","key":"`+tooLargeValue+`"}`)) + + _, err := svc.ParsePayload(req) + if !errors.Is(err, ErrPayloadTooLarge) { + t.Fatalf("expected ErrPayloadTooLarge, got %v", err) + } + }) + t.Run("body read failures return error", func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/stats", nil) req.Body = io.NopCloser(errReader{}) @@ -111,4 +121,10 @@ func TestRecord(t *testing.T) { if len(stub.keys) != 1 || stub.keys[0] != "afn.stats-hits.192_0_2_15" { t.Fatalf("expected invalid metric key to be ignored, got %#v", stub.keys) } + + stub.keys = nil + svc.Record(map[string]string{"ip": "not-an-ip"}) + if len(stub.keys) != 1 || stub.keys[0] != "afn.stats-hits.unknown" { + t.Fatalf("expected invalid ip metric key to be normalized, got %#v", stub.keys) + } } From 72874797b0d1f716ef046363a21457fb3d78ef22 Mon Sep 17 00:00:00 2001 From: Julien Semaan Date: Mon, 13 Apr 2026 23:35:22 +0000 Subject: [PATCH 10/10] fix: enforce contact batch limits and normalize SMTP TLS env flag --- api/internal/contact/service.go | 2 +- api/internal/contact/service_test.go | 14 +++++++++++++ pages/contact.markdown | 3 +-- utils/utils.go | 5 ++++- utils/utils_test.go | 30 ++++++++++++++++++++++++++++ 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/api/internal/contact/service.go b/api/internal/contact/service.go index 265f290..f340afd 100644 --- a/api/internal/contact/service.go +++ b/api/internal/contact/service.go @@ -46,7 +46,7 @@ func NewService(cache Cache, maxPerDay int, supportEmail string, sendEmail Sende } func (s *Service) BeforeInsert(_ context.Context, items []*resource.Item) error { - if s.cache != nil && s.cache.ItemCount() >= s.maxPerDay { + if s.cache != nil && s.cache.ItemCount()+len(items) > s.maxPerDay { return errTooManyRequests } diff --git a/api/internal/contact/service_test.go b/api/internal/contact/service_test.go index d824fd4..495853b 100644 --- a/api/internal/contact/service_test.go +++ b/api/internal/contact/service_test.go @@ -49,6 +49,20 @@ func TestBeforeInsert(t *testing.T) { } } +func TestBeforeInsertRejectsBatchThatExceedsLimit(t *testing.T) { + cache := newCacheStub() + cache.SetDefault("existing-1", "existing-1") + + svc := NewService(cache, 2, "support@example.com", nil) + err := svc.BeforeInsert(context.Background(), []*resource.Item{{ID: "req-2"}, {ID: "req-3"}}) + if err == nil { + t.Fatal("expected insert hook to reject batch that exceeds limit") + } + if !errors.Is(err, errTooManyRequests) { + t.Fatalf("expected errTooManyRequests, got %v", err) + } +} + func TestAfterInsertNoopOnError(t *testing.T) { cache := newCacheStub() svc := NewService(cache, 10, "support@example.com", nil) diff --git a/pages/contact.markdown b/pages/contact.markdown index 36cdab7..2f0e632 100644 --- a/pages/contact.markdown +++ b/pages/contact.markdown @@ -61,7 +61,7 @@ $(function(){ e.preventDefault(); $.ajax({ type: "POST", - url: "http://localhost:8001/contact_requests", + url: "https://api.anyfile-notepad.semaan.ca/contact_requests", data: JSON.stringify({ "contact_email":$('#contact_email').val(), "message":$('#message').val(), @@ -86,4 +86,3 @@ $(function(){ - diff --git a/utils/utils.go b/utils/utils.go index d97b934..190729a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -7,6 +7,7 @@ import ( "net/smtp" "os" "strconv" + "strings" ) var smtpSendMail = smtp.SendMail @@ -25,7 +26,9 @@ func SendEmail(to []string, msg []byte) error { auth = smtp.PlainAuth("", user, password, host) } - skipTLSVerify, _ := strconv.ParseBool(os.Getenv("SMTP_SKIP_TLS_VERIFY")) + rawSkipTLSVerify := strings.TrimSpace(os.Getenv("SMTP_SKIP_TLS_VERIFY")) + rawSkipTLSVerify = strings.Trim(rawSkipTLSVerify, "\"'") + skipTLSVerify, _ := strconv.ParseBool(rawSkipTLSVerify) // Here we do it all: connect to our server, set up a message and send it err := error(nil) diff --git a/utils/utils_test.go b/utils/utils_test.go index 8ac5de6..f50e86e 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -118,4 +118,34 @@ func TestSendEmail(t *testing.T) { t.Fatal("did not expect normal sender to be called") } }) + + t.Run("skip tls verify accepts quoted bool", func(t *testing.T) { + t.Setenv("SMTP_SKIP_TLS_VERIFY", "'true'") + + normalSenderCalled := false + smtpSendMail = func(addr string, auth smtp.Auth, from string, to []string, msg []byte) error { + normalSenderCalled = true + return nil + } + + customSenderCalled := false + smtpSendMailWithTLSConfig = func(addr string, host string, from string, to []string, msg []byte, auth smtp.Auth, skipTLSVerify bool) error { + customSenderCalled = true + if !skipTLSVerify { + t.Fatal("expected skipTLSVerify to be true") + } + return nil + } + + err := SendEmail([]string{"support@example.com"}, []byte("msg")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !customSenderCalled { + t.Fatal("expected custom tls sender to be called") + } + if normalSenderCalled { + t.Fatal("did not expect normal sender to be called") + } + }) }