anubis/cmd/anubis/main.go
CXM 3520421757
fix: determine bind network from bind address (#714)
* fix: determine bind network from bind address

* docs: update CHANGELOG

Signed-off-by: Xe Iaso <me@xeiaso.net>

---------

Signed-off-by: Xe Iaso <me@xeiaso.net>
Co-authored-by: Xe Iaso <me@xeiaso.net>
2025-06-27 17:57:37 +00:00

513 lines
18 KiB
Go

package main
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"embed"
"encoding/hex"
"errors"
"flag"
"fmt"
"io/fs"
"log"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/TecharoHQ/anubis"
"github.com/TecharoHQ/anubis/data"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/internal/thoth"
libanubis "github.com/TecharoHQ/anubis/lib"
botPolicy "github.com/TecharoHQ/anubis/lib/policy"
"github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/web"
"github.com/facebookgo/flagenv"
_ "github.com/joho/godotenv/autoload"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
basePrefix = flag.String("base-prefix", "", "base prefix (root URL) the application is served under e.g. /myapp")
bind = flag.String("bind", ":8923", "network address to bind HTTP to")
bindNetwork = flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
challengeDifficulty = flag.Int("difficulty", anubis.DefaultDifficulty, "difficulty of the challenge")
cookieDomain = flag.String("cookie-domain", "", "if set, the top-level domain that the Anubis cookie will be valid for")
cookieDynamicDomain = flag.Bool("cookie-dynamic-domain", false, "if set, automatically set the cookie Domain value based on the request domain")
cookieExpiration = flag.Duration("cookie-expiration-time", anubis.CookieDefaultExpirationTime, "The amount of time the authorization cookie is valid for")
cookiePartitioned = flag.Bool("cookie-partitioned", false, "if true, sets the partitioned flag on Anubis cookies, enabling CHIPS support")
hs512Secret = flag.String("hs512-secret", "", "secret used to sign JWTs, uses ed25519 if not set")
ed25519PrivateKeyHex = flag.String("ed25519-private-key-hex", "", "private key used to sign JWTs, if not set a random one will be assigned")
ed25519PrivateKeyHexFile = flag.String("ed25519-private-key-hex-file", "", "file name containing value for ed25519-private-key-hex")
metricsBind = flag.String("metrics-bind", ":9090", "network address to bind metrics to")
metricsBindNetwork = flag.String("metrics-bind-network", "tcp", "network family for the metrics server to bind to")
socketMode = flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.")
robotsTxt = flag.Bool("serve-robots-txt", false, "serve a robots.txt file that disallows all robots")
policyFname = flag.String("policy-fname", "", "full path to anubis policy document (defaults to a sensible built-in policy)")
redirectDomains = flag.String("redirect-domains", "", "list of domains separated by commas which anubis is allowed to redirect to. Leaving this unset allows any domain.")
slogLevel = flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
stripBasePrefix = flag.Bool("strip-base-prefix", false, "if true, strips the base prefix from requests forwarded to the target server")
target = flag.String("target", "http://localhost:3923", "target to reverse proxy to, set to an empty string to disable proxying when only using auth request")
targetSNI = flag.String("target-sni", "", "if set, the value of the TLS handshake hostname when forwarding requests to the target")
targetHost = flag.String("target-host", "", "if set, the value of the Host header when forwarding requests to the target")
targetInsecureSkipVerify = flag.Bool("target-insecure-skip-verify", false, "if true, skips TLS validation for the backend")
healthcheck = flag.Bool("healthcheck", false, "run a health check against Anubis")
useRemoteAddress = flag.Bool("use-remote-address", false, "read the client's IP address from the network request, useful for debugging and running Anubis on bare metal")
debugBenchmarkJS = flag.Bool("debug-benchmark-js", false, "respond to every request with a challenge for benchmarking hashrate")
ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough")
ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time")
ogCacheConsiderHost = flag.Bool("og-cache-consider-host", false, "enable or disable the use of the host in the Open Graph tag cache")
extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder")
webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals")
versionFlag = flag.Bool("version", false, "print Anubis version")
xffStripPrivate = flag.Bool("xff-strip-private", true, "if set, strip private addresses from X-Forwarded-For")
thothInsecure = flag.Bool("thoth-insecure", false, "if set, connect to Thoth over plain HTTP/2, don't enable this unless support told you to")
thothURL = flag.String("thoth-url", "", "if set, URL for Thoth, the IP reputation database for Anubis")
thothToken = flag.String("thoth-token", "", "if set, API token for Thoth, the IP reputation database for Anubis")
)
func keyFromHex(value string) (ed25519.PrivateKey, error) {
keyBytes, err := hex.DecodeString(value)
if err != nil {
return nil, fmt.Errorf("supplied key is not hex-encoded: %w", err)
}
if len(keyBytes) != ed25519.SeedSize {
return nil, fmt.Errorf("supplied key is not %d bytes long, got %d bytes", ed25519.SeedSize, len(keyBytes))
}
return ed25519.NewKeyFromSeed(keyBytes), nil
}
func doHealthCheck() error {
resp, err := http.Get("http://localhost" + *metricsBind + anubis.BasePrefix + "/metrics")
if err != nil {
return fmt.Errorf("failed to fetch metrics: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return nil
}
// parseBindNetFromAddr determine bind network and address based on the given network and address.
func parseBindNetFromAddr(address string) (string, string) {
defaultScheme := "http://"
if !strings.Contains(address, "://") {
if strings.HasPrefix(address, ":") {
address = defaultScheme + "localhost" + address
} else {
address = defaultScheme + address
}
}
bindUri, err := url.Parse(address)
if err != nil {
log.Fatal(fmt.Errorf("failed to parse bind URL: %w", err))
}
switch bindUri.Scheme {
case "unix":
return "unix", bindUri.Path
case "tcp", "http", "https":
return "tcp", bindUri.Host
default:
log.Fatal(fmt.Errorf("unsupported network scheme %s in address %s", bindUri.Scheme, address))
}
return "", address
}
func setupListener(network string, address string) (net.Listener, string) {
formattedAddress := ""
if network == "" {
// keep compatibility
network, address = parseBindNetFromAddr(address)
}
switch network {
case "unix":
formattedAddress = "unix:" + address
case "tcp":
if strings.HasPrefix(address, ":") { // assume it's just a port e.g. :4259
formattedAddress = "http://localhost" + address
} else {
formattedAddress = "http://" + address
}
default:
formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
}
listener, err := net.Listen(network, address)
if err != nil {
log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err))
}
// additional permission handling for unix sockets
if network == "unix" {
mode, err := strconv.ParseUint(*socketMode, 8, 0)
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", *socketMode, err))
}
err = os.Chmod(address, os.FileMode(mode))
if err != nil {
err := listener.Close()
if err != nil {
log.Printf("failed to close listener: %v", err)
}
log.Fatal(fmt.Errorf("could not change socket mode: %w", err))
}
}
return listener, formattedAddress
}
func makeReverseProxy(target string, targetSNI string, targetHost string, insecureSkipVerify bool) (http.Handler, error) {
targetUri, err := url.Parse(target)
if err != nil {
return nil, fmt.Errorf("failed to parse target URL: %w", err)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
if targetUri.Scheme == "unix" {
// clean path up so we don't use the socket path in proxied requests
addr := targetUri.Path
targetUri.Path = ""
// tell transport how to dial unix sockets
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", addr)
}
// tell transport how to handle the unix url scheme
transport.RegisterProtocol("unix", libanubis.UnixRoundTripper{Transport: transport})
}
if insecureSkipVerify || targetSNI != "" {
transport.TLSClientConfig = &tls.Config{}
if insecureSkipVerify {
slog.Warn("TARGET_INSECURE_SKIP_VERIFY is set to true, TLS certificate validation will not be performed", "target", target)
transport.TLSClientConfig.InsecureSkipVerify = true
}
if targetSNI != "" {
transport.TLSClientConfig.ServerName = targetSNI
}
}
rp := httputil.NewSingleHostReverseProxy(targetUri)
rp.Transport = transport
if targetHost != "" {
originalDirector := rp.Director
rp.Director = func(req *http.Request) {
originalDirector(req)
req.Host = targetHost
}
}
return rp, nil
}
func startDecayMapCleanup(ctx context.Context, s *libanubis.Server) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.CleanupDecayMap()
case <-ctx.Done():
return
}
}
}
func main() {
flagenv.Parse()
flag.Parse()
if *versionFlag {
fmt.Println("Anubis", anubis.Version)
return
}
internal.InitSlog(*slogLevel)
if *extractResources != "" {
if err := extractEmbedFS(data.BotPolicies, ".", *extractResources); err != nil {
log.Fatal(err)
}
if err := extractEmbedFS(web.Static, "static", *extractResources); err != nil {
log.Fatal(err)
}
fmt.Printf("Extracted embedded static files to %s\n", *extractResources)
return
}
var rp http.Handler
// when using anubis via Systemd and environment variables, then it is not possible to set targe to an empty string but only to space
if strings.TrimSpace(*target) != "" {
var err error
rp, err = makeReverseProxy(*target, *targetSNI, *targetHost, *targetInsecureSkipVerify)
if err != nil {
log.Fatalf("can't make reverse proxy: %v", err)
}
}
if *cookieDomain != "" && *cookieDynamicDomain {
log.Fatalf("you can't set COOKIE_DOMAIN and COOKIE_DYNAMIC_DOMAIN at the same time")
}
ctx := context.Background()
// Thoth configuration
switch {
case *thothURL != "" && *thothToken == "":
slog.Warn("THOTH_URL is set but no THOTH_TOKEN is set")
case *thothURL == "" && *thothToken != "":
slog.Warn("THOTH_TOKEN is set but no THOTH_URL is set")
case *thothURL != "" && *thothToken != "":
slog.Debug("connecting to Thoth")
thothClient, err := thoth.New(ctx, *thothURL, *thothToken, *thothInsecure)
if err != nil {
log.Fatalf("can't dial thoth at %s: %v", *thothURL, err)
}
ctx = thoth.With(ctx, thothClient)
}
policy, err := libanubis.LoadPoliciesOrDefault(ctx, *policyFname, *challengeDifficulty)
if err != nil {
log.Fatalf("can't parse policy file: %v", err)
}
ruleErrorIDs := make(map[string]string)
for _, rule := range policy.Bots {
if rule.Action != config.RuleDeny {
continue
}
hash := rule.Hash()
ruleErrorIDs[rule.Name] = hash
}
// replace the bot policy rules with a single rule that always benchmarks
if *debugBenchmarkJS {
policy.Bots = []botPolicy.Bot{{
Name: "",
Rules: botPolicy.NewHeaderExistsChecker("User-Agent"),
Action: config.RuleBenchmark,
}}
}
if *basePrefix != "" && !strings.HasPrefix(*basePrefix, "/") {
log.Fatalf("[misconfiguration] base-prefix must start with a slash, eg: /%s", *basePrefix)
} else if strings.HasSuffix(*basePrefix, "/") {
log.Fatalf("[misconfiguration] base-prefix must not end with a slash")
}
if *stripBasePrefix && *basePrefix == "" {
log.Fatalf("[misconfiguration] strip-base-prefix is set to true, but base-prefix is not set, " +
"this may result in unexpected behavior")
}
var ed25519Priv ed25519.PrivateKey
if *hs512Secret != "" && (*ed25519PrivateKeyHex != "" || *ed25519PrivateKeyHexFile != "") {
log.Fatal("do not specify both HS512 and ED25519 secrets")
} else if *hs512Secret != "" {
ed25519Priv = ed25519.PrivateKey(*hs512Secret)
} else if *ed25519PrivateKeyHex != "" && *ed25519PrivateKeyHexFile != "" {
log.Fatal("do not specify both ED25519_PRIVATE_KEY_HEX and ED25519_PRIVATE_KEY_HEX_FILE")
} else if *ed25519PrivateKeyHex != "" {
ed25519Priv, err = keyFromHex(*ed25519PrivateKeyHex)
if err != nil {
log.Fatalf("failed to parse and validate ED25519_PRIVATE_KEY_HEX: %v", err)
}
} else if *ed25519PrivateKeyHexFile != "" {
hexFile, err := os.ReadFile(*ed25519PrivateKeyHexFile)
if err != nil {
log.Fatalf("failed to read ED25519_PRIVATE_KEY_HEX_FILE %s: %v", *ed25519PrivateKeyHexFile, err)
}
ed25519Priv, err = keyFromHex(string(bytes.TrimSpace(hexFile)))
if err != nil {
log.Fatalf("failed to parse and validate content of ED25519_PRIVATE_KEY_HEX_FILE: %v", err)
}
} else {
_, ed25519Priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
log.Fatalf("failed to generate ed25519 key: %v", err)
}
slog.Warn("generating random key, Anubis will have strange behavior when multiple instances are behind the same load balancer target, for more information: see https://anubis.techaro.lol/docs/admin/installation#key-generation")
}
var redirectDomainsList []string
if *redirectDomains != "" {
domains := strings.Split(*redirectDomains, ",")
for _, domain := range domains {
_, err = url.Parse(domain)
if err != nil {
log.Fatalf("cannot parse redirect-domain %q: %s", domain, err.Error())
}
redirectDomainsList = append(redirectDomainsList, strings.TrimSpace(domain))
}
} else {
slog.Warn("REDIRECT_DOMAINS is not set, Anubis will only redirect to the same domain a request is coming from, see https://anubis.techaro.lol/docs/admin/configuration/redirect-domains")
}
// If OpenGraph configuration values are not set in the config file, use the
// values from flags / envvars.
if !policy.OpenGraph.Enabled {
policy.OpenGraph.Enabled = *ogPassthrough
policy.OpenGraph.ConsiderHost = *ogCacheConsiderHost
policy.OpenGraph.TimeToLive = *ogTimeToLive
policy.OpenGraph.Override = map[string]string{}
}
s, err := libanubis.New(libanubis.Options{
BasePrefix: *basePrefix,
StripBasePrefix: *stripBasePrefix,
Next: rp,
Policy: policy,
ServeRobotsTXT: *robotsTxt,
ED25519PrivateKey: ed25519Priv,
HS512Secret: []byte(*hs512Secret),
CookieDomain: *cookieDomain,
CookieExpiration: *cookieExpiration,
CookiePartitioned: *cookiePartitioned,
RedirectDomains: redirectDomainsList,
Target: *target,
WebmasterEmail: *webmasterEmail,
OpenGraph: policy.OpenGraph,
})
if err != nil {
log.Fatalf("can't construct libanubis.Server: %v", err)
}
wg := new(sync.WaitGroup)
// install signal handler
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
if *metricsBind != "" {
wg.Add(1)
go metricsServer(ctx, wg.Done)
}
go startDecayMapCleanup(ctx, s)
var h http.Handler
h = s
h = internal.RemoteXRealIP(*useRemoteAddress, *bindNetwork, h)
h = internal.XForwardedForToXRealIP(h)
h = internal.XForwardedForUpdate(*xffStripPrivate, h)
srv := http.Server{Handler: h, ErrorLog: internal.GetFilteredHTTPLogger()}
listener, listenerUrl := setupListener(*bindNetwork, *bind)
slog.Info(
"listening",
"url", listenerUrl,
"difficulty", *challengeDifficulty,
"serveRobotsTXT", *robotsTxt,
"target", *target,
"version", anubis.Version,
"use-remote-address", *useRemoteAddress,
"debug-benchmark-js", *debugBenchmarkJS,
"og-passthrough", *ogPassthrough,
"og-expiry-time", *ogTimeToLive,
"base-prefix", *basePrefix,
"cookie-expiration-time", *cookieExpiration,
"rule-error-ids", ruleErrorIDs,
)
go func() {
<-ctx.Done()
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(c); err != nil {
log.Printf("cannot shut down: %v", err)
}
}()
if err := srv.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
wg.Wait()
}
func metricsServer(ctx context.Context, done func()) {
defer done()
mux := http.NewServeMux()
mux.Handle(anubis.BasePrefix+"/metrics", promhttp.Handler())
srv := http.Server{Handler: mux, ErrorLog: internal.GetFilteredHTTPLogger()}
listener, metricsUrl := setupListener(*metricsBindNetwork, *metricsBind)
slog.Debug("listening for metrics", "url", metricsUrl)
if *healthcheck {
log.Println("running healthcheck")
if err := doHealthCheck(); err != nil {
log.Fatal(err)
}
return
}
go func() {
<-ctx.Done()
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(c); err != nil {
log.Printf("cannot shut down: %v", err)
}
}()
if err := srv.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
}
func extractEmbedFS(fsys embed.FS, root string, destDir string) error {
return fs.WalkDir(fsys, root, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
relPath, err := filepath.Rel(root, path)
if err != nil {
return err
}
destPath := filepath.Join(destDir, root, relPath)
if d.IsDir() {
return os.MkdirAll(destPath, 0o700)
}
embeddedData, err := fs.ReadFile(fsys, path)
if err != nil {
return err
}
return os.WriteFile(destPath, embeddedData, 0o644)
})
}