From e98d749bf2f7cc4a3c8d27a9c71438323c0a90ee Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Fri, 25 Jul 2025 19:52:07 +0000 Subject: [PATCH] refactor: move CEL checker to its own package Signed-off-by: Xe Iaso --- cmd/anubis/main.go | 3 +- cmd/robots2policy/main.go | 21 ++++---- lib/anubis.go | 3 +- lib/challenge/all/all.go | 6 +++ lib/checker/all/all.go | 1 + .../expression/checker.go} | 17 +++---- .../expression/config.go} | 28 +++++------ .../expression/config_test.go} | 48 +++++++++---------- lib/checker/expression/factory.go | 43 +++++++++++++++++ lib/policy/config/config.go | 19 ++++---- lib/policy/config/threshold.go | 11 +++-- lib/policy/config/threshold_test.go | 10 ++-- lib/policy/policy.go | 3 +- 13 files changed, 135 insertions(+), 78 deletions(-) create mode 100644 lib/challenge/all/all.go rename lib/{policy/celchecker.go => checker/expression/checker.go} (82%) rename lib/{policy/config/expressionorlist.go => checker/expression/config.go} (71%) rename lib/{policy/config/expressionorlist_test.go => checker/expression/config_test.go} (87%) create mode 100644 lib/checker/expression/factory.go diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go index cc8ebd1..c48feeb 100644 --- a/cmd/anubis/main.go +++ b/cmd/anubis/main.go @@ -31,6 +31,7 @@ import ( "github.com/TecharoHQ/anubis/data" "github.com/TecharoHQ/anubis/internal" libanubis "github.com/TecharoHQ/anubis/lib" + "github.com/TecharoHQ/anubis/lib/checker/headerexists" botPolicy "github.com/TecharoHQ/anubis/lib/policy" "github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/thoth" @@ -323,7 +324,7 @@ func main() { if *debugBenchmarkJS { policy.Bots = []botPolicy.Bot{{ Name: "", - Rules: botPolicy.NewHeaderExistsChecker("User-Agent"), + Rules: headerexists.New("User-Agent"), Action: config.RuleBenchmark, }} } diff --git a/cmd/robots2policy/main.go b/cmd/robots2policy/main.go index eaa4d7f..34652b0 100644 --- a/cmd/robots2policy/main.go +++ b/cmd/robots2policy/main.go @@ -12,6 +12,7 @@ import ( "regexp" "strings" + "github.com/TecharoHQ/anubis/lib/checker/expression" "github.com/TecharoHQ/anubis/lib/policy/config" "sigs.k8s.io/yaml" @@ -37,11 +38,11 @@ type RobotsRule struct { } type AnubisRule struct { - Expression *config.ExpressionOrList `yaml:"expression,omitempty" json:"expression,omitempty"` - Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"` - Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"` - Name string `yaml:"name" json:"name"` - Action string `yaml:"action" json:"action"` + Expression *expression.Config `yaml:"expression,omitempty" json:"expression,omitempty"` + Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"` + Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"` + Name string `yaml:"name" json:"name"` + Action string `yaml:"action" json:"action"` } func init() { @@ -224,11 +225,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule { } if userAgent == "*" { - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{"true"}, // Always applies } } else { - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)}, } } @@ -249,11 +250,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule { rule.Name = fmt.Sprintf("%s-global-restriction-%d", *policyName, ruleCounter) rule.Action = "WEIGH" rule.Weight = &config.Weight{Adjust: 20} // Increase difficulty significantly - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{"true"}, // Always applies } } else { - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)}, } } @@ -285,7 +286,7 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule { pathCondition := buildPathCondition(disallow) conditions = append(conditions, pathCondition) - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: conditions, } diff --git a/lib/anubis.go b/lib/anubis.go index b06d016..702a6dd 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -38,8 +38,7 @@ import ( _ "github.com/TecharoHQ/anubis/lib/checker/all" // challenge implementations - _ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh" - _ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork" + _ "github.com/TecharoHQ/anubis/lib/challenge/all" ) var ( diff --git a/lib/challenge/all/all.go b/lib/challenge/all/all.go new file mode 100644 index 0000000..eb1f32c --- /dev/null +++ b/lib/challenge/all/all.go @@ -0,0 +1,6 @@ +package all + +import ( + _ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh" + _ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork" +) diff --git a/lib/checker/all/all.go b/lib/checker/all/all.go index a5bf6a6..aca7316 100644 --- a/lib/checker/all/all.go +++ b/lib/checker/all/all.go @@ -2,6 +2,7 @@ package all import ( + _ "github.com/TecharoHQ/anubis/lib/checker/expression" _ "github.com/TecharoHQ/anubis/lib/checker/headerexists" _ "github.com/TecharoHQ/anubis/lib/checker/headermatches" _ "github.com/TecharoHQ/anubis/lib/checker/path" diff --git a/lib/policy/celchecker.go b/lib/checker/expression/checker.go similarity index 82% rename from lib/policy/celchecker.go rename to lib/checker/expression/checker.go index a385398..557d69a 100644 --- a/lib/policy/celchecker.go +++ b/lib/checker/expression/checker.go @@ -1,22 +1,22 @@ -package policy +package expression import ( "fmt" "net/http" "github.com/TecharoHQ/anubis/internal" - "github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/expressions" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" ) -type CELChecker struct { +type Checker struct { program cel.Program src string + hash string } -func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) { +func New(cfg *Config) (*Checker, error) { env, err := expressions.BotEnvironment() if err != nil { return nil, err @@ -27,17 +27,18 @@ func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) { return nil, fmt.Errorf("can't compile CEL program: %w", err) } - return &CELChecker{ + return &Checker{ src: cfg.String(), + hash: internal.FastHash(cfg.String()), program: program, }, nil } -func (cc *CELChecker) Hash() string { - return internal.FastHash(cc.src) +func (cc *Checker) Hash() string { + return cc.hash } -func (cc *CELChecker) Check(r *http.Request) (bool, error) { +func (cc *Checker) Check(r *http.Request) (bool, error) { result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r}) if err != nil { diff --git a/lib/policy/config/expressionorlist.go b/lib/checker/expression/config.go similarity index 71% rename from lib/policy/config/expressionorlist.go rename to lib/checker/expression/config.go index b4e64c4..8066187 100644 --- a/lib/policy/config/expressionorlist.go +++ b/lib/checker/expression/config.go @@ -1,4 +1,4 @@ -package config +package expression import ( "encoding/json" @@ -9,18 +9,18 @@ import ( ) var ( - ErrExpressionOrListMustBeStringOrObject = errors.New("config: this must be a string or an object") - ErrExpressionEmpty = errors.New("config: this expression is empty") - ErrExpressionCantHaveBoth = errors.New("config: expression block can't contain multiple expression types") + ErrExpressionOrListMustBeStringOrObject = errors.New("expression: this must be a string or an object") + ErrExpressionEmpty = errors.New("expression: this expression is empty") + ErrExpressionCantHaveBoth = errors.New("expression: expression block can't contain multiple expression types") ) -type ExpressionOrList struct { +type Config struct { Expression string `json:"-" yaml:"-"` All []string `json:"all,omitempty" yaml:"all,omitempty"` Any []string `json:"any,omitempty" yaml:"any,omitempty"` } -func (eol ExpressionOrList) String() string { +func (eol Config) String() string { switch { case len(eol.Expression) != 0: return eol.Expression @@ -46,7 +46,7 @@ func (eol ExpressionOrList) String() string { panic("this should not happen") } -func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { +func (eol Config) Equal(rhs *Config) bool { if eol.Expression != rhs.Expression { return false } @@ -62,7 +62,7 @@ func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { return true } -func (eol *ExpressionOrList) MarshalYAML() (any, error) { +func (eol *Config) MarshalYAML() (any, error) { switch { case len(eol.All) == 1 && len(eol.Any) == 0: eol.Expression = eol.All[0] @@ -76,11 +76,11 @@ func (eol *ExpressionOrList) MarshalYAML() (any, error) { return eol.Expression, nil } - type RawExpressionOrList ExpressionOrList + type RawExpressionOrList Config return RawExpressionOrList(*eol), nil } -func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) { +func (eol *Config) MarshalJSON() ([]byte, error) { switch { case len(eol.All) == 1 && len(eol.Any) == 0: eol.Expression = eol.All[0] @@ -94,17 +94,17 @@ func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) { return json.Marshal(string(eol.Expression)) } - type RawExpressionOrList ExpressionOrList + type RawExpressionOrList Config val := RawExpressionOrList(*eol) return json.Marshal(val) } -func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error { +func (eol *Config) UnmarshalJSON(data []byte) error { switch string(data[0]) { case `"`: // string return json.Unmarshal(data, &eol.Expression) case "{": // object - type RawExpressionOrList ExpressionOrList + type RawExpressionOrList Config var val RawExpressionOrList if err := json.Unmarshal(data, &val); err != nil { return err @@ -118,7 +118,7 @@ func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error { return ErrExpressionOrListMustBeStringOrObject } -func (eol *ExpressionOrList) Valid() error { +func (eol *Config) Valid() error { if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 { return ErrExpressionEmpty } diff --git a/lib/policy/config/expressionorlist_test.go b/lib/checker/expression/config_test.go similarity index 87% rename from lib/policy/config/expressionorlist_test.go rename to lib/checker/expression/config_test.go index a09baf3..293b53e 100644 --- a/lib/policy/config/expressionorlist_test.go +++ b/lib/checker/expression/config_test.go @@ -1,4 +1,4 @@ -package config +package expression import ( "bytes" @@ -12,13 +12,13 @@ import ( func TestExpressionOrListMarshalJSON(t *testing.T) { for _, tt := range []struct { name string - input *ExpressionOrList + input *Config output []byte err error }{ { name: "single expression", - input: &ExpressionOrList{ + input: &Config{ Expression: "true", }, output: []byte(`"true"`), @@ -26,7 +26,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "all", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true", "true"}, }, output: []byte(`{"all":["true","true"]}`), @@ -34,7 +34,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "all one", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true"}, }, output: []byte(`"true"`), @@ -42,7 +42,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "any", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true", "false"}, }, output: []byte(`{"any":["true","false"]}`), @@ -50,7 +50,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "any one", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true"}, }, output: []byte(`"true"`), @@ -75,13 +75,13 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { func TestExpressionOrListMarshalYAML(t *testing.T) { for _, tt := range []struct { name string - input *ExpressionOrList + input *Config output []byte err error }{ { name: "single expression", - input: &ExpressionOrList{ + input: &Config{ Expression: "true", }, output: []byte(`"true"`), @@ -89,7 +89,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "all", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true", "true"}, }, output: []byte(`all: @@ -99,7 +99,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "all one", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true"}, }, output: []byte(`"true"`), @@ -107,7 +107,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "any", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true", "false"}, }, output: []byte(`any: @@ -117,7 +117,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "any one", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true"}, }, output: []byte(`"true"`), @@ -145,14 +145,14 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { for _, tt := range []struct { err error validErr error - result *ExpressionOrList + result *Config name string inp string }{ { name: "simple", inp: `"\"User-Agent\" in headers"`, - result: &ExpressionOrList{ + result: &Config{ Expression: `"User-Agent" in headers`, }, }, @@ -161,7 +161,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { inp: `{ "all": ["\"User-Agent\" in headers"] }`, - result: &ExpressionOrList{ + result: &Config{ All: []string{ `"User-Agent" in headers`, }, @@ -172,7 +172,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { inp: `{ "any": ["\"User-Agent\" in headers"] }`, - result: &ExpressionOrList{ + result: &Config{ Any: []string{ `"User-Agent" in headers`, }, @@ -195,7 +195,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - var eol ExpressionOrList + var eol Config if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) { t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err) @@ -217,40 +217,40 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { func TestExpressionOrListString(t *testing.T) { for _, tt := range []struct { name string - in ExpressionOrList + in Config out string }{ { name: "single expression", - in: ExpressionOrList{ + in: Config{ Expression: "true", }, out: "true", }, { name: "all", - in: ExpressionOrList{ + in: Config{ All: []string{"true"}, }, out: "( true )", }, { name: "all with &&", - in: ExpressionOrList{ + in: Config{ All: []string{"true", "true"}, }, out: "( true ) && ( true )", }, { name: "any", - in: ExpressionOrList{ + in: Config{ All: []string{"true"}, }, out: "( true )", }, { name: "any with ||", - in: ExpressionOrList{ + in: Config{ Any: []string{"true", "true"}, }, out: "( true ) || ( true )", diff --git a/lib/checker/expression/factory.go b/lib/checker/expression/factory.go new file mode 100644 index 0000000..03b8ff4 --- /dev/null +++ b/lib/checker/expression/factory.go @@ -0,0 +1,43 @@ +package expression + +import ( + "context" + "encoding/json" + "errors" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +func init() { + checker.Register("expression", Factory{}) +} + +type Factory struct{} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var fc = &Config{} + + if err := json.Unmarshal([]byte(data), fc); err != nil { + return nil, errors.Join(checker.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return nil, errors.Join(checker.ErrInvalidConfig, err) + } + + return New(fc) +} + +func (f Factory) Valid(ctx context.Context, data json.RawMessage) error { + var fc = &Config{} + + if err := json.Unmarshal([]byte(data), fc); err != nil { + return err + } + + if err := fc.Valid(); err != nil { + return err + } + + return nil +} diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index 4d67df1..05f1064 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -12,6 +12,7 @@ import ( "time" "github.com/TecharoHQ/anubis/data" + "github.com/TecharoHQ/anubis/lib/checker/expression" "github.com/TecharoHQ/anubis/lib/checker/headermatches" "github.com/TecharoHQ/anubis/lib/checker/path" "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" @@ -58,15 +59,15 @@ func (r Rule) Valid() error { const DefaultAlgorithm = "fast" type BotConfig struct { - UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"` - PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"` - HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"` - Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"` - Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` - Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` - Name string `json:"name" yaml:"name"` - Action Rule `json:"action" yaml:"action"` - RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` + UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"` + PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"` + HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"` + Expression *expression.Config `json:"expression,omitempty" yaml:"expression,omitempty"` + Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` + Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` + Name string `json:"name" yaml:"name"` + Action Rule `json:"action" yaml:"action"` + RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` // Thoth features GeoIP *GeoIP `json:"geoip,omitempty"` diff --git a/lib/policy/config/threshold.go b/lib/policy/config/threshold.go index d9a0ed0..3c7b615 100644 --- a/lib/policy/config/threshold.go +++ b/lib/policy/config/threshold.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/TecharoHQ/anubis" + "github.com/TecharoHQ/anubis/lib/checker/expression" ) var ( @@ -17,7 +18,7 @@ var ( DefaultThresholds = []Threshold{ { Name: "legacy-anubis-behaviour", - Expression: &ExpressionOrList{ + Expression: &expression.Config{ Expression: "weight > 0", }, Action: RuleChallenge, @@ -31,10 +32,10 @@ var ( ) type Threshold struct { - Name string `json:"name" yaml:"name"` - Expression *ExpressionOrList `json:"expression" yaml:"expression"` - Action Rule `json:"action" yaml:"action"` - Challenge *ChallengeRules `json:"challenge" yaml:"challenge"` + Name string `json:"name" yaml:"name"` + Expression *expression.Config `json:"expression" yaml:"expression"` + Action Rule `json:"action" yaml:"action"` + Challenge *ChallengeRules `json:"challenge" yaml:"challenge"` } func (t Threshold) Valid() error { diff --git a/lib/policy/config/threshold_test.go b/lib/policy/config/threshold_test.go index 9024fe8..33d6533 100644 --- a/lib/policy/config/threshold_test.go +++ b/lib/policy/config/threshold_test.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/TecharoHQ/anubis/lib/checker/expression" ) func TestThresholdValid(t *testing.T) { @@ -18,7 +20,7 @@ func TestThresholdValid(t *testing.T) { name: "basic allow", input: &Threshold{ Name: "basic-allow", - Expression: &ExpressionOrList{Expression: "true"}, + Expression: &expression.Config{Expression: "true"}, Action: RuleAllow, }, err: nil, @@ -27,7 +29,7 @@ func TestThresholdValid(t *testing.T) { name: "basic challenge", input: &Threshold{ Name: "basic-challenge", - Expression: &ExpressionOrList{Expression: "true"}, + Expression: &expression.Config{Expression: "true"}, Action: RuleChallenge, Challenge: &ChallengeRules{ Algorithm: "fast", @@ -50,9 +52,9 @@ func TestThresholdValid(t *testing.T) { { name: "invalid expression", input: &Threshold{ - Expression: &ExpressionOrList{}, + Expression: &expression.Config{}, }, - err: ErrExpressionEmpty, + err: expression.ErrExpressionEmpty, }, { name: "invalid action", diff --git a/lib/policy/policy.go b/lib/policy/policy.go index dc2c58f..180db8f 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/TecharoHQ/anubis/lib/checker" + "github.com/TecharoHQ/anubis/lib/checker/expression" "github.com/TecharoHQ/anubis/lib/checker/headermatches" "github.com/TecharoHQ/anubis/lib/checker/path" "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" @@ -115,7 +116,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic } if b.Expression != nil { - c, err := NewCELChecker(b.Expression) + c, err := expression.New(b.Expression) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err)) } else {