diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go index 724f88a..222e21f 100644 --- a/cmd/anubis/main.go +++ b/cmd/anubis/main.go @@ -20,7 +20,6 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "strconv" "strings" "sync" @@ -58,7 +57,7 @@ var ( ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough") ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time") extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder") - webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals") + webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals") ) func keyFromHex(value string) (ed25519.PrivateKey, error) { @@ -205,22 +204,17 @@ func main() { continue } - hash, err := rule.Hash() - if err != nil { - log.Fatalf("can't calculate checksum of rule %s: %v", rule.Name, err) - } - + hash := rule.Hash() fmt.Printf("* %s: %s\n", rule.Name, hash) } fmt.Println() // replace the bot policy rules with a single rule that always benchmarks if *debugBenchmarkJS { - userAgent := regexp.MustCompile(".") policy.Bots = []botPolicy.Bot{{ - Name: "", - UserAgent: userAgent, - Action: config.RuleBenchmark, + Name: "", + Rules: botPolicy.NewHeaderExistsChecker("User-Agent"), + Action: config.RuleBenchmark, }} } @@ -261,7 +255,7 @@ func main() { OGPassthrough: *ogPassthrough, OGTimeToLive: *ogTimeToLive, Target: *target, - WebmasterEmail: *webmasterEmail, + WebmasterEmail: *webmasterEmail, }) if err != nil { log.Fatalf("can't construct libanubis.Server: %v", err) diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 9200ade..6a90c81 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- Refactor check logic to be more generic and work on a Checker type - Add more AI user agents based on the [ai.robots.txt](https://github.com/ai-robots-txt/ai.robots.txt) project - Embedded challenge data in initial HTML response to improve performance - Whitelisted [DuckDuckBot](https://duckduckgo.com/duckduckgo-help-pages/results/duckduckbot/) in botPolicies diff --git a/lib/anubis.go b/lib/anubis.go index afc3d86..82e69d0 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" "io" - "log" "log/slog" "math" "net" @@ -238,12 +237,8 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)", s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return } - hash, err := rule.Hash() - if err != nil { - lg.Error("can't calculate checksum of rule", "err", err) - templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)", s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) - return - } + hash := rule.Hash() + lg.Debug("rule hash", "hash", hash) templ.Handler(web.Base("Oh noes!", web.ErrorPage(fmt.Sprintf("Access Denied: error code %s", hash), s.opts.WebmasterEmail)), templ.WithStatus(http.StatusOK)).ServeHTTP(w, r) return @@ -337,7 +332,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { s.next.ServeHTTP(w, r) } -func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr CheckResult, rule *policy.Bot) { +func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr policy.CheckResult, rule *policy.Bot) { lg := slog.With( "user_agent", r.UserAgent(), "accept_language", r.Header.Get("Accept-Language"), @@ -518,41 +513,33 @@ func (s *Server) TestError(w http.ResponseWriter, r *http.Request) { templ.Handler(web.Base("Oh noes!", web.ErrorPage(err, s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) } +func cr(name string, rule config.Rule) policy.CheckResult { + return policy.CheckResult{ + Name: name, + Rule: rule, + } +} + // Check evaluates the list of rules, and returns the result -func (s *Server) check(r *http.Request) (CheckResult, *policy.Bot, error) { +func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error) { host := r.Header.Get("X-Real-Ip") if host == "" { - return decaymap.Zilch[CheckResult](), nil, fmt.Errorf("[misconfiguration] X-Real-Ip header is not set") + return decaymap.Zilch[policy.CheckResult](), nil, fmt.Errorf("[misconfiguration] X-Real-Ip header is not set") } addr := net.ParseIP(host) if addr == nil { - return decaymap.Zilch[CheckResult](), nil, fmt.Errorf("[misconfiguration] %q is not an IP address", host) + return decaymap.Zilch[policy.CheckResult](), nil, fmt.Errorf("[misconfiguration] %q is not an IP address", host) } for _, b := range s.policy.Bots { - if b.UserAgent != nil { - if b.UserAgent.MatchString(r.UserAgent()) && s.checkRemoteAddress(b, addr) { - return cr("bot/"+b.Name, b.Action), &b, nil - } + match, err := b.Rules.Check(r) + if err != nil { + return decaymap.Zilch[policy.CheckResult](), nil, fmt.Errorf("can't run check %s: %w", b.Name, err) } - if b.Path != nil { - if b.Path.MatchString(r.URL.Path) && s.checkRemoteAddress(b, addr) { - return cr("bot/"+b.Name, b.Action), &b, nil - } - } - - if b.Ranger != nil { - if s.checkRemoteAddress(b, addr) { - return cr("bot/"+b.Name, b.Action), &b, nil - } - } - - if len(b.Headers) > 0 { - if s.checkHeaders(b, r.Header) { - return cr("bot/"+b.Name, b.Action), &b, nil - } + if match { + return cr("bot/"+b.Name, b.Action), &b, nil } } @@ -565,40 +552,6 @@ func (s *Server) check(r *http.Request) (CheckResult, *policy.Bot, error) { }, nil } -func (s *Server) checkRemoteAddress(b policy.Bot, addr net.IP) bool { - if b.Ranger == nil { - return true - } - - ok, err := b.Ranger.Contains(addr) - if err != nil { - log.Panicf("[unexpected] something very funky is going on, %q does not have a calculable network number: %v", addr.String(), err) - } - - return ok -} - -func (s *Server) checkHeaders(b policy.Bot, header http.Header) bool { - if len(b.Headers) == 0 { - return true - } - - for name, expr := range b.Headers { - values := header.Values(name) - if values == nil { - return false - } - - for _, value := range values { - if !expr.MatchString(value) { - return false - } - } - } - - return true -} - func (s *Server) CleanupDecayMap() { s.DNSBLCache.Cleanup() s.OGTags.Cleanup() diff --git a/lib/policy/bot.go b/lib/policy/bot.go index e656d9a..3a43655 100644 --- a/lib/policy/bot.go +++ b/lib/policy/bot.go @@ -2,45 +2,18 @@ package policy import ( "fmt" - "regexp" - "strings" "github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/lib/policy/config" - "github.com/yl2chen/cidranger" ) type Bot struct { Name string - UserAgent *regexp.Regexp - Path *regexp.Regexp - Headers map[string]*regexp.Regexp - Action config.Rule `json:"action"` + Action config.Rule Challenge *config.ChallengeRules - Ranger cidranger.Ranger + Rules Checker } -func (b Bot) Hash() (string, error) { - var pathRex string - if b.Path != nil { - pathRex = b.Path.String() - } - var userAgentRex string - if b.UserAgent != nil { - userAgentRex = b.UserAgent.String() - } - var headersRex string - if len(b.Headers) > 0 { - var sb strings.Builder - sb.Grow(len(b.Headers) * 64) - - for name, expr := range b.Headers { - sb.WriteString(name) - sb.WriteString(expr.String()) - } - - headersRex = sb.String() - } - - return internal.SHA256sum(fmt.Sprintf("%s::%s::%s::%s", b.Name, pathRex, userAgentRex, headersRex)), nil +func (b Bot) Hash() string { + return internal.SHA256sum(fmt.Sprintf("%s::%s", b.Name, b.Rules.Hash())) } diff --git a/lib/policy/checker.go b/lib/policy/checker.go new file mode 100644 index 0000000..ad98ced --- /dev/null +++ b/lib/policy/checker.go @@ -0,0 +1,201 @@ +package policy + +import ( + "errors" + "fmt" + "net" + "net/http" + "regexp" + "strings" + + "github.com/TecharoHQ/anubis/internal" + "github.com/yl2chen/cidranger" +) + +var ( + ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") +) + +type Checker interface { + Check(*http.Request) (bool, error) + Hash() string +} + +type CheckerList []Checker + +func (cl CheckerList) Check(r *http.Request) (bool, error) { + for _, c := range cl { + ok, err := c.Check(r) + if err != nil { + return ok, err + } + if ok { + return ok, nil + } + } + + return false, nil +} + +func (cl CheckerList) Hash() string { + var sb strings.Builder + + for _, c := range cl { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.SHA256sum(sb.String()) +} + +type RemoteAddrChecker struct { + ranger cidranger.Ranger + hash string +} + +func NewRemoteAddrChecker(cidrs []string) (Checker, error) { + ranger := cidranger.NewPCTrieRanger() + var sb strings.Builder + + for _, cidr := range cidrs { + _, rng, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err) + } + + ranger.Insert(cidranger.NewBasicRangerEntry(*rng)) + fmt.Fprintln(&sb, cidr) + } + + return &RemoteAddrChecker{ + ranger: ranger, + hash: internal.SHA256sum(sb.String()), + }, nil +} + +func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) { + host := r.Header.Get("X-Real-Ip") + if host == "" { + return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration) + } + + addr := net.ParseIP(host) + if addr == nil { + return false, fmt.Errorf("%w: %s is not an IP address", ErrMisconfiguration, host) + } + + ok, err := rac.ranger.Contains(addr) + if err != nil { + return false, err + } + + if ok { + return true, nil + } + + return false, nil +} + +func (rac *RemoteAddrChecker) Hash() string { + return rac.hash +} + +type HeaderMatchesChecker struct { + header string + regexp *regexp.Regexp + hash string +} + +func NewUserAgentChecker(rexStr string) (Checker, error) { + return NewHeaderMatchesChecker("User-Agent", rexStr) +} + +func NewHeaderMatchesChecker(header, rexStr string) (Checker, error) { + rex, err := regexp.Compile(rexStr) + if err != nil { + return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) + } + return &HeaderMatchesChecker{header, rex, internal.SHA256sum(header + ": " + rexStr)}, nil +} + +func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) { + if hmc.regexp.MatchString(r.Header.Get(hmc.header)) { + return true, nil + } + + return false, nil +} + +func (hmc *HeaderMatchesChecker) Hash() string { + return hmc.hash +} + +type PathChecker struct { + regexp *regexp.Regexp + hash string +} + +func NewPathChecker(rexStr string) (Checker, error) { + rex, err := regexp.Compile(rexStr) + if err != nil { + return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) + } + return &PathChecker{rex, internal.SHA256sum(rexStr)}, nil +} + +func (pc *PathChecker) Check(r *http.Request) (bool, error) { + if pc.regexp.MatchString(r.URL.Path) { + return true, nil + } + + return false, nil +} + +func (pc *PathChecker) Hash() string { + return pc.hash +} + +func NewHeaderExistsChecker(key string) Checker { + return headerExistsChecker{key} +} + +type headerExistsChecker struct { + header string +} + +func (hec headerExistsChecker) Check(r *http.Request) (bool, error) { + if r.Header.Get(hec.header) != "" { + return true, nil + } + + return false, nil +} + +func (hec headerExistsChecker) Hash() string { + return internal.SHA256sum(hec.header) +} + +func NewHeadersChecker(headermap map[string]string) (Checker, error) { + var result CheckerList + var errs []error + + for key, rexStr := range headermap { + if rexStr == ".*" { + result = append(result, headerExistsChecker{key}) + continue + } + + rex, err := regexp.Compile(rexStr) + if err != nil { + errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err)) + continue + } + + result = append(result, &HeaderMatchesChecker{key, rex, internal.SHA256sum(key + ": " + rexStr)}) + } + + if len(errs) != 0 { + return nil, errors.Join(errs...) + } + + return result, nil +} diff --git a/lib/policy/checker_test.go b/lib/policy/checker_test.go new file mode 100644 index 0000000..6739509 --- /dev/null +++ b/lib/policy/checker_test.go @@ -0,0 +1,200 @@ +package policy + +import ( + "errors" + "net/http" + "testing" +) + +func TestRemoteAddrChecker(t *testing.T) { + for _, tt := range []struct { + name string + cidrs []string + ip string + ok bool + err error + }{ + { + name: "match_ipv4", + cidrs: []string{"0.0.0.0/0"}, + ip: "1.1.1.1", + ok: true, + err: nil, + }, + { + name: "match_ipv6", + cidrs: []string{"::/0"}, + ip: "cafe:babe::", + ok: true, + err: nil, + }, + { + name: "not_match_ipv4", + cidrs: []string{"1.1.1.1/32"}, + ip: "1.1.1.2", + ok: false, + err: nil, + }, + { + name: "not_match_ipv6", + cidrs: []string{"cafe:babe::/128"}, + ip: "cafe:babe:4::/128", + ok: false, + err: nil, + }, + { + name: "no_ip_set", + cidrs: []string{"::/0"}, + ok: false, + err: ErrMisconfiguration, + }, + { + name: "invalid_ip", + cidrs: []string{"::/0"}, + ip: "According to all natural laws of aviation", + ok: false, + err: ErrMisconfiguration, + }, + } { + t.Run(tt.name, func(t *testing.T) { + rac, err := NewRemoteAddrChecker(tt.cidrs) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating RemoteAddrChecker failed: %v", err) + } + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + if tt.ip != "" { + r.Header.Add("X-Real-Ip", tt.ip) + } + + ok, err := rac.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} + +func TestHeaderMatchesChecker(t *testing.T) { + for _, tt := range []struct { + name string + header string + rexStr string + reqHeaderKey string + reqHeaderValue string + ok bool + err error + }{ + { + name: "match", + header: "Cf-Worker", + rexStr: ".*", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: true, + err: nil, + }, + { + name: "not_match", + header: "Cf-Worker", + rexStr: "false", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "not_present", + header: "Cf-Worker", + rexStr: "foobar", + reqHeaderKey: "Something-Else", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "invalid_regex", + rexStr: "a(b", + err: ErrMisconfiguration, + }, + } { + t.Run(tt.name, func(t *testing.T) { + hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating HeaderMatchesChecker failed") + } + + if tt.err != nil && hmc == nil { + return + } + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue) + + ok, err := hmc.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} + +func TestHeaderExistsChecker(t *testing.T) { + for _, tt := range []struct { + name string + header string + reqHeader string + ok bool + }{ + { + name: "match", + header: "Authorization", + reqHeader: "Authorization", + ok: true, + }, + { + name: "not_match", + header: "Authorization", + reqHeader: "Authentication", + }, + } { + t.Run(tt.name, func(t *testing.T) { + hec := headerExistsChecker{tt.header} + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeader, "hunter2") + + ok, err := hec.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil { + t.Errorf("err: %v", err) + } + }) + } +} diff --git a/lib/checkresult.go b/lib/policy/checkresult.go similarity index 70% rename from lib/checkresult.go rename to lib/policy/checkresult.go index 3803df2..c84f326 100644 --- a/lib/checkresult.go +++ b/lib/policy/checkresult.go @@ -1,4 +1,4 @@ -package lib +package policy import ( "log/slog" @@ -16,10 +16,3 @@ func (cr CheckResult) LogValue() slog.Value { slog.String("name", cr.Name), slog.String("rule", string(cr.Rule))) } - -func cr(name string, rule config.Rule) CheckResult { - return CheckResult{ - Name: name, - Rule: rule, - } -} diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 2d610c8..5923f16 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -4,12 +4,9 @@ import ( "errors" "fmt" "io" - "net" - "regexp" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/yl2chen/cidranger" "k8s.io/apimachinery/pkg/util/yaml" "github.com/TecharoHQ/anubis/lib/policy/config" @@ -58,57 +55,45 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon } parsedBot := Bot{ - Name: b.Name, - Action: b.Action, - Headers: map[string]*regexp.Regexp{}, + Name: b.Name, + Action: b.Action, } + cl := CheckerList{} + if len(b.RemoteAddr) > 0 { - parsedBot.Ranger = cidranger.NewPCTrieRanger() - - for _, cidr := range b.RemoteAddr { - _, rng, err := net.ParseCIDR(cidr) - if err != nil { - return nil, fmt.Errorf("[unexpected] range %s not parsing: %w", cidr, err) - } - - parsedBot.Ranger.Insert(cidranger.NewBasicRangerEntry(*rng)) + c, err := NewRemoteAddrChecker(b.RemoteAddr) + if err != nil { + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err)) + } else { + cl = append(cl, c) } } if b.UserAgentRegex != nil { - userAgent, err := regexp.Compile(*b.UserAgentRegex) + c, err := NewUserAgentChecker(*b.UserAgentRegex) if err != nil { - validationErrs = append(validationErrs, fmt.Errorf("while compiling user agent regexp: %w", err)) - continue + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err)) } else { - parsedBot.UserAgent = userAgent + cl = append(cl, c) } } if b.PathRegex != nil { - path, err := regexp.Compile(*b.PathRegex) + c, err := NewPathChecker(*b.PathRegex) if err != nil { - validationErrs = append(validationErrs, fmt.Errorf("while compiling path regexp: %w", err)) - continue + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err)) } else { - parsedBot.Path = path + cl = append(cl, c) } } if len(b.HeadersRegex) > 0 { - for name, expr := range b.HeadersRegex { - if name == "" { - continue - } - - header, err := regexp.Compile(expr) - if err != nil { - validationErrs = append(validationErrs, fmt.Errorf("while compiling header regexp: %w", err)) - continue - } else { - parsedBot.Headers[name] = header - } + c, err := NewHeadersChecker(b.HeadersRegex) + if err != nil { + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s headers regex map: %w", b.Name, err)) + } else { + cl = append(cl, c) } } @@ -125,6 +110,8 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon } } + parsedBot.Rules = cl + result.Bots = append(result.Bots, parsedBot) }