chore: refactor error handling

This commit is contained in:
nyyu 2025-06-18 15:48:51 +02:00
parent 4b44a2f2ec
commit bfd71045d1
2 changed files with 33 additions and 18 deletions

View File

@ -1,6 +1,7 @@
package lib
import (
"errors"
"fmt"
"math/rand"
"net/http"
@ -69,23 +70,12 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Authorization required"))
} else {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")
uri := r.Header.Get("X-Forwarded-Uri")
if proto == "" || host == "" || uri == "" {
s.respondWithStatus(w, r, "Missing required X-Forwarded-* headers", http.StatusBadRequest)
redirectURL, err := s.constructRedirectURL(r)
if err != nil {
s.respondWithStatus(w, r, err.Error(), http.StatusBadRequest)
return
}
// Check if host is allowed in RedirectDomains
if len(s.opts.RedirectDomains) > 0 && !slices.Contains(s.opts.RedirectDomains, host) {
s.respondWithStatus(w, r, "Redirect domain not allowed", http.StatusBadRequest)
return
}
redir := proto + "://" + host + uri
escapedURL := url.QueryEscape(redir)
http.Redirect(w, r, fmt.Sprintf("%s/.within.website/?redir=%s", s.opts.PublicUrl, escapedURL), http.StatusTemporaryRedirect)
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
}
return
}
@ -137,6 +127,24 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic
handler.ServeHTTP(w, r)
}
func (s *Server) constructRedirectURL(r *http.Request) (string, error) {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")
uri := r.Header.Get("X-Forwarded-Uri")
if proto == "" || host == "" || uri == "" {
return "", errors.New("missing required X-Forwarded-* headers")
}
// Check if host is allowed in RedirectDomains
if len(s.opts.RedirectDomains) > 0 && !slices.Contains(s.opts.RedirectDomains, host) {
return "", errors.New("redirect domain not allowed")
}
redir := proto + "://" + host + uri
escapedURL := url.QueryEscape(redir)
return fmt.Sprintf("%s/.within.website/?redir=%s", s.opts.PublicUrl, escapedURL), nil
}
func (s *Server) RenderBench(w http.ResponseWriter, r *http.Request) {
templ.Handler(
web.Base("Benchmarking Anubis!", web.Bench()),
@ -190,8 +198,12 @@ func (s *Server) ServeHTTPNext(w http.ResponseWriter, r *http.Request) {
return
}
if (len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && !slices.Contains(s.opts.RedirectDomains, urlParsed.Host)) ||
(r.URL.Host != "" && urlParsed.Host != r.URL.Host) {
hostNotAllowed := len(urlParsed.Host) > 0 &&
len(s.opts.RedirectDomains) != 0 &&
!slices.Contains(s.opts.RedirectDomains, urlParsed.Host)
hostMismatch := r.URL.Host != "" && urlParsed.Host != r.URL.Host
if hostNotAllowed || hostMismatch {
s.respondWithStatus(w, r, "Redirect domain not allowed", http.StatusBadRequest)
return
}

View File

@ -77,7 +77,10 @@ func TestRenderIndexRedirect(t *testing.T) {
t.Errorf("expected status %d, got %d", http.StatusTemporaryRedirect, rr.Code)
}
location := rr.Header().Get("Location")
parsedURL, _ := url.Parse(location)
parsedURL, err := url.Parse(location)
if err != nil {
t.Fatalf("failed to parse location URL %q: %v", location, err)
}
scheme := "https"
if parsedURL.Scheme != scheme {