fix(config): actually load threshold config (#696)

* fix(config): actually load threshold config

Signed-off-by: Xe Iaso <me@xeiaso.net>

* chore: spelling

Signed-off-by: Xe Iaso <me@xeiaso.net>

* test(lib): fix test failures

Signed-off-by: Xe Iaso <me@xeiaso.net>

---------

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso 2025-06-19 17:13:01 -04:00 committed by GitHub
parent 226cf36bf7
commit 7aa732c700
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 201 additions and 30 deletions

View File

@ -83,7 +83,9 @@
^\Q.github/FUNDING.yml\E$
^\Q.github/workflows/spelling.yml\E$
^data/crawlers/
^docs/blog/tags\.yml$
^docs/manifest/.*$
^docs/static/\.nojekyll$
^lib/policy/config/testdata/bad/unparseable\.json$
ignore$
robots.txt

View File

@ -44,7 +44,6 @@ chall
challengemozilla
checkpath
checkresult
chen
chibi
cidranger
ckie
@ -61,7 +60,6 @@ DDOS
Debian
debrpm
decaymap
decompiling
Diffbot
discordapp
discordbot
@ -300,6 +298,7 @@ xess
xff
XForwarded
XNG
XOB
XReal
yae
YAMLTo

View File

@ -24,12 +24,16 @@ func init() {
internal.InitSlog("debug")
}
func loadPolicies(t *testing.T, fname string) *policy.ParsedConfig {
func loadPolicies(t *testing.T, fname string, difficulty int) *policy.ParsedConfig {
t.Helper()
ctx := thothmock.WithMockThoth(t)
anubisPolicy, err := LoadPoliciesOrDefault(ctx, fname, anubis.DefaultDifficulty)
if fname == "" {
fname = "./testdata/test_config.yaml"
}
anubisPolicy, err := LoadPoliciesOrDefault(ctx, fname, difficulty)
if err != nil {
t.Fatal(err)
}
@ -176,8 +180,7 @@ func TestLoadPolicies(t *testing.T) {
// Regression test for CVE-2025-24369
func TestCVE2025_24369(t *testing.T) {
pol := loadPolicies(t, "")
pol.DefaultDifficulty = 4
pol := loadPolicies(t, "", anubis.DefaultDifficulty)
srv := spawnAnubis(t, Options{
Next: http.NewServeMux(),
@ -200,8 +203,7 @@ func TestCVE2025_24369(t *testing.T) {
}
func TestCookieCustomExpiration(t *testing.T) {
pol := loadPolicies(t, "")
pol.DefaultDifficulty = 0
pol := loadPolicies(t, "", 0)
ckieExpiration := 10 * time.Minute
srv := spawnAnubis(t, Options{
@ -250,8 +252,7 @@ func TestCookieCustomExpiration(t *testing.T) {
}
func TestCookieSettings(t *testing.T) {
pol := loadPolicies(t, "")
pol.DefaultDifficulty = 0
pol := loadPolicies(t, "", 0)
srv := spawnAnubis(t, Options{
Next: http.NewServeMux(),
@ -316,10 +317,7 @@ func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) {
for i := 1; i < 10; i++ {
t.Run(fmt.Sprint(i), func(t *testing.T) {
anubisPolicy, err := LoadPoliciesOrDefault(t.Context(), "", i)
if err != nil {
t.Fatal(err)
}
anubisPolicy := loadPolicies(t, "", i)
s, err := New(Options{
Next: h,
@ -337,11 +335,13 @@ func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) {
req.Header.Add("X-Real-Ip", "127.0.0.1")
_, bot, err := s.check(req)
cr, bot, err := s.check(req)
if err != nil {
t.Fatal(err)
}
t.Log(cr.Name)
if bot.Challenge.Difficulty != i {
t.Errorf("Challenge.Difficulty is wrong, wanted %d, got: %d", i, bot.Challenge.Difficulty)
}
@ -389,8 +389,7 @@ func TestBasePrefix(t *testing.T) {
// Reset the global BasePrefix before each test
anubis.BasePrefix = ""
pol := loadPolicies(t, "")
pol.DefaultDifficulty = 4
pol := loadPolicies(t, "", 4)
srv := spawnAnubis(t, Options{
Next: h,
@ -518,8 +517,7 @@ func TestCustomStatusCodes(t *testing.T) {
"DENY": 403,
}
pol := loadPolicies(t, "./testdata/aggressive_403.yaml")
pol.DefaultDifficulty = 4
pol := loadPolicies(t, "./testdata/aggressive_403.yaml", 4)
srv := spawnAnubis(t, Options{
Next: h,
@ -553,7 +551,7 @@ func TestCustomStatusCodes(t *testing.T) {
func TestCloudflareWorkersRule(t *testing.T) {
for _, variant := range []string{"cel", "header"} {
t.Run(variant, func(t *testing.T) {
pol := loadPolicies(t, "./testdata/cloudflare-workers-"+variant+".yaml")
pol := loadPolicies(t, "./testdata/cloudflare-workers-"+variant+".yaml", 0)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "OK")
@ -609,8 +607,7 @@ func TestCloudflareWorkersRule(t *testing.T) {
}
func TestRuleChange(t *testing.T) {
pol := loadPolicies(t, "testdata/rule_change.yaml")
pol.DefaultDifficulty = 0
pol := loadPolicies(t, "testdata/rule_change.yaml", 0)
ckieExpiration := 10 * time.Minute
srv := spawnAnubis(t, Options{

View File

@ -26,7 +26,7 @@ func TestBadConfigs(t *testing.T) {
for _, st := range finfos {
st := st
t.Run(st.Name(), func(t *testing.T) {
if _, err := LoadPoliciesOrDefault(t.Context(), filepath.Join("policy", "config", "testdata", "good", st.Name()), anubis.DefaultDifficulty); err == nil {
if _, err := LoadPoliciesOrDefault(t.Context(), filepath.Join("policy", "config", "testdata", "bad", st.Name()), anubis.DefaultDifficulty); err == nil {
t.Fatal(err)
} else {
t.Log(err)

View File

@ -0,0 +1,55 @@
package config
import (
"errors"
"fmt"
"testing"
)
func TestASNsValid(t *testing.T) {
for _, tt := range []struct {
name string
input *ASNs
err error
}{
{
name: "basic valid",
input: &ASNs{
Match: []uint32{13335}, // Cloudflare
},
},
{
name: "private ASN",
input: &ASNs{
Match: []uint32{64513, 4206942069}, // 16 and 32 bit private ASN
},
err: ErrPrivateASN,
},
} {
t.Run(tt.name, func(t *testing.T) {
if err := tt.input.Valid(); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Error("got wrong validation error")
}
})
}
}
func TestIsPrivateASN(t *testing.T) {
for _, tt := range []struct {
input uint32
output bool
}{
{13335, false}, // Cloudflare
{64513, true}, // 16 bit private ASN
{4206942069, true}, // 32 bit private ASN
} {
t.Run(fmt.Sprint(tt.input, "->", tt.output), func(t *testing.T) {
result := isPrivateASN(tt.input)
if result != tt.output {
t.Errorf("wanted isPrivateASN(%d) == %v, got: %v", tt.input, tt.output, result)
}
})
}
}

View File

@ -326,7 +326,7 @@ type fileConfig struct {
Bots []BotOrImport `json:"bots"`
DNSBL bool `json:"dnsbl"`
StatusCodes StatusCodes `json:"status_codes"`
Thresholds []Threshold `json:"threshold"`
Thresholds []Threshold `json:"thresholds"`
}
func (c *fileConfig) Valid() error {
@ -346,10 +346,6 @@ func (c *fileConfig) Valid() error {
errs = append(errs, err)
}
if len(c.Thresholds) == 0 {
errs = append(errs, ErrNoThresholdRulesDefined)
}
for i, t := range c.Thresholds {
if err := t.Valid(); err != nil {
errs = append(errs, fmt.Errorf("threshold %d: %w", i, err))
@ -369,7 +365,6 @@ func Load(fin io.Reader, fname string) (*Config, error) {
Challenge: http.StatusOK,
Deny: http.StatusOK,
},
Thresholds: DefaultThresholds,
}
if err := yaml.NewYAMLToJSONDecoder(fin).Decode(&c); err != nil {
@ -407,6 +402,10 @@ func Load(fin io.Reader, fname string) (*Config, error) {
}
}
if len(c.Thresholds) == 0 {
c.Thresholds = DefaultThresholds
}
for _, t := range c.Thresholds {
if err := t.Valid(); err != nil {
validationErrs = append(validationErrs, err)

View File

@ -8,7 +8,7 @@ import (
)
var (
countryCodeRegexp = regexp.MustCompile(`^\w{2}$`)
countryCodeRegexp = regexp.MustCompile(`^[a-zA-Z]{2}$`)
ErrNotCountryCode = errors.New("config.Bot: invalid country code")
)

View File

@ -0,0 +1,36 @@
package config
import (
"errors"
"testing"
)
func TestGeoIPValid(t *testing.T) {
for _, tt := range []struct {
name string
input *GeoIP
err error
}{
{
name: "basic valid",
input: &GeoIP{
Countries: []string{"CA"},
},
},
{
name: "invalid country",
input: &GeoIP{
Countries: []string{"XOB"},
},
err: ErrNotCountryCode,
},
} {
t.Run(tt.name, func(t *testing.T) {
if err := tt.input.Valid(); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Error("got wrong validation error")
}
})
}
}

View File

@ -0,0 +1,11 @@
bots:
- name: simple-weight-adjust
action: WEIGH
user_agent_regex: Mozilla
weight:
adjust: 5
thresholds:
- name: extreme-suspicion
expression: "true"
action: WEIGH

View File

@ -0,0 +1,15 @@
bots:
- name: simple-weight-adjust
action: WEIGH
user_agent_regex: Mozilla
weight:
adjust: 5
thresholds:
- name: extreme-suspicion
expression: "true"
action: WEIGH
challenge:
algorithm: fast
difficulty: 4
report_as: 4

View File

@ -3,6 +3,8 @@ package config
import (
"errors"
"fmt"
"os"
"path/filepath"
"testing"
)
@ -90,3 +92,20 @@ func TestDefaultThresholdsValid(t *testing.T) {
})
}
}
func TestLoadActuallyLoadsThresholds(t *testing.T) {
fin, err := os.Open(filepath.Join(".", "testdata", "good", "thresholds.yaml"))
if err != nil {
t.Fatal(err)
}
defer fin.Close()
c, err := Load(fin, fin.Name())
if err != nil {
t.Fatal(err)
}
if len(c.Thresholds) != 4 {
t.Errorf("wanted 4 thresholds, got %d thresholds", len(c.Thresholds))
}
}

38
lib/testdata/test_config.yaml vendored Normal file
View File

@ -0,0 +1,38 @@
bots:
- import: (data)/bots/_deny-pathological.yaml
- import: (data)/bots/aggressive-brazilian-scrapers.yaml
- import: (data)/meta/ai-block-aggressive.yaml
- import: (data)/crawlers/_allow-good.yaml
- import: (data)/clients/x-firefox-ai.yaml
- import: (data)/common/keep-internet-working.yaml
- name: countries-with-aggressive-scrapers
action: WEIGH
geoip:
countries:
- BR
- CN
weight:
adjust: 10
- name: aggressive-asns-without-functional-abuse-contact
action: WEIGH
asns:
match:
- 13335 # Cloudflare
- 136907 # Huawei Cloud
- 45102 # Alibaba Cloud
weight:
adjust: 10
- name: generic-browser
user_agent_regex: >-
Mozilla|Opera
action: WEIGH
weight:
adjust: 10
dnsbl: false
status_codes:
CHALLENGE: 200
DENY: 200
thresholds: []