From bfd71045d1f77c9e2960fa03e8793bbdd7f2cb6b Mon Sep 17 00:00:00 2001 From: nyyu Date: Wed, 18 Jun 2025 15:48:51 +0200 Subject: [PATCH] chore: refactor error handling --- lib/http.go | 46 +++++++++++++++++++++++++++++----------------- lib/http_test.go | 5 ++++- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/lib/http.go b/lib/http.go index 905ab6d..9c16b3f 100644 --- a/lib/http.go +++ b/lib/http.go @@ -1,6 +1,7 @@ package lib import ( + "errors" "fmt" "math/rand" "net/http" @@ -69,23 +70,12 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("Authorization required")) } else { - proto := r.Header.Get("X-Forwarded-Proto") - host := r.Header.Get("X-Forwarded-Host") - uri := r.Header.Get("X-Forwarded-Uri") - - if proto == "" || host == "" || uri == "" { - s.respondWithStatus(w, r, "Missing required X-Forwarded-* headers", http.StatusBadRequest) + redirectURL, err := s.constructRedirectURL(r) + if err != nil { + s.respondWithStatus(w, r, err.Error(), http.StatusBadRequest) return } - // Check if host is allowed in RedirectDomains - if len(s.opts.RedirectDomains) > 0 && !slices.Contains(s.opts.RedirectDomains, host) { - s.respondWithStatus(w, r, "Redirect domain not allowed", http.StatusBadRequest) - return - } - - redir := proto + "://" + host + uri - escapedURL := url.QueryEscape(redir) - http.Redirect(w, r, fmt.Sprintf("%s/.within.website/?redir=%s", s.opts.PublicUrl, escapedURL), http.StatusTemporaryRedirect) + http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) } return } @@ -137,6 +127,24 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic handler.ServeHTTP(w, r) } +func (s *Server) constructRedirectURL(r *http.Request) (string, error) { + proto := r.Header.Get("X-Forwarded-Proto") + host := r.Header.Get("X-Forwarded-Host") + uri := r.Header.Get("X-Forwarded-Uri") + + if proto == "" || host == "" || uri == "" { + return "", errors.New("missing required X-Forwarded-* headers") + } + // Check if host is allowed in RedirectDomains + if len(s.opts.RedirectDomains) > 0 && !slices.Contains(s.opts.RedirectDomains, host) { + return "", errors.New("redirect domain not allowed") + } + + redir := proto + "://" + host + uri + escapedURL := url.QueryEscape(redir) + return fmt.Sprintf("%s/.within.website/?redir=%s", s.opts.PublicUrl, escapedURL), nil +} + func (s *Server) RenderBench(w http.ResponseWriter, r *http.Request) { templ.Handler( web.Base("Benchmarking Anubis!", web.Bench()), @@ -190,8 +198,12 @@ func (s *Server) ServeHTTPNext(w http.ResponseWriter, r *http.Request) { return } - if (len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && !slices.Contains(s.opts.RedirectDomains, urlParsed.Host)) || - (r.URL.Host != "" && urlParsed.Host != r.URL.Host) { + hostNotAllowed := len(urlParsed.Host) > 0 && + len(s.opts.RedirectDomains) != 0 && + !slices.Contains(s.opts.RedirectDomains, urlParsed.Host) + hostMismatch := r.URL.Host != "" && urlParsed.Host != r.URL.Host + + if hostNotAllowed || hostMismatch { s.respondWithStatus(w, r, "Redirect domain not allowed", http.StatusBadRequest) return } diff --git a/lib/http_test.go b/lib/http_test.go index 856d0a3..c4b2527 100644 --- a/lib/http_test.go +++ b/lib/http_test.go @@ -77,7 +77,10 @@ func TestRenderIndexRedirect(t *testing.T) { t.Errorf("expected status %d, got %d", http.StatusTemporaryRedirect, rr.Code) } location := rr.Header().Get("Location") - parsedURL, _ := url.Parse(location) + parsedURL, err := url.Parse(location) + if err != nil { + t.Fatalf("failed to parse location URL %q: %v", location, err) + } scheme := "https" if parsedURL.Scheme != scheme {