Skip to content

Commit 2304836

Browse files
committed
Add 2FA enforcement and changelog
1 parent d04d459 commit 2304836

6 files changed

Lines changed: 400 additions & 11 deletions

File tree

auth/middleware.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,51 @@ func ClearSessionCookie(w http.ResponseWriter) {
8484
}
8585
http.SetCookie(w, cookie)
8686
}
87+
88+
// TOTPEnforcement configures 2FA enforcement behavior.
89+
type TOTPEnforcement struct {
90+
Enabled func() bool
91+
GraceDays func() int
92+
SetupURL string
93+
}
94+
95+
// RequireTOTP is a middleware that enforces 2FA setup.
96+
// It should be used after RequireAuth.
97+
func RequireTOTP(cfg TOTPEnforcement) func(http.Handler) http.Handler {
98+
return func(next http.Handler) http.Handler {
99+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
100+
if cfg.Enabled == nil || !cfg.Enabled() {
101+
next.ServeHTTP(w, r)
102+
return
103+
}
104+
105+
user, ok := GetUser(r.Context())
106+
if !ok {
107+
next.ServeHTTP(w, r)
108+
return
109+
}
110+
111+
if user.TOTPEnabled {
112+
next.ServeHTTP(w, r)
113+
return
114+
}
115+
116+
graceDays := 0
117+
if cfg.GraceDays != nil {
118+
graceDays = cfg.GraceDays()
119+
}
120+
121+
if user.InTOTPGracePeriod(graceDays) {
122+
next.ServeHTTP(w, r)
123+
return
124+
}
125+
126+
setupURL := cfg.SetupURL
127+
if setupURL == "" {
128+
setupURL = "/totp-setup"
129+
}
130+
131+
http.Redirect(w, r, setupURL, http.StatusSeeOther)
132+
})
133+
}
134+
}

auth/middleware_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http"
66
"net/http/httptest"
77
"testing"
8+
"time"
89

910
"github.com/hatmaxkit/hatmax/config"
1011
"github.com/hatmaxkit/hatmax/log"
@@ -227,3 +228,144 @@ func TestClearSessionCookie(t *testing.T) {
227228
t.Errorf("ClearSessionCookie() cookie.MaxAge = %v, want -1", cookie.MaxAge)
228229
}
229230
}
231+
232+
func TestRequireTOTP(t *testing.T) {
233+
tests := []struct {
234+
name string
235+
cfg TOTPEnforcement
236+
user *User
237+
wantStatus int
238+
wantRedirect string
239+
}{
240+
{
241+
name: "enforcement disabled",
242+
cfg: TOTPEnforcement{
243+
Enabled: func() bool { return false },
244+
},
245+
user: &User{TOTPEnabled: false},
246+
wantStatus: http.StatusOK,
247+
},
248+
{
249+
name: "enforcement nil",
250+
cfg: TOTPEnforcement{},
251+
user: &User{TOTPEnabled: false},
252+
wantStatus: http.StatusOK,
253+
},
254+
{
255+
name: "totp enabled",
256+
cfg: TOTPEnforcement{
257+
Enabled: func() bool { return true },
258+
},
259+
user: &User{TOTPEnabled: true},
260+
wantStatus: http.StatusOK,
261+
},
262+
{
263+
name: "totp not enabled redirects",
264+
cfg: TOTPEnforcement{
265+
Enabled: func() bool { return true },
266+
GraceDays: func() int { return 0 },
267+
},
268+
user: &User{TOTPEnabled: false},
269+
wantStatus: http.StatusSeeOther,
270+
wantRedirect: "/totp-setup",
271+
},
272+
{
273+
name: "custom setup url",
274+
cfg: TOTPEnforcement{
275+
Enabled: func() bool { return true },
276+
GraceDays: func() int { return 0 },
277+
SetupURL: "/settings/2fa",
278+
},
279+
user: &User{TOTPEnabled: false},
280+
wantStatus: http.StatusSeeOther,
281+
wantRedirect: "/settings/2fa",
282+
},
283+
{
284+
name: "no user in context passes",
285+
cfg: TOTPEnforcement{
286+
Enabled: func() bool { return true },
287+
},
288+
user: nil,
289+
wantStatus: http.StatusOK,
290+
},
291+
}
292+
293+
for _, tt := range tests {
294+
t.Run(tt.name, func(t *testing.T) {
295+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
296+
w.WriteHeader(http.StatusOK)
297+
})
298+
299+
middleware := RequireTOTP(tt.cfg)
300+
wrapped := middleware(handler)
301+
302+
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
303+
if tt.user != nil {
304+
ctx := WithUser(req.Context(), tt.user)
305+
req = req.WithContext(ctx)
306+
}
307+
308+
w := httptest.NewRecorder()
309+
wrapped.ServeHTTP(w, req)
310+
311+
if w.Code != tt.wantStatus {
312+
t.Errorf("RequireTOTP() status = %v, want %v", w.Code, tt.wantStatus)
313+
}
314+
315+
if tt.wantRedirect != "" {
316+
location := w.Header().Get("Location")
317+
if location != tt.wantRedirect {
318+
t.Errorf("RequireTOTP() redirect = %v, want %v", location, tt.wantRedirect)
319+
}
320+
}
321+
})
322+
}
323+
}
324+
325+
func TestRequireTOTPGracePeriod(t *testing.T) {
326+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
327+
w.WriteHeader(http.StatusOK)
328+
})
329+
330+
cfg := TOTPEnforcement{
331+
Enabled: func() bool { return true },
332+
GraceDays: func() int { return 7 },
333+
}
334+
335+
middleware := RequireTOTP(cfg)
336+
wrapped := middleware(handler)
337+
338+
// User created recently should pass
339+
recentUser := &User{
340+
TOTPEnabled: false,
341+
CreatedAt: time.Now().AddDate(0, 0, -3),
342+
}
343+
344+
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
345+
ctx := WithUser(req.Context(), recentUser)
346+
req = req.WithContext(ctx)
347+
348+
w := httptest.NewRecorder()
349+
wrapped.ServeHTTP(w, req)
350+
351+
if w.Code != http.StatusOK {
352+
t.Errorf("RequireTOTP() with grace period status = %v, want %v", w.Code, http.StatusOK)
353+
}
354+
355+
// User created long ago should redirect
356+
oldUser := &User{
357+
TOTPEnabled: false,
358+
CreatedAt: time.Now().AddDate(0, 0, -30),
359+
}
360+
361+
req = httptest.NewRequest(http.MethodGet, "/protected", nil)
362+
ctx = WithUser(req.Context(), oldUser)
363+
req = req.WithContext(ctx)
364+
365+
w = httptest.NewRecorder()
366+
wrapped.ServeHTTP(w, req)
367+
368+
if w.Code != http.StatusSeeOther {
369+
t.Errorf("RequireTOTP() outside grace period status = %v, want %v", w.Code, http.StatusSeeOther)
370+
}
371+
}

auth/models.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@ import "time"
44

55
// User represents an authenticated user in the system.
66
type User struct {
7-
ID string
8-
Email string
9-
PasswordHash string
10-
Roles []string
11-
Active bool
12-
CreatedAt time.Time
13-
UpdatedAt time.Time
7+
ID string
8+
Email string
9+
PasswordHash string
10+
Roles []string
11+
Active bool
12+
TOTPSecret string
13+
TOTPEnabled bool
14+
TOTPVerifiedAt *time.Time
15+
CreatedAt time.Time
16+
UpdatedAt time.Time
1417
}
1518

1619
// HasRole checks if the user has the specified role.
@@ -33,6 +36,20 @@ func (u *User) HasAnyRole(roles ...string) bool {
3336
return false
3437
}
3538

39+
// NeedsTOTPSetup returns true if the user has not set up TOTP yet.
40+
func (u *User) NeedsTOTPSetup() bool {
41+
return u.TOTPSecret == ""
42+
}
43+
44+
// InTOTPGracePeriod returns true if the user is within the grace period.
45+
func (u *User) InTOTPGracePeriod(days int) bool {
46+
if days <= 0 {
47+
return false
48+
}
49+
gracePeriodEnd := u.CreatedAt.AddDate(0, 0, days)
50+
return time.Now().Before(gracePeriodEnd)
51+
}
52+
3653
// Session represents a user session.
3754
type Session struct {
3855
ID string

auth/models_test.go

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package auth
22

3-
import "testing"
3+
import (
4+
"testing"
5+
"time"
6+
)
47

58
func TestUserHasRole(t *testing.T) {
69
tests := []struct {
@@ -113,3 +116,90 @@ func TestUserHasAnyRole(t *testing.T) {
113116
})
114117
}
115118
}
119+
120+
func TestUserNeedsTOTPSetup(t *testing.T) {
121+
tests := []struct {
122+
name string
123+
totpSecret string
124+
want bool
125+
}{
126+
{
127+
name: "no secret needs setup",
128+
totpSecret: "",
129+
want: true,
130+
},
131+
{
132+
name: "has secret no setup needed",
133+
totpSecret: "JBSWY3DPEHPK3PXP",
134+
want: false,
135+
},
136+
}
137+
138+
for _, tt := range tests {
139+
t.Run(tt.name, func(t *testing.T) {
140+
u := &User{TOTPSecret: tt.totpSecret}
141+
got := u.NeedsTOTPSetup()
142+
if got != tt.want {
143+
t.Errorf("NeedsTOTPSetup() = %v, want %v", got, tt.want)
144+
}
145+
})
146+
}
147+
}
148+
149+
func TestUserInTOTPGracePeriod(t *testing.T) {
150+
now := time.Now()
151+
152+
tests := []struct {
153+
name string
154+
createdAt time.Time
155+
graceDays int
156+
want bool
157+
}{
158+
{
159+
name: "within grace period",
160+
createdAt: now.AddDate(0, 0, -3),
161+
graceDays: 7,
162+
want: true,
163+
},
164+
{
165+
name: "outside grace period",
166+
createdAt: now.AddDate(0, 0, -10),
167+
graceDays: 7,
168+
want: false,
169+
},
170+
{
171+
name: "exactly at grace period end",
172+
createdAt: now.AddDate(0, 0, -7),
173+
graceDays: 7,
174+
want: false,
175+
},
176+
{
177+
name: "zero grace days",
178+
createdAt: now,
179+
graceDays: 0,
180+
want: false,
181+
},
182+
{
183+
name: "negative grace days",
184+
createdAt: now,
185+
graceDays: -1,
186+
want: false,
187+
},
188+
{
189+
name: "created today",
190+
createdAt: now,
191+
graceDays: 7,
192+
want: true,
193+
},
194+
}
195+
196+
for _, tt := range tests {
197+
t.Run(tt.name, func(t *testing.T) {
198+
u := &User{CreatedAt: tt.createdAt}
199+
got := u.InTOTPGracePeriod(tt.graceDays)
200+
if got != tt.want {
201+
t.Errorf("InTOTPGracePeriod(%d) = %v, want %v", tt.graceDays, got, tt.want)
202+
}
203+
})
204+
}
205+
}

0 commit comments

Comments
 (0)