From bb434a335163d54f27737cc7442b6e13b3abf43c Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Thu, 24 Jul 2025 11:24:58 -0400 Subject: [PATCH] fix(lib): add comprehensive XSS protection logic (#905) Signed-off-by: Xe Iaso --- lib/anubis.go | 15 +++++--- lib/anubis_test.go | 91 ++++++++++++++++++++++++++-------------------- lib/http.go | 10 ++--- 3 files changed, 65 insertions(+), 51 deletions(-) diff --git a/lib/anubis.go b/lib/anubis.go index eb88b70..dec822f 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -264,7 +264,7 @@ func (s *Server) checkRules(w http.ResponseWriter, r *http.Request, cr policy.Ch hash := rule.Hash() lg.Debug("rule hash", "hash", hash) - s.respondWithStatus(w, r, fmt.Sprintf("%s %s", localizer.T("access_denied"), hash), s.policy.StatusCodes.Deny) + s.respondWithStatus(w, r, fmt.Sprintf("%s %s", localizer.T("access_denied"), hash), "/", s.policy.StatusCodes.Deny) return true case config.RuleChallenge: lg.Debug("challenge requested") @@ -302,7 +302,7 @@ func (s *Server) handleDNSBL(w http.ResponseWriter, r *http.Request, ip string, localizer.T("dronebl_entry"), resp.String(), localizer.T("see_dronebl_lookup"), - ip), s.policy.StatusCodes.Deny) + ip), "/", s.policy.StatusCodes.Deny) return true } } @@ -388,13 +388,16 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { redirURL, err := url.ParseRequestURI(redir) if err != nil { lg.Error("invalid redirect", "err", err) - s.respondWithError(w, r, localizer.T("invalid_redirect")) + s.respondWithStatus(w, r, localizer.T("invalid_redirect"), "/", http.StatusBadRequest) return } - if redirURL.Scheme != "" && redirURL.Scheme != "http" && redirURL.Scheme != "https" { + switch redirURL.Scheme { + case "", "http", "https": + // allowed + default: lg.Error("XSS attempt blocked, invalid redirect scheme", "scheme", redirURL.Scheme) - s.respondWithStatus(w, r, localizer.T("invalid_redirect"), http.StatusBadRequest) + s.respondWithStatus(w, r, localizer.T("invalid_redirect"), "/", http.StatusBadRequest) return } @@ -463,7 +466,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { case errors.As(err, &cerr): switch { case errors.Is(err, challenge.ErrFailed): - s.respondWithStatus(w, r, cerr.PublicReason, cerr.StatusCode) + s.respondWithStatus(w, r, cerr.PublicReason, "/", cerr.StatusCode) case errors.Is(err, challenge.ErrInvalidFormat), errors.Is(err, challenge.ErrMissingField): s.respondWithError(w, r, cerr.PublicReason) } diff --git a/lib/anubis_test.go b/lib/anubis_test.go index d1e7212..cc20352 100644 --- a/lib/anubis_test.go +++ b/lib/anubis_test.go @@ -1,6 +1,7 @@ package lib import ( + "bytes" "encoding/json" "fmt" "io" @@ -834,49 +835,56 @@ func TestPassChallengeXSS(t *testing.T) { }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - nonce := 0 - elapsedTime := 420 - calculated := "" - calcString := fmt.Sprintf("%s%d", chall.Challenge, nonce) - calculated = internal.SHA256sum(calcString) + t.Run("with test cookie", func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + nonce := 0 + elapsedTime := 420 + calculated := "" + calcString := fmt.Sprintf("%s%d", chall.Challenge, nonce) + calculated = internal.SHA256sum(calcString) - req, err := http.NewRequest(http.MethodGet, ts.URL+"/.within.website/x/cmd/anubis/api/pass-challenge", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - q := req.URL.Query() - q.Set("response", calculated) - q.Set("nonce", fmt.Sprint(nonce)) - q.Set("redir", tc.redir) - q.Set("elapsedTime", fmt.Sprint(elapsedTime)) - req.URL.RawQuery = q.Encode() - - u, err := url.Parse(ts.URL) - if err != nil { - t.Fatal(err) - } - - for _, ckie := range cli.Jar.Cookies(u) { - if ckie.Name == anubis.TestCookieName { - req.AddCookie(ckie) + req, err := http.NewRequest(http.MethodGet, ts.URL+"/.within.website/x/cmd/anubis/api/pass-challenge", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) } - } - resp, err := cli.Do(req) - if err != nil { - t.Fatalf("can't do request: %v", err) - } + q := req.URL.Query() + q.Set("response", calculated) + q.Set("nonce", fmt.Sprint(nonce)) + q.Set("redir", tc.redir) + q.Set("elapsedTime", fmt.Sprint(elapsedTime)) + req.URL.RawQuery = q.Encode() - body, _ := io.ReadAll(resp.Body) + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("wanted status %d, got %d. body: %s", http.StatusBadRequest, resp.StatusCode, body) - } - }) - } + for _, ckie := range cli.Jar.Cookies(u) { + if ckie.Name == anubis.TestCookieName { + req.AddCookie(ckie) + } + } + + resp, err := cli.Do(req) + if err != nil { + t.Fatalf("can't do request: %v", err) + } + + body, _ := io.ReadAll(resp.Body) + + if bytes.Contains(body, []byte(tc.redir)) { + t.Log(string(body)) + t.Error("found XSS in HTML body") + } + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("wanted status %d, got %d. body: %s", http.StatusBadRequest, resp.StatusCode, body) + } + }) + } + }) t.Run("no test cookie", func(t *testing.T) { for _, tc := range testCases { @@ -899,8 +907,6 @@ func TestPassChallengeXSS(t *testing.T) { q.Set("elapsedTime", fmt.Sprint(elapsedTime)) req.URL.RawQuery = q.Encode() - // Do NOT add the test cookie here - resp, err := cli.Do(req) if err != nil { t.Fatalf("can't do request: %v", err) @@ -908,6 +914,11 @@ func TestPassChallengeXSS(t *testing.T) { body, _ := io.ReadAll(resp.Body) + if bytes.Contains(body, []byte(tc.redir)) { + t.Log(string(body)) + t.Error("found XSS in HTML body") + } + if resp.StatusCode != http.StatusBadRequest { t.Errorf("wanted status %d, got %d. body: %s", http.StatusBadRequest, resp.StatusCode, body) } diff --git a/lib/http.go b/lib/http.go index d1e8233..905724d 100644 --- a/lib/http.go +++ b/lib/http.go @@ -192,13 +192,13 @@ func (s *Server) RenderBench(w http.ResponseWriter, r *http.Request) { } func (s *Server) respondWithError(w http.ResponseWriter, r *http.Request, message string) { - s.respondWithStatus(w, r, message, http.StatusInternalServerError) + s.respondWithStatus(w, r, message, "/", http.StatusInternalServerError) } -func (s *Server) respondWithStatus(w http.ResponseWriter, r *http.Request, msg string, status int) { +func (s *Server) respondWithStatus(w http.ResponseWriter, r *http.Request, msg, redirect string, status int) { localizer := localization.GetLocalizer(r) - templ.Handler(web.Base(localizer.T("oh_noes"), web.ErrorPage(msg, s.opts.WebmasterEmail, r.FormValue("redir"), localizer), s.policy.Impressum, localizer), templ.WithStatus(status)).ServeHTTP(w, r) + templ.Handler(web.Base(localizer.T("oh_noes"), web.ErrorPage(msg, s.opts.WebmasterEmail, redirect, localizer), s.policy.Impressum, localizer), templ.WithStatus(status)).ServeHTTP(w, r) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -238,12 +238,12 @@ func (s *Server) ServeHTTPNext(w http.ResponseWriter, r *http.Request) { redir := r.FormValue("redir") urlParsed, err := r.URL.Parse(redir) if err != nil { - s.respondWithStatus(w, r, localizer.T("redirect_not_parseable"), http.StatusBadRequest) + s.respondWithStatus(w, r, localizer.T("redirect_not_parseable"), "/", http.StatusBadRequest) return } if (len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && !slices.Contains(s.opts.RedirectDomains, urlParsed.Host)) || urlParsed.Host != r.URL.Host { - s.respondWithStatus(w, r, localizer.T("redirect_domain_not_allowed"), http.StatusBadRequest) + s.respondWithStatus(w, r, localizer.T("redirect_domain_not_allowed"), "/", http.StatusBadRequest) return }