diff --git a/lib/checker/all/all.go b/lib/checker/all/all.go new file mode 100644 index 0000000..89a8d7b --- /dev/null +++ b/lib/checker/all/all.go @@ -0,0 +1,8 @@ +// Package all imports all of the standard checker types. +package all + +import ( + _ "github.com/TecharoHQ/anubis/lib/checker/headerexists" + _ "github.com/TecharoHQ/anubis/lib/checker/headermatches" + _ "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" +) diff --git a/lib/checker/headerexists/checker.go b/lib/checker/headerexists/checker.go new file mode 100644 index 0000000..27cf11b --- /dev/null +++ b/lib/checker/headerexists/checker.go @@ -0,0 +1,32 @@ +package headerexists + +import ( + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" +) + +func New(key string) checker.Interface { + return headerExistsChecker{ + header: strings.TrimSpace(http.CanonicalHeaderKey(key)), + hash: internal.FastHash(key), + } +} + +type headerExistsChecker struct { + header, hash string +} + +func (hec headerExistsChecker) Check(r *http.Request) (bool, error) { + if r.Header.Get(hec.header) != "" { + return true, nil + } + + return false, nil +} + +func (hec headerExistsChecker) Hash() string { + return hec.hash +} diff --git a/lib/checker/headerexists/checker_test.go b/lib/checker/headerexists/checker_test.go new file mode 100644 index 0000000..627cab2 --- /dev/null +++ b/lib/checker/headerexists/checker_test.go @@ -0,0 +1,57 @@ +package headerexists + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" +) + +func TestChecker(t *testing.T) { + fac := Factory{} + + for _, tt := range []struct { + name string + header string + reqHeader string + ok bool + }{ + { + name: "match", + header: "Authorization", + reqHeader: "Authorization", + ok: true, + }, + { + name: "not_match", + header: "Authorization", + reqHeader: "Authentication", + }, + } { + t.Run(tt.name, func(t *testing.T) { + hec, err := fac.Build(t.Context(), json.RawMessage(fmt.Sprintf("%q", tt.header))) + if err != nil { + t.Fatal(err) + } + + t.Log(hec.Hash()) + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeader, "hunter2") + + ok, err := hec.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil { + t.Errorf("err: %v", err) + } + }) + } +} diff --git a/lib/checker/headerexists/factory.go b/lib/checker/headerexists/factory.go new file mode 100644 index 0000000..7953e01 --- /dev/null +++ b/lib/checker/headerexists/factory.go @@ -0,0 +1,40 @@ +package headerexists + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +type Factory struct{} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var headerName string + + if err := json.Unmarshal([]byte(data), &headerName); err != nil { + return nil, fmt.Errorf("%w: want string", checker.ErrUnparseableConfig) + } + + if err := f.Valid(ctx, data); err != nil { + return nil, err + } + + return New(http.CanonicalHeaderKey(headerName)), nil +} + +func (Factory) Valid(ctx context.Context, data json.RawMessage) error { + var headerName string + + if err := json.Unmarshal([]byte(data), &headerName); err != nil { + return fmt.Errorf("%w: want string", checker.ErrUnparseableConfig) + } + + if headerName == "" { + return fmt.Errorf("%w: string must not be empty", checker.ErrInvalidConfig) + } + + return nil +} diff --git a/lib/checker/headerexists/factory_test.go b/lib/checker/headerexists/factory_test.go new file mode 100644 index 0000000..644d9fb --- /dev/null +++ b/lib/checker/headerexists/factory_test.go @@ -0,0 +1,60 @@ +package headerexists + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestFactoryGood(t *testing.T) { + files, err := os.ReadDir("./testdata/good") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestFactoryBad(t *testing.T) { + files, err := os.ReadDir("./testdata/bad") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name())) + if err != nil { + t.Fatal(err) + } + + t.Run("Build", func(t *testing.T) { + if _, err := fac.Build(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal(err) + } + }) + + t.Run("Valid", func(t *testing.T) { + if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal(err) + } + }) + }) + } +} diff --git a/lib/checker/headerexists/testdata/bad/empty.json b/lib/checker/headerexists/testdata/bad/empty.json new file mode 100644 index 0000000..3cc762b --- /dev/null +++ b/lib/checker/headerexists/testdata/bad/empty.json @@ -0,0 +1 @@ +"" \ No newline at end of file diff --git a/lib/checker/headerexists/testdata/bad/object.json b/lib/checker/headerexists/testdata/bad/object.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/lib/checker/headerexists/testdata/bad/object.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/lib/checker/headerexists/testdata/good/authorization.json b/lib/checker/headerexists/testdata/good/authorization.json new file mode 100644 index 0000000..7329827 --- /dev/null +++ b/lib/checker/headerexists/testdata/good/authorization.json @@ -0,0 +1 @@ +"Authorization" \ No newline at end of file diff --git a/lib/checker/headermatches/checker.go b/lib/checker/headermatches/checker.go new file mode 100644 index 0000000..0f52f29 --- /dev/null +++ b/lib/checker/headermatches/checker.go @@ -0,0 +1,24 @@ +package headermatches + +import ( + "net/http" + "regexp" +) + +type Checker struct { + header string + regexp *regexp.Regexp + hash string +} + +func (c *Checker) Check(r *http.Request) (bool, error) { + if c.regexp.MatchString(r.Header.Get(c.header)) { + return true, nil + } + + return false, nil +} + +func (c *Checker) Hash() string { + return c.hash +} diff --git a/lib/checker/headermatches/checker_test.go b/lib/checker/headermatches/checker_test.go new file mode 100644 index 0000000..9928c7a --- /dev/null +++ b/lib/checker/headermatches/checker_test.go @@ -0,0 +1,98 @@ +package headermatches + +import ( + "encoding/json" + "errors" + "net/http" + "testing" +) + +func TestChecker(t *testing.T) { + +} + +func TestHeaderMatchesChecker(t *testing.T) { + fac := Factory{} + + for _, tt := range []struct { + err error + name string + header string + rexStr string + reqHeaderKey string + reqHeaderValue string + ok bool + }{ + { + name: "match", + header: "Cf-Worker", + rexStr: ".*", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: true, + err: nil, + }, + { + name: "not_match", + header: "Cf-Worker", + rexStr: "false", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "not_present", + header: "Cf-Worker", + rexStr: "foobar", + reqHeaderKey: "Something-Else", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "invalid_regex", + rexStr: "a(b", + err: ErrInvalidRegex, + }, + } { + t.Run(tt.name, func(t *testing.T) { + fc := fileConfig{ + Header: tt.header, + ValueRegex: tt.rexStr, + } + data, err := json.Marshal(fc) + if err != nil { + t.Fatal(err) + } + + hmc, err := fac.Build(t.Context(), json.RawMessage(data)) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating HeaderMatchesChecker failed") + } + + if tt.err != nil && hmc == nil { + return + } + + t.Log(hmc.Hash()) + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue) + + ok, err := hmc.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} diff --git a/lib/checker/headermatches/config.go b/lib/checker/headermatches/config.go new file mode 100644 index 0000000..7a0e7be --- /dev/null +++ b/lib/checker/headermatches/config.go @@ -0,0 +1,44 @@ +package headermatches + +import ( + "errors" + "fmt" + "regexp" +) + +var ( + ErrNoHeader = errors.New("headermatches: no header is configured") + ErrNoValueRegex = errors.New("headermatches: no value regex is configured") + ErrInvalidRegex = errors.New("headermatches: value regex is invalid") +) + +type fileConfig struct { + Header string `json:"header" yaml:"header"` + ValueRegex string `json:"value_regex" yaml:"value_regex"` +} + +func (fc fileConfig) String() string { + return fmt.Sprintf("header=%q value_regex=%q", fc.Header, fc.ValueRegex) +} + +func (fc fileConfig) Valid() error { + var errs []error + + if fc.Header == "" { + errs = append(errs, ErrNoHeader) + } + + if fc.ValueRegex == "" { + errs = append(errs, ErrNoValueRegex) + } + + if _, err := regexp.Compile(fc.ValueRegex); err != nil { + errs = append(errs, ErrInvalidRegex, err) + } + + if len(errs) != 0 { + return errors.Join(errs...) + } + + return nil +} diff --git a/lib/checker/headermatches/config_test.go b/lib/checker/headermatches/config_test.go new file mode 100644 index 0000000..8f190f1 --- /dev/null +++ b/lib/checker/headermatches/config_test.go @@ -0,0 +1,55 @@ +package headermatches + +import ( + "errors" + "testing" +) + +func TestFileConfigValid(t *testing.T) { + for _, tt := range []struct { + name, description string + in fileConfig + err error + }{ + { + name: "simple happy", + description: "the most common usecase", + in: fileConfig{ + Header: "User-Agent", + ValueRegex: ".*", + }, + }, + { + name: "no header", + description: "Header must be set, it is not", + in: fileConfig{ + ValueRegex: ".*", + }, + err: ErrNoHeader, + }, + { + name: "no value regex", + description: "ValueRegex must be set, it is not", + in: fileConfig{ + Header: "User-Agent", + }, + err: ErrNoValueRegex, + }, + { + name: "invalid regex", + description: "the user wrote an invalid value regular expression", + in: fileConfig{ + Header: "User-Agent", + ValueRegex: "[a-z", + }, + err: ErrInvalidRegex, + }, + } { + t.Run(tt.name, func(t *testing.T) { + if err := tt.in.Valid(); !errors.Is(err, tt.err) { + t.Log(tt.description) + t.Fatal(err) + } + }) + } +} diff --git a/lib/checker/headermatches/factory.go b/lib/checker/headermatches/factory.go new file mode 100644 index 0000000..4e32db2 --- /dev/null +++ b/lib/checker/headermatches/factory.go @@ -0,0 +1,66 @@ +package headermatches + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "regexp" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" +) + +func init() { + checker.Register("header_matches", Factory{}) + checker.Register("user_agent", Factory{defaultHeader: "User-Agent"}) +} + +type Factory struct { + defaultHeader string +} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var fc fileConfig + + if f.defaultHeader != "" { + fc.Header = f.defaultHeader + } + + if err := json.Unmarshal([]byte(data), &fc); err != nil { + return nil, errors.Join(checker.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return nil, errors.Join(checker.ErrInvalidConfig, err) + } + + valueRex, err := regexp.Compile(fc.ValueRegex) + if err != nil { + return nil, errors.Join(ErrInvalidRegex, err) + } + + return &Checker{ + header: http.CanonicalHeaderKey(fc.Header), + regexp: valueRex, + hash: internal.FastHash(fc.String()), + }, nil +} + +func (f Factory) Valid(ctx context.Context, data json.RawMessage) error { + var fc fileConfig + + if f.defaultHeader != "" { + fc.Header = f.defaultHeader + } + + if err := json.Unmarshal([]byte(data), &fc); err != nil { + return err + } + + if err := fc.Valid(); err != nil { + return err + } + + return nil +} diff --git a/lib/checker/headermatches/factory_test.go b/lib/checker/headermatches/factory_test.go new file mode 100644 index 0000000..414d86c --- /dev/null +++ b/lib/checker/headermatches/factory_test.go @@ -0,0 +1,52 @@ +package headermatches + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestFactoryGood(t *testing.T) { + files, err := os.ReadDir("./testdata/good") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestFactoryBad(t *testing.T) { + files, err := os.ReadDir("./testdata/bad") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal(err) + } + }) + } +} diff --git a/lib/checker/headermatches/testdata/bad/invalid_config.json b/lib/checker/headermatches/testdata/bad/invalid_config.json new file mode 100644 index 0000000..ff30235 --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/invalid_config.json @@ -0,0 +1 @@ +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/invalid_value_regex.json b/lib/checker/headermatches/testdata/bad/invalid_value_regex.json new file mode 100644 index 0000000..6df6af2 --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/invalid_value_regex.json @@ -0,0 +1,4 @@ +{ + "header": "User-Agent", + "value_regex": "a(b" +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/no_header.json b/lib/checker/headermatches/testdata/bad/no_header.json new file mode 100644 index 0000000..21e543e --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/no_header.json @@ -0,0 +1,3 @@ +{ + "value_regex": "PaleMoon" +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/no_value_regex.json b/lib/checker/headermatches/testdata/bad/no_value_regex.json new file mode 100644 index 0000000..54a27a6 --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/no_value_regex.json @@ -0,0 +1,3 @@ +{ + "header": "User-Agent" +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/nothing.json b/lib/checker/headermatches/testdata/bad/nothing.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/nothing.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/good/simple.json b/lib/checker/headermatches/testdata/good/simple.json new file mode 100644 index 0000000..bbfa97e --- /dev/null +++ b/lib/checker/headermatches/testdata/good/simple.json @@ -0,0 +1,4 @@ +{ + "header": "User-Agent", + "value_regex": "PaleMoon" +} \ No newline at end of file diff --git a/lib/checker/headermatches/useragent.go b/lib/checker/headermatches/useragent.go new file mode 100644 index 0000000..006a91d --- /dev/null +++ b/lib/checker/headermatches/useragent.go @@ -0,0 +1 @@ +package headermatches