mirror of
https://github.com/TecharoHQ/anubis.git
synced 2025-09-08 04:05:23 -04:00
feat(lib): use Checker type instead of ad-hoc logic (#318)
This makes each check into its own type that has encapsulated check logic, meaning that it's easier to add new checker implementations in the future. Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
parent
9b7bf8ee06
commit
84b28760b3
@ -20,7 +20,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -58,7 +57,7 @@ var (
|
|||||||
ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough")
|
ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough")
|
||||||
ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time")
|
ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time")
|
||||||
extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder")
|
extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder")
|
||||||
webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals")
|
webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals")
|
||||||
)
|
)
|
||||||
|
|
||||||
func keyFromHex(value string) (ed25519.PrivateKey, error) {
|
func keyFromHex(value string) (ed25519.PrivateKey, error) {
|
||||||
@ -205,22 +204,17 @@ func main() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
hash, err := rule.Hash()
|
hash := rule.Hash()
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("can't calculate checksum of rule %s: %v", rule.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("* %s: %s\n", rule.Name, hash)
|
fmt.Printf("* %s: %s\n", rule.Name, hash)
|
||||||
}
|
}
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
// replace the bot policy rules with a single rule that always benchmarks
|
// replace the bot policy rules with a single rule that always benchmarks
|
||||||
if *debugBenchmarkJS {
|
if *debugBenchmarkJS {
|
||||||
userAgent := regexp.MustCompile(".")
|
|
||||||
policy.Bots = []botPolicy.Bot{{
|
policy.Bots = []botPolicy.Bot{{
|
||||||
Name: "",
|
Name: "",
|
||||||
UserAgent: userAgent,
|
Rules: botPolicy.NewHeaderExistsChecker("User-Agent"),
|
||||||
Action: config.RuleBenchmark,
|
Action: config.RuleBenchmark,
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,7 +255,7 @@ func main() {
|
|||||||
OGPassthrough: *ogPassthrough,
|
OGPassthrough: *ogPassthrough,
|
||||||
OGTimeToLive: *ogTimeToLive,
|
OGTimeToLive: *ogTimeToLive,
|
||||||
Target: *target,
|
Target: *target,
|
||||||
WebmasterEmail: *webmasterEmail,
|
WebmasterEmail: *webmasterEmail,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("can't construct libanubis.Server: %v", err)
|
log.Fatalf("can't construct libanubis.Server: %v", err)
|
||||||
|
@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
- Refactor check logic to be more generic and work on a Checker type
|
||||||
- Add more AI user agents based on the [ai.robots.txt](https://github.com/ai-robots-txt/ai.robots.txt) project
|
- Add more AI user agents based on the [ai.robots.txt](https://github.com/ai-robots-txt/ai.robots.txt) project
|
||||||
- Embedded challenge data in initial HTML response to improve performance
|
- Embedded challenge data in initial HTML response to improve performance
|
||||||
- Whitelisted [DuckDuckBot](https://duckduckgo.com/duckduckgo-help-pages/results/duckduckbot/) in botPolicies
|
- Whitelisted [DuckDuckBot](https://duckduckgo.com/duckduckgo-help-pages/results/duckduckbot/) in botPolicies
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
@ -238,12 +237,8 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) {
|
|||||||
templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)", s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r)
|
templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)", s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hash, err := rule.Hash()
|
hash := rule.Hash()
|
||||||
if err != nil {
|
|
||||||
lg.Error("can't calculate checksum of rule", "err", err)
|
|
||||||
templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)", s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lg.Debug("rule hash", "hash", hash)
|
lg.Debug("rule hash", "hash", hash)
|
||||||
templ.Handler(web.Base("Oh noes!", web.ErrorPage(fmt.Sprintf("Access Denied: error code %s", hash), s.opts.WebmasterEmail)), templ.WithStatus(http.StatusOK)).ServeHTTP(w, r)
|
templ.Handler(web.Base("Oh noes!", web.ErrorPage(fmt.Sprintf("Access Denied: error code %s", hash), s.opts.WebmasterEmail)), templ.WithStatus(http.StatusOK)).ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
@ -337,7 +332,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.next.ServeHTTP(w, r)
|
s.next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr CheckResult, rule *policy.Bot) {
|
func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr policy.CheckResult, rule *policy.Bot) {
|
||||||
lg := slog.With(
|
lg := slog.With(
|
||||||
"user_agent", r.UserAgent(),
|
"user_agent", r.UserAgent(),
|
||||||
"accept_language", r.Header.Get("Accept-Language"),
|
"accept_language", r.Header.Get("Accept-Language"),
|
||||||
@ -518,41 +513,33 @@ func (s *Server) TestError(w http.ResponseWriter, r *http.Request) {
|
|||||||
templ.Handler(web.Base("Oh noes!", web.ErrorPage(err, s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r)
|
templ.Handler(web.Base("Oh noes!", web.ErrorPage(err, s.opts.WebmasterEmail)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cr(name string, rule config.Rule) policy.CheckResult {
|
||||||
|
return policy.CheckResult{
|
||||||
|
Name: name,
|
||||||
|
Rule: rule,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check evaluates the list of rules, and returns the result
|
// Check evaluates the list of rules, and returns the result
|
||||||
func (s *Server) check(r *http.Request) (CheckResult, *policy.Bot, error) {
|
func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error) {
|
||||||
host := r.Header.Get("X-Real-Ip")
|
host := r.Header.Get("X-Real-Ip")
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return decaymap.Zilch[CheckResult](), nil, fmt.Errorf("[misconfiguration] X-Real-Ip header is not set")
|
return decaymap.Zilch[policy.CheckResult](), nil, fmt.Errorf("[misconfiguration] X-Real-Ip header is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := net.ParseIP(host)
|
addr := net.ParseIP(host)
|
||||||
if addr == nil {
|
if addr == nil {
|
||||||
return decaymap.Zilch[CheckResult](), nil, fmt.Errorf("[misconfiguration] %q is not an IP address", host)
|
return decaymap.Zilch[policy.CheckResult](), nil, fmt.Errorf("[misconfiguration] %q is not an IP address", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, b := range s.policy.Bots {
|
for _, b := range s.policy.Bots {
|
||||||
if b.UserAgent != nil {
|
match, err := b.Rules.Check(r)
|
||||||
if b.UserAgent.MatchString(r.UserAgent()) && s.checkRemoteAddress(b, addr) {
|
if err != nil {
|
||||||
return cr("bot/"+b.Name, b.Action), &b, nil
|
return decaymap.Zilch[policy.CheckResult](), nil, fmt.Errorf("can't run check %s: %w", b.Name, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.Path != nil {
|
if match {
|
||||||
if b.Path.MatchString(r.URL.Path) && s.checkRemoteAddress(b, addr) {
|
return cr("bot/"+b.Name, b.Action), &b, nil
|
||||||
return cr("bot/"+b.Name, b.Action), &b, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.Ranger != nil {
|
|
||||||
if s.checkRemoteAddress(b, addr) {
|
|
||||||
return cr("bot/"+b.Name, b.Action), &b, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(b.Headers) > 0 {
|
|
||||||
if s.checkHeaders(b, r.Header) {
|
|
||||||
return cr("bot/"+b.Name, b.Action), &b, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -565,40 +552,6 @@ func (s *Server) check(r *http.Request) (CheckResult, *policy.Bot, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) checkRemoteAddress(b policy.Bot, addr net.IP) bool {
|
|
||||||
if b.Ranger == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := b.Ranger.Contains(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Panicf("[unexpected] something very funky is going on, %q does not have a calculable network number: %v", addr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) checkHeaders(b policy.Bot, header http.Header) bool {
|
|
||||||
if len(b.Headers) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, expr := range b.Headers {
|
|
||||||
values := header.Values(name)
|
|
||||||
if values == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
if !expr.MatchString(value) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) CleanupDecayMap() {
|
func (s *Server) CleanupDecayMap() {
|
||||||
s.DNSBLCache.Cleanup()
|
s.DNSBLCache.Cleanup()
|
||||||
s.OGTags.Cleanup()
|
s.OGTags.Cleanup()
|
||||||
|
@ -2,45 +2,18 @@ package policy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/internal"
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
"github.com/yl2chen/cidranger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Bot struct {
|
type Bot struct {
|
||||||
Name string
|
Name string
|
||||||
UserAgent *regexp.Regexp
|
Action config.Rule
|
||||||
Path *regexp.Regexp
|
|
||||||
Headers map[string]*regexp.Regexp
|
|
||||||
Action config.Rule `json:"action"`
|
|
||||||
Challenge *config.ChallengeRules
|
Challenge *config.ChallengeRules
|
||||||
Ranger cidranger.Ranger
|
Rules Checker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b Bot) Hash() (string, error) {
|
func (b Bot) Hash() string {
|
||||||
var pathRex string
|
return internal.SHA256sum(fmt.Sprintf("%s::%s", b.Name, b.Rules.Hash()))
|
||||||
if b.Path != nil {
|
|
||||||
pathRex = b.Path.String()
|
|
||||||
}
|
|
||||||
var userAgentRex string
|
|
||||||
if b.UserAgent != nil {
|
|
||||||
userAgentRex = b.UserAgent.String()
|
|
||||||
}
|
|
||||||
var headersRex string
|
|
||||||
if len(b.Headers) > 0 {
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.Grow(len(b.Headers) * 64)
|
|
||||||
|
|
||||||
for name, expr := range b.Headers {
|
|
||||||
sb.WriteString(name)
|
|
||||||
sb.WriteString(expr.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
headersRex = sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return internal.SHA256sum(fmt.Sprintf("%s::%s::%s::%s", b.Name, pathRex, userAgentRex, headersRex)), nil
|
|
||||||
}
|
}
|
||||||
|
201
lib/policy/checker.go
Normal file
201
lib/policy/checker.go
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/TecharoHQ/anubis/internal"
|
||||||
|
"github.com/yl2chen/cidranger"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Checker interface {
|
||||||
|
Check(*http.Request) (bool, error)
|
||||||
|
Hash() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CheckerList []Checker
|
||||||
|
|
||||||
|
func (cl CheckerList) Check(r *http.Request) (bool, error) {
|
||||||
|
for _, c := range cl {
|
||||||
|
ok, err := c.Check(r)
|
||||||
|
if err != nil {
|
||||||
|
return ok, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cl CheckerList) Hash() string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for _, c := range cl {
|
||||||
|
fmt.Fprintln(&sb, c.Hash())
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.SHA256sum(sb.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
type RemoteAddrChecker struct {
|
||||||
|
ranger cidranger.Ranger
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRemoteAddrChecker(cidrs []string) (Checker, error) {
|
||||||
|
ranger := cidranger.NewPCTrieRanger()
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
_, rng, err := net.ParseCIDR(cidr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ranger.Insert(cidranger.NewBasicRangerEntry(*rng))
|
||||||
|
fmt.Fprintln(&sb, cidr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RemoteAddrChecker{
|
||||||
|
ranger: ranger,
|
||||||
|
hash: internal.SHA256sum(sb.String()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
host := r.Header.Get("X-Real-Ip")
|
||||||
|
if host == "" {
|
||||||
|
return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := net.ParseIP(host)
|
||||||
|
if addr == nil {
|
||||||
|
return false, fmt.Errorf("%w: %s is not an IP address", ErrMisconfiguration, host)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := rac.ranger.Contains(addr)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *RemoteAddrChecker) Hash() string {
|
||||||
|
return rac.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeaderMatchesChecker struct {
|
||||||
|
header string
|
||||||
|
regexp *regexp.Regexp
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserAgentChecker(rexStr string) (Checker, error) {
|
||||||
|
return NewHeaderMatchesChecker("User-Agent", rexStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHeaderMatchesChecker(header, rexStr string) (Checker, error) {
|
||||||
|
rex, err := regexp.Compile(rexStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
|
||||||
|
}
|
||||||
|
return &HeaderMatchesChecker{header, rex, internal.SHA256sum(header + ": " + rexStr)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hmc *HeaderMatchesChecker) Hash() string {
|
||||||
|
return hmc.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
type PathChecker struct {
|
||||||
|
regexp *regexp.Regexp
|
||||||
|
hash string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPathChecker(rexStr string) (Checker, error) {
|
||||||
|
rex, err := regexp.Compile(rexStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
|
||||||
|
}
|
||||||
|
return &PathChecker{rex, internal.SHA256sum(rexStr)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PathChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
if pc.regexp.MatchString(r.URL.Path) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PathChecker) Hash() string {
|
||||||
|
return pc.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHeaderExistsChecker(key string) Checker {
|
||||||
|
return headerExistsChecker{key}
|
||||||
|
}
|
||||||
|
|
||||||
|
type headerExistsChecker struct {
|
||||||
|
header string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
|
||||||
|
if r.Header.Get(hec.header) != "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hec headerExistsChecker) Hash() string {
|
||||||
|
return internal.SHA256sum(hec.header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHeadersChecker(headermap map[string]string) (Checker, error) {
|
||||||
|
var result CheckerList
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
for key, rexStr := range headermap {
|
||||||
|
if rexStr == ".*" {
|
||||||
|
result = append(result, headerExistsChecker{key})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rex, err := regexp.Compile(rexStr)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, &HeaderMatchesChecker{key, rex, internal.SHA256sum(key + ": " + rexStr)})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) != 0 {
|
||||||
|
return nil, errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
200
lib/policy/checker_test.go
Normal file
200
lib/policy/checker_test.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRemoteAddrChecker(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
cidrs []string
|
||||||
|
ip string
|
||||||
|
ok bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match_ipv4",
|
||||||
|
cidrs: []string{"0.0.0.0/0"},
|
||||||
|
ip: "1.1.1.1",
|
||||||
|
ok: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match_ipv6",
|
||||||
|
cidrs: []string{"::/0"},
|
||||||
|
ip: "cafe:babe::",
|
||||||
|
ok: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match_ipv4",
|
||||||
|
cidrs: []string{"1.1.1.1/32"},
|
||||||
|
ip: "1.1.1.2",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match_ipv6",
|
||||||
|
cidrs: []string{"cafe:babe::/128"},
|
||||||
|
ip: "cafe:babe:4::/128",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_ip_set",
|
||||||
|
cidrs: []string{"::/0"},
|
||||||
|
ok: false,
|
||||||
|
err: ErrMisconfiguration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_ip",
|
||||||
|
cidrs: []string{"::/0"},
|
||||||
|
ip: "According to all natural laws of aviation",
|
||||||
|
ok: false,
|
||||||
|
err: ErrMisconfiguration,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rac, err := NewRemoteAddrChecker(tt.cidrs)
|
||||||
|
if err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Fatalf("creating RemoteAddrChecker failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.ip != "" {
|
||||||
|
r.Header.Add("X-Real-Ip", tt.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := rac.Check(r)
|
||||||
|
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderMatchesChecker(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
rexStr string
|
||||||
|
reqHeaderKey string
|
||||||
|
reqHeaderValue string
|
||||||
|
ok bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match",
|
||||||
|
header: "Cf-Worker",
|
||||||
|
rexStr: ".*",
|
||||||
|
reqHeaderKey: "Cf-Worker",
|
||||||
|
reqHeaderValue: "true",
|
||||||
|
ok: true,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match",
|
||||||
|
header: "Cf-Worker",
|
||||||
|
rexStr: "false",
|
||||||
|
reqHeaderKey: "Cf-Worker",
|
||||||
|
reqHeaderValue: "true",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_present",
|
||||||
|
header: "Cf-Worker",
|
||||||
|
rexStr: "foobar",
|
||||||
|
reqHeaderKey: "Something-Else",
|
||||||
|
reqHeaderValue: "true",
|
||||||
|
ok: false,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_regex",
|
||||||
|
rexStr: "a(b",
|
||||||
|
err: ErrMisconfiguration,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
|
||||||
|
if err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Fatalf("creating HeaderMatchesChecker failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.err != nil && hmc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
|
||||||
|
|
||||||
|
ok, err := hmc.Check(r)
|
||||||
|
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
|
||||||
|
t.Errorf("err: %v, wanted: %v", err, tt.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderExistsChecker(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
reqHeader string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "match",
|
||||||
|
header: "Authorization",
|
||||||
|
reqHeader: "Authorization",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_match",
|
||||||
|
header: "Authorization",
|
||||||
|
reqHeader: "Authentication",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hec := headerExistsChecker{tt.header}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't make request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Header.Set(tt.reqHeader, "hunter2")
|
||||||
|
|
||||||
|
ok, err := hec.Check(r)
|
||||||
|
|
||||||
|
if tt.ok != ok {
|
||||||
|
t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("err: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package lib
|
package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@ -16,10 +16,3 @@ func (cr CheckResult) LogValue() slog.Value {
|
|||||||
slog.String("name", cr.Name),
|
slog.String("name", cr.Name),
|
||||||
slog.String("rule", string(cr.Rule)))
|
slog.String("rule", string(cr.Rule)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func cr(name string, rule config.Rule) CheckResult {
|
|
||||||
return CheckResult{
|
|
||||||
Name: name,
|
|
||||||
Rule: rule,
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,12 +4,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"regexp"
|
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
"github.com/yl2chen/cidranger"
|
|
||||||
"k8s.io/apimachinery/pkg/util/yaml"
|
"k8s.io/apimachinery/pkg/util/yaml"
|
||||||
|
|
||||||
"github.com/TecharoHQ/anubis/lib/policy/config"
|
"github.com/TecharoHQ/anubis/lib/policy/config"
|
||||||
@ -58,57 +55,45 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
|
|||||||
}
|
}
|
||||||
|
|
||||||
parsedBot := Bot{
|
parsedBot := Bot{
|
||||||
Name: b.Name,
|
Name: b.Name,
|
||||||
Action: b.Action,
|
Action: b.Action,
|
||||||
Headers: map[string]*regexp.Regexp{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cl := CheckerList{}
|
||||||
|
|
||||||
if len(b.RemoteAddr) > 0 {
|
if len(b.RemoteAddr) > 0 {
|
||||||
parsedBot.Ranger = cidranger.NewPCTrieRanger()
|
c, err := NewRemoteAddrChecker(b.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
for _, cidr := range b.RemoteAddr {
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
|
||||||
_, rng, err := net.ParseCIDR(cidr)
|
} else {
|
||||||
if err != nil {
|
cl = append(cl, c)
|
||||||
return nil, fmt.Errorf("[unexpected] range %s not parsing: %w", cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedBot.Ranger.Insert(cidranger.NewBasicRangerEntry(*rng))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.UserAgentRegex != nil {
|
if b.UserAgentRegex != nil {
|
||||||
userAgent, err := regexp.Compile(*b.UserAgentRegex)
|
c, err := NewUserAgentChecker(*b.UserAgentRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while compiling user agent regexp: %w", err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
|
||||||
continue
|
|
||||||
} else {
|
} else {
|
||||||
parsedBot.UserAgent = userAgent
|
cl = append(cl, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.PathRegex != nil {
|
if b.PathRegex != nil {
|
||||||
path, err := regexp.Compile(*b.PathRegex)
|
c, err := NewPathChecker(*b.PathRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while compiling path regexp: %w", err))
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
|
||||||
continue
|
|
||||||
} else {
|
} else {
|
||||||
parsedBot.Path = path
|
cl = append(cl, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(b.HeadersRegex) > 0 {
|
if len(b.HeadersRegex) > 0 {
|
||||||
for name, expr := range b.HeadersRegex {
|
c, err := NewHeadersChecker(b.HeadersRegex)
|
||||||
if name == "" {
|
if err != nil {
|
||||||
continue
|
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s headers regex map: %w", b.Name, err))
|
||||||
}
|
} else {
|
||||||
|
cl = append(cl, c)
|
||||||
header, err := regexp.Compile(expr)
|
|
||||||
if err != nil {
|
|
||||||
validationErrs = append(validationErrs, fmt.Errorf("while compiling header regexp: %w", err))
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
parsedBot.Headers[name] = header
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,6 +110,8 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parsedBot.Rules = cl
|
||||||
|
|
||||||
result.Bots = append(result.Bots, parsedBot)
|
result.Bots = append(result.Bots, parsedBot)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user