anubis/lib/policy/checker.go
Jason Cameron b2b2679bae
perf: replace cidranger with bart for significant performance improvements (#675)
* feat: replace cidranger with bart improving performance by 3-20x

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* perf: replace cidranger with bart for IP range checking

- Replace cidranger.Ranger with bart.Lite in RemoteAddrChecker
- Use netip.ParsePrefix instead of net.ParseCIDR for modern IP handling
- Improve performance: 3-20x faster lookups with zero heap allocations
- Update imports to use github.com/gaissmai/bart and net/netip
- Remove cidranger dependency from go.mod

Benchmark results:
- IPv4 lookups: 4x faster (15.58ns vs 63.25ns, 0 vs 2 allocs)
- IPv6 lookups: 3x faster (26.51ns vs 76.96ns, 0 vs 2 allocs)
- Insertions: 20x faster (976ns vs 19,191ns)
- Large tables: 14x faster (5.2ns vs 74.85ns)

* docs: clarify CHANGELOG to not give false impressions

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* perf: optimize string concatenation in RemoteAddrChecker hash generation

Replace fmt.Fprintln with strings.Join for 7x faster performance:
- Before: 935.1 ns/op, 784 B/op, 22 allocs/op
- After: 133.2 ns/op, 192 B/op, 1 alloc/op

The hash is used for JWT cookie validation and error code generation.
Comma separation provides the same deterministic uniqueness as newlines
but with significantly better performance during policy initialization.

* chore: remove accidentally commited string benchmark

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* style: apply Copilot suggestions

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* fix: reference the right var name

i cannot write a merge commit

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

---------

Signed-off-by: Jason Cameron <git@jasoncameron.dev>
2025-06-17 11:57:55 -04:00

175 lines
3.9 KiB
Go

package policy
import (
"errors"
"fmt"
"net/http"
"net/netip"
"regexp"
"strings"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/checker"
"github.com/gaissmai/bart"
)
var (
ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
)
type staticHashChecker struct {
hash string
}
func (staticHashChecker) Check(r *http.Request) (bool, error) {
return true, nil
}
func (s staticHashChecker) Hash() string { return s.hash }
func NewStaticHashChecker(hashable string) checker.Impl {
return staticHashChecker{hash: internal.FastHash(hashable)}
}
type RemoteAddrChecker struct {
prefixTable *bart.Lite
hash string
}
func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) {
table := new(bart.Lite)
for _, cidr := range cidrs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err)
}
table.Insert(prefix)
}
return &RemoteAddrChecker{
prefixTable: table,
hash: internal.FastHash(strings.Join(cidrs, ",")),
}, nil
}
func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
host := r.Header.Get("X-Real-Ip")
if host == "" {
return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration)
}
addr, err := netip.ParseAddr(host)
if err != nil {
return false, fmt.Errorf("%w: %s is not an IP address: %w", ErrMisconfiguration, host, err)
}
return rac.prefixTable.Contains(addr), nil
}
func (rac *RemoteAddrChecker) Hash() string {
return rac.hash
}
type HeaderMatchesChecker struct {
header string
regexp *regexp.Regexp
hash string
}
func NewUserAgentChecker(rexStr string) (checker.Impl, error) {
return NewHeaderMatchesChecker("User-Agent", rexStr)
}
func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
}
return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil
}
func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
return true, nil
}
return false, nil
}
func (hmc *HeaderMatchesChecker) Hash() string {
return hmc.hash
}
type PathChecker struct {
regexp *regexp.Regexp
hash string
}
func NewPathChecker(rexStr string) (checker.Impl, error) {
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil {
return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
}
return &PathChecker{rex, internal.FastHash(rexStr)}, nil
}
func (pc *PathChecker) Check(r *http.Request) (bool, error) {
if pc.regexp.MatchString(r.URL.Path) {
return true, nil
}
return false, nil
}
func (pc *PathChecker) Hash() string {
return pc.hash
}
func NewHeaderExistsChecker(key string) checker.Impl {
return headerExistsChecker{strings.TrimSpace(key)}
}
type headerExistsChecker struct {
header string
}
func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
if r.Header.Get(hec.header) != "" {
return true, nil
}
return false, nil
}
func (hec headerExistsChecker) Hash() string {
return internal.FastHash(hec.header)
}
func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) {
var result checker.List
var errs []error
for key, rexStr := range headermap {
if rexStr == ".*" {
result = append(result, headerExistsChecker{strings.TrimSpace(key)})
continue
}
rex, err := regexp.Compile(strings.TrimSpace(rexStr))
if err != nil {
errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
continue
}
result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)})
}
if len(errs) != 0 {
return nil, errors.Join(errs...)
}
return result, nil
}