diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go
index ff2d14f..76f99fe 100644
--- a/cmd/anubis/main.go
+++ b/cmd/anubis/main.go
@@ -59,6 +59,7 @@ var (
debugBenchmarkJS = flag.Bool("debug-benchmark-js", false, "respond to every request with a challenge for benchmarking hashrate")
ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough")
ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time")
+ ogCacheConsiderHost = flag.Bool("og-cache-consider-host", false, "enable or disable the use of the host in the Open Graph tag cache")
extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder")
webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals")
)
@@ -272,18 +273,19 @@ func main() {
}
s, err := libanubis.New(libanubis.Options{
- BasePrefix: *basePrefix,
- Next: rp,
- Policy: policy,
- ServeRobotsTXT: *robotsTxt,
- PrivateKey: priv,
- CookieDomain: *cookieDomain,
- CookiePartitioned: *cookiePartitioned,
- OGPassthrough: *ogPassthrough,
- OGTimeToLive: *ogTimeToLive,
- RedirectDomains: redirectDomainsList,
- Target: *target,
- WebmasterEmail: *webmasterEmail,
+ BasePrefix: *basePrefix,
+ Next: rp,
+ Policy: policy,
+ ServeRobotsTXT: *robotsTxt,
+ PrivateKey: priv,
+ CookieDomain: *cookieDomain,
+ CookiePartitioned: *cookiePartitioned,
+ OGPassthrough: *ogPassthrough,
+ OGTimeToLive: *ogTimeToLive,
+ RedirectDomains: redirectDomainsList,
+ Target: *target,
+ WebmasterEmail: *webmasterEmail,
+ OGCacheConsidersHost: *ogCacheConsiderHost,
})
if err != nil {
log.Fatalf("can't construct libanubis.Server: %v", err)
diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md
index efaae6b..277add0 100644
--- a/docs/docs/CHANGELOG.md
+++ b/docs/docs/CHANGELOG.md
@@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed mojeekbot user agent regex
- Added support for running anubis behind a base path (e.g. `/myapp`)
- Reduce Anubis' paranoia with user cookies ([#365](https://github.com/TecharoHQ/anubis/pull/365))
+- Added support for Opengraph passthrough while using unix sockets
+- The opengraph subsystem now passes the HTTP `HOST` header through to the origin
## v1.16.0
diff --git a/docs/docs/admin/configuration/open-graph.mdx b/docs/docs/admin/configuration/open-graph.mdx
index 87dd404..98cdd74 100644
--- a/docs/docs/admin/configuration/open-graph.mdx
+++ b/docs/docs/admin/configuration/open-graph.mdx
@@ -9,10 +9,11 @@ This page provides detailed information on how to configure [OpenGraph tag](http
## Configuration Options
-| Name | Description | Type | Default | Example |
-|------------------|-----------------------------------------------------------|----------|---------|-------------------------|
-| `OG_PASSTHROUGH` | Enables or disables the Open Graph tag passthrough system | Boolean | `false` | `OG_PASSTHROUGH=true` |
-| `OG_EXPIRY_TIME` | Configurable cache expiration time for Open Graph tags | Duration | `24h` | `OG_EXPIRY_TIME=1h` |
+| Name | Description | Type | Default | Example |
+|--------------------------|-----------------------------------------------------------|----------|---------|---------------------------------|
+| `OG_PASSTHROUGH` | Enables or disables the Open Graph tag passthrough system | Boolean | `false` | `OG_PASSTHROUGH=true` |
+| `OG_EXPIRY_TIME` | Configurable cache expiration time for Open Graph tags | Duration | `24h` | `OG_EXPIRY_TIME=1h` |
+| `OG_CACHE_CONSIDER_HOST` | Enables or disables the use of the host in the cache key | Boolean | `false` | `OG_CACHE_CONSIDER_HOST=true` |
## Usage
@@ -21,6 +22,7 @@ To configure Open Graph tags, you can set the following environment variables, e
```sh
export OG_PASSTHROUGH=true
export OG_EXPIRY_TIME=1h
+export OG_CACHE_CONSIDER_HOST=false
```
## Implementation Details
@@ -33,6 +35,8 @@ When `OG_PASSTHROUGH` is enabled, Anubis will:
The cache expiration time is controlled by `OG_EXPIRY_TIME`.
+When `OG_CACHE_CONSIDER_HOST` is enabled, Anubis will include the host in the cache key for Open Graph tags. This ensures that tags are cached separately for different hosts.
+
## Example
Here is an example of how to configure Open Graph tags in your Anubis setup:
@@ -40,8 +44,19 @@ Here is an example of how to configure Open Graph tags in your Anubis setup:
```sh
export OG_PASSTHROUGH=true
export OG_EXPIRY_TIME=1h
+export OG_CACHE_CONSIDER_HOST=false
```
-With these settings, Anubis will cache Open Graph tags for 1 hour and pass them through to the challenge page.
+With these settings, Anubis will cache Open Graph tags for 1 hour and pass them through to the challenge page, not considering the host in the cache key.
+
+## When to Enable `OG_CACHE_CONSIDER_HOST`
+
+In most cases, you would want to keep `OG_CACHE_CONSIDER_HOST` set to `false` to avoid unnecessary cache fragmentation. However, there are some scenarios where enabling this option can be beneficial:
+
+1. **Multi-Tenant Applications**: If you are running a multi-tenant application where different tenants are hosted on different subdomains, enabling `OG_CACHE_CONSIDER_HOST` ensures that the Open Graph tags are cached separately for each tenant. This prevents one tenant's Open Graph tags from being served to another tenant's users.
+
+2. **Different Content for Different Hosts**: If your application serves different content based on the host, enabling `OG_CACHE_CONSIDER_HOST` ensures that the correct Open Graph tags are cached and served for each host. This is useful for applications that have different branding or content for different domains or subdomains.
+
+3. **Security and Privacy Concerns**: In some cases, you may want to ensure that Open Graph tags are not shared between different hosts for security or privacy reasons. Enabling `OG_CACHE_CONSIDER_HOST` ensures that the tags are cached separately for each host, preventing any potential leakage of information between hosts.
For more information, refer to the [installation guide](../installation).
diff --git a/docs/docs/admin/installation.mdx b/docs/docs/admin/installation.mdx
index 57b5886..1fe2e0f 100644
--- a/docs/docs/admin/installation.mdx
+++ b/docs/docs/admin/installation.mdx
@@ -63,6 +63,7 @@ Anubis uses these environment variables for configuration:
| `METRICS_BIND_NETWORK` | `tcp` | The address family that the Anubis metrics server listens on. See `BIND_NETWORK` for more information. |
| `OG_EXPIRY_TIME` | `24h` | The expiration time for the Open Graph tag cache. |
| `OG_PASSTHROUGH` | `false` | If set to `true`, Anubis will enable Open Graph tag passthrough. |
+| `OG_CACHE_CONSIDER_HOST` | `false` | If set to `true`, Anubis will consider the host in the Open Graph tag cache key. |
| `POLICY_FNAME` | unset | The file containing [bot policy configuration](./policies.mdx). See the bot policy documentation for more details. If unset, the default bot policy configuration is used. |
| `REDIRECT_DOMAINS` | unset | If set, restrict the domains that Anubis can redirect to when passing a challenge.
If this is unset, Anubis may redirect to any domain which could cause security issues in the unlikely case that an attacker passes a challenge for your browser and then tricks you into clicking a link to your domain. |
| `SERVE_ROBOTS_TXT` | `false` | If set `true`, Anubis will serve a default `robots.txt` file that disallows all known AI scrapers by name and then additionally disallows every scraper. This is useful if facts and circumstances make it difficult to change the underlying service to serve such a `robots.txt` file. |
diff --git a/internal/ogtags/cache.go b/internal/ogtags/cache.go
index b3e35e4..903b723 100644
--- a/internal/ogtags/cache.go
+++ b/internal/ogtags/cache.go
@@ -8,18 +8,21 @@ import (
)
// GetOGTags is the main function that retrieves Open Graph tags for a URL
-func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
+func (c *OGTagCache) GetOGTags(url *url.URL, originalHost string) (map[string]string, error) {
if url == nil {
return nil, errors.New("nil URL provided, cannot fetch OG tags")
}
- urlStr := c.getTarget(url)
+
+ target := c.getTarget(url)
+ cacheKey := c.generateCacheKey(target, originalHost)
+
// Check cache first
- if cachedTags := c.checkCache(urlStr); cachedTags != nil {
+ if cachedTags := c.checkCache(cacheKey); cachedTags != nil {
return cachedTags, nil
}
- // Fetch HTML content
- doc, err := c.fetchHTMLDocument(urlStr)
+ // Fetch HTML content, passing the original host
+ doc, err := c.fetchHTMLDocumentWithCache(target, originalHost, cacheKey)
if errors.Is(err, syscall.ECONNREFUSED) {
slog.Debug("Connection refused, returning empty tags")
return nil, nil
@@ -35,17 +38,28 @@ func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
ogTags := c.extractOGTags(doc)
// Store in cache
- c.cache.Set(urlStr, ogTags, c.ogTimeToLive)
+ c.cache.Set(cacheKey, ogTags, c.ogTimeToLive)
return ogTags, nil
}
+func (c *OGTagCache) generateCacheKey(target string, originalHost string) string {
+ var cacheKey string
+
+ if c.ogCacheConsiderHost {
+ cacheKey = target + "|" + originalHost
+ } else {
+ cacheKey = target
+ }
+ return cacheKey
+}
+
// checkCache checks if we have the tags cached and returns them if so
-func (c *OGTagCache) checkCache(urlStr string) map[string]string {
- if cachedTags, ok := c.cache.Get(urlStr); ok {
+func (c *OGTagCache) checkCache(cacheKey string) map[string]string {
+ if cachedTags, ok := c.cache.Get(cacheKey); ok {
slog.Debug("cache hit", "tags", cachedTags)
return cachedTags
}
- slog.Debug("cache miss", "url", urlStr)
+ slog.Debug("cache miss", "url", cacheKey)
return nil
}
diff --git a/internal/ogtags/cache_test.go b/internal/ogtags/cache_test.go
index cd32414..fbacf22 100644
--- a/internal/ogtags/cache_test.go
+++ b/internal/ogtags/cache_test.go
@@ -4,12 +4,13 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "reflect"
"testing"
"time"
)
func TestCheckCache(t *testing.T) {
- cache := NewOGTagCache("http://example.com", true, time.Minute)
+ cache := NewOGTagCache("http://example.com", true, time.Minute, false)
// Set up test data
urlStr := "http://example.com/page"
@@ -17,18 +18,19 @@ func TestCheckCache(t *testing.T) {
"og:title": "Test Title",
"og:description": "Test Description",
}
+ cacheKey := cache.generateCacheKey(urlStr, "example.com")
// Test cache miss
- tags := cache.checkCache(urlStr)
+ tags := cache.checkCache(cacheKey)
if tags != nil {
t.Errorf("expected nil tags on cache miss, got %v", tags)
}
// Manually add to cache
- cache.cache.Set(urlStr, expectedTags, time.Minute)
+ cache.cache.Set(cacheKey, expectedTags, time.Minute)
// Test cache hit
- tags = cache.checkCache(urlStr)
+ tags = cache.checkCache(cacheKey)
if tags == nil {
t.Fatal("expected non-nil tags on cache hit, got nil")
}
@@ -67,7 +69,7 @@ func TestGetOGTags(t *testing.T) {
defer ts.Close()
// Create an instance of OGTagCache with a short TTL for testing
- cache := NewOGTagCache(ts.URL, true, 1*time.Minute)
+ cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
// Parse the test server URL
parsedURL, err := url.Parse(ts.URL)
@@ -76,7 +78,8 @@ func TestGetOGTags(t *testing.T) {
}
// Test fetching OG tags from the test server
- ogTags, err := cache.GetOGTags(parsedURL)
+ // Pass the host from the parsed test server URL
+ ogTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags: %v", err)
}
@@ -95,13 +98,15 @@ func TestGetOGTags(t *testing.T) {
}
// Test fetching OG tags from the cache
- ogTags, err = cache.GetOGTags(parsedURL)
+ // Pass the host from the parsed test server URL
+ ogTags, err = cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err)
}
// Test fetching OG tags from the cache (3rd time)
- newOgTags, err := cache.GetOGTags(parsedURL)
+ // Pass the host from the parsed test server URL
+ newOgTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err)
}
@@ -120,3 +125,116 @@ func TestGetOGTags(t *testing.T) {
}
}
+
+// TestGetOGTagsWithHostConsideration tests the behavior of the cache with and without host consideration and for multiple hosts in a theoretical setup.
+func TestGetOGTagsWithHostConsideration(t *testing.T) {
+ var loadCount int // Counter to track how many times the test route is loaded
+
+ // Create a test server
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ loadCount++ // Increment counter on each request to the server
+ w.Header().Set("Content-Type", "text/html")
+ w.Write([]byte(`
+
+
+
Content
+ + `)) + })) + defer ts.Close() + + parsedURL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("failed to parse test server URL: %v", err) + } + + expectedTags := map[string]string{ + "og:title": "Test Title", + "og:description": "Test Description", + } + + testCases := []struct { + name string + ogCacheConsiderHost bool + requests []struct { + host string + expectedLoadCount int // Expected load count *after* this request + } + }{ + { + name: "Host Not Considered - Same Host", + ogCacheConsiderHost: false, + requests: []struct { + host string + expectedLoadCount int + }{ + {"host1", 1}, // First request, miss + {"host1", 1}, // Second request, same host, hit (host ignored) + }, + }, + { + name: "Host Not Considered - Different Host", + ogCacheConsiderHost: false, + requests: []struct { + host string + expectedLoadCount int + }{ + {"host1", 1}, // First request, miss + {"host2", 1}, // Second request, different host, hit (host ignored) + }, + }, + { + name: "Host Considered - Same Host", + ogCacheConsiderHost: true, + requests: []struct { + host string + expectedLoadCount int + }{ + {"host1", 1}, // First request, miss + {"host1", 1}, // Second request, same host, hit + }, + }, + { + name: "Host Considered - Different Host", + ogCacheConsiderHost: true, + requests: []struct { + host string + expectedLoadCount int + }{ + {"host1", 1}, // First request, miss + {"host2", 2}, // Second request, different host, miss + {"host2", 2}, // Third request, same as second, hit + {"host1", 2}, // Fourth request, same as first, hit + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + loadCount = 0 // Reset load count for each test case + cache := NewOGTagCache(ts.URL, true, 1*time.Minute, tc.ogCacheConsiderHost) + + for i, req := range tc.requests { + ogTags, err := cache.GetOGTags(parsedURL, req.host) + if err != nil { + t.Errorf("Request %d (host: %s): unexpected error: %v", i+1, req.host, err) + continue // Skip further checks for this request if error occurred + } + + // Verify tags are correct (should always be the same in this setup) + if !reflect.DeepEqual(ogTags, expectedTags) { + t.Errorf("Request %d (host: %s): expected tags %v, got %v", i+1, req.host, expectedTags, ogTags) + } + + // Verify the load count to check cache hit/miss behavior + if loadCount != req.expectedLoadCount { + t.Errorf("Request %d (host: %s): expected load count %d, got %d (cache hit/miss mismatch)", i+1, req.host, req.expectedLoadCount, loadCount) + } + } + }) + } +} diff --git a/internal/ogtags/fetch.go b/internal/ogtags/fetch.go index 7e02eca..312e040 100644 --- a/internal/ogtags/fetch.go +++ b/internal/ogtags/fetch.go @@ -1,6 +1,7 @@ package ogtags import ( + "context" "errors" "fmt" "golang.org/x/net/html" @@ -16,17 +17,35 @@ var ( emptyMap = map[string]string{} // used to indicate an empty result in the cache. Can't use nil as it would be a cache miss. ) -func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) { - resp, err := c.client.Get(urlStr) +// fetchHTMLDocumentWithCache fetches the HTML document from the given URL string, +// preserving the original host header. +func (c *OGTagCache) fetchHTMLDocumentWithCache(urlStr string, originalHost string, cacheKey string) (*html.Node, error) { + req, err := http.NewRequestWithContext(context.Background(), "GET", urlStr, nil) + if err != nil { + return nil, fmt.Errorf("failed to create http request: %w", err) + } + + // Set the Host header to the original host + if originalHost != "" { + req.Host = originalHost + } + + // Add proxy headers + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes + + // Send the request + resp, err := c.client.Do(req) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { slog.Debug("og: request timed out", "url", urlStr) - c.cache.Set(urlStr, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server + c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server } return nil, fmt.Errorf("http get failed: %w", err) } - // this defer will call MaxBytesReader's Close, which closes the original body. + + // Ensure the response body is closed defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { @@ -36,19 +55,17 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) { if resp.StatusCode != http.StatusOK { slog.Debug("og: received non-OK status code", "url", urlStr, "status", resp.StatusCode) - c.cache.Set(urlStr, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes + c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes return nil, fmt.Errorf("%w: page not found", ErrOgHandled) } // Check content type ct := resp.Header.Get("Content-Type") if ct == "" { - // assume non html body return nil, fmt.Errorf("missing Content-Type header") } else { mediaType, _, err := mime.ParseMediaType(ct) if err != nil { - // Malformed Content-Type header slog.Debug("og: malformed Content-Type header", "url", urlStr, "contentType", ct) return nil, fmt.Errorf("%w malformed Content-Type header: %w", ErrOgHandled, err) } @@ -59,17 +76,16 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) { } } - resp.Body = http.MaxBytesReader(nil, resp.Body, c.maxContentLength) + resp.Body = http.MaxBytesReader(nil, resp.Body, maxContentLength) doc, err := html.Parse(resp.Body) if err != nil { // Check if the error is specifically because the limit was exceeded var maxBytesErr *http.MaxBytesError if errors.As(err, &maxBytesErr) { - slog.Debug("og: content exceeded max length", "url", urlStr, "limit", c.maxContentLength) - return nil, fmt.Errorf("content too large: exceeded %d bytes", c.maxContentLength) + slog.Debug("og: content exceeded max length", "url", urlStr, "limit", maxContentLength) + return nil, fmt.Errorf("content too large: exceeded %d bytes", maxContentLength) } - // parsing error (e.g., malformed HTML) return nil, fmt.Errorf("failed to parse HTML: %w", err) } diff --git a/internal/ogtags/fetch_test.go b/internal/ogtags/fetch_test.go index 60af957..7462287 100644 --- a/internal/ogtags/fetch_test.go +++ b/internal/ogtags/fetch_test.go @@ -2,6 +2,7 @@ package ogtags import ( "fmt" + "golang.org/x/net/html" "io" "net/http" "net/http/httptest" @@ -78,8 +79,8 @@ func TestFetchHTMLDocument(t *testing.T) { })) defer ts.Close() - cache := NewOGTagCache("", true, time.Minute) - doc, err := cache.fetchHTMLDocument(ts.URL) + cache := NewOGTagCache("", true, time.Minute, false) + doc, err := cache.fetchHTMLDocument(ts.URL, "anything") if tt.expectError { if err == nil { @@ -105,9 +106,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) { t.Skip("test requires theoretical network egress") } - cache := NewOGTagCache("", true, time.Minute) + cache := NewOGTagCache("", true, time.Minute, false) - doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example") + doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example", "anything") if err == nil { t.Error("expected error for invalid URL, got nil") @@ -117,3 +118,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) { t.Error("expected nil document for invalid URL, got non-nil") } } + +// fetchHTMLDocument allows you to call fetchHTMLDocumentWithCache without a duplicate generateCacheKey call +func (c *OGTagCache) fetchHTMLDocument(urlStr string, originalHost string) (*html.Node, error) { + cacheKey := c.generateCacheKey(urlStr, originalHost) + return c.fetchHTMLDocumentWithCache(urlStr, originalHost, cacheKey) +} diff --git a/internal/ogtags/integration_test.go b/internal/ogtags/integration_test.go index 9eaaa3a..32245c2 100644 --- a/internal/ogtags/integration_test.go +++ b/internal/ogtags/integration_test.go @@ -104,7 +104,7 @@ func TestIntegrationGetOGTags(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create cache instance - cache := NewOGTagCache(ts.URL, true, 1*time.Minute) + cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false) // Create URL for test testURL, _ := url.Parse(ts.URL) @@ -112,7 +112,8 @@ func TestIntegrationGetOGTags(t *testing.T) { testURL.RawQuery = tc.query // Get OG tags - ogTags, err := cache.GetOGTags(testURL) + // Pass the host from the test URL + ogTags, err := cache.GetOGTags(testURL, testURL.Host) // Check error expectation if tc.expectError { @@ -139,7 +140,8 @@ func TestIntegrationGetOGTags(t *testing.T) { } // Test cache retrieval - cachedOGTags, err := cache.GetOGTags(testURL) + // Pass the host from the test URL + cachedOGTags, err := cache.GetOGTags(testURL, testURL.Host) if err != nil { t.Fatalf("failed to get OG tags from cache: %v", err) } diff --git a/internal/ogtags/ogtags.go b/internal/ogtags/ogtags.go index 72185bb..c88d280 100644 --- a/internal/ogtags/ogtags.go +++ b/internal/ogtags/ogtags.go @@ -1,51 +1,111 @@ package ogtags import ( + "context" + "log/slog" + "net" "net/http" "net/url" + "strings" "time" "github.com/TecharoHQ/anubis/decaymap" ) +const ( + maxContentLength = 16 << 20 // 16 MiB in bytes, if there is a reasonable reason that you need more than this...Why? + httpTimeout = 5 * time.Second /*todo: make this configurable?*/ +) + type OGTagCache struct { - cache *decaymap.Impl[string, map[string]string] - target string - ogPassthrough bool - ogTimeToLive time.Duration - approvedTags []string - approvedPrefixes []string - client *http.Client - maxContentLength int64 + cache *decaymap.Impl[string, map[string]string] + targetURL *url.URL + ogCacheConsiderHost bool + ogPassthrough bool + ogTimeToLive time.Duration + approvedTags []string + approvedPrefixes []string + client *http.Client } -func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration) *OGTagCache { +func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration, ogTagsConsiderHost bool) *OGTagCache { // Predefined approved tags and prefixes // In the future, these could come from configuration defaultApprovedTags := []string{"description", "keywords", "author"} defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"} - client := &http.Client{ - Timeout: 5 * time.Second, /*make this configurable?*/ + + var parsedTargetURL *url.URL + var err error + + if target == "" { + // Default to localhost if target is empty + parsedTargetURL, _ = url.Parse("http://localhost") + } else { + parsedTargetURL, err = url.Parse(target) + if err != nil { + slog.Debug("og: failed to parse target URL, treating as non-unix", "target", target, "error", err) + // If parsing fails, treat it as a non-unix target for backward compatibility or default behavior + // For now, assume it's not a scheme issue but maybe an invalid char, etc. + // A simple string target might be intended if it's not a full URL. + parsedTargetURL = &url.URL{Scheme: "http", Host: target} // Assume http if scheme missing and host-like + if !strings.Contains(target, "://") && !strings.HasPrefix(target, "unix:") { + // If it looks like just a host/host:port (and not unix), prepend http:// (todo: is this bad...? Trace path to see if i can yell at user to do it right) + parsedTargetURL, _ = url.Parse("http://" + target) // fetch cares about scheme but anubis doesn't + } + } } - const maxContentLength = 16 << 20 // 16 MiB in bytes + client := &http.Client{ + Timeout: httpTimeout, + } + + // Configure custom transport for Unix sockets + if parsedTargetURL.Scheme == "unix" { + socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path + client.Transport = &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, + } + } return &OGTagCache{ - cache: decaymap.New[string, map[string]string](), - target: target, - ogPassthrough: ogPassthrough, - ogTimeToLive: ogTimeToLive, - approvedTags: defaultApprovedTags, - approvedPrefixes: defaultApprovedPrefixes, - client: client, - maxContentLength: maxContentLength, + cache: decaymap.New[string, map[string]string](), + targetURL: parsedTargetURL, // Store the parsed URL + ogPassthrough: ogPassthrough, + ogTimeToLive: ogTimeToLive, + ogCacheConsiderHost: ogTagsConsiderHost, // todo: refactor to be a separate struct + approvedTags: defaultApprovedTags, + approvedPrefixes: defaultApprovedPrefixes, + client: client, } } +// getTarget constructs the target URL string for fetching OG tags. +// For Unix sockets, it creates a "fake" HTTP URL that the custom dialer understands. func (c *OGTagCache) getTarget(u *url.URL) string { - return c.target + u.Path + if c.targetURL.Scheme == "unix" { + // The custom dialer ignores the host, but we need a valid http URL structure. + // Use "unix" as a placeholder host. Path and Query from original request are appended. + fakeURL := &url.URL{ + Scheme: "http", // Scheme must be http/https for client.Get + Host: "unix", // Arbitrary host, ignored by custom dialer + Path: u.Path, + RawQuery: u.RawQuery, + } + return fakeURL.String() + } + + // For regular http/https targets + target := *c.targetURL // Make a copy + target.Path = u.Path + target.RawQuery = u.RawQuery + return target.String() + } func (c *OGTagCache) Cleanup() { - c.cache.Cleanup() + if c.cache != nil { + c.cache.Cleanup() + } } diff --git a/internal/ogtags/ogtags_test.go b/internal/ogtags/ogtags_test.go index 8cd5b0d..4d97ad2 100644 --- a/internal/ogtags/ogtags_test.go +++ b/internal/ogtags/ogtags_test.go @@ -1,7 +1,16 @@ package ogtags import ( + "context" + "errors" + "fmt" + "net" + "net/http" "net/url" + "os" + "path/filepath" + "reflect" + "strings" "testing" "time" ) @@ -29,14 +38,23 @@ func TestNewOGTagCache(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive) + cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive, false) if cache == nil { t.Fatal("expected non-nil cache, got nil") } - if cache.target != tt.target { - t.Errorf("expected target %s, got %s", tt.target, cache.target) + // Check the parsed targetURL, handling the default case for empty target + expectedURLStr := tt.target + if tt.target == "" { + // Default behavior when target is empty is now http://localhost + expectedURLStr = "http://localhost" + } else if !strings.Contains(tt.target, "://") && !strings.HasPrefix(tt.target, "unix:") { + // Handle case where target is just host or host:port (and not unix) + expectedURLStr = "http://" + tt.target + } + if cache.targetURL.String() != expectedURLStr { + t.Errorf("expected targetURL %s, got %s", expectedURLStr, cache.targetURL.String()) } if cache.ogPassthrough != tt.ogPassthrough { @@ -50,6 +68,45 @@ func TestNewOGTagCache(t *testing.T) { } } +// TestNewOGTagCache_UnixSocket specifically tests unix socket initialization +func TestNewOGTagCache_UnixSocket(t *testing.T) { + tempDir := t.TempDir() + socketPath := filepath.Join(tempDir, "test.sock") + target := "unix://" + socketPath + + cache := NewOGTagCache(target, true, 5*time.Minute, false) + + if cache == nil { + t.Fatal("expected non-nil cache, got nil") + } + + if cache.targetURL.Scheme != "unix" { + t.Errorf("expected targetURL scheme 'unix', got '%s'", cache.targetURL.Scheme) + } + if cache.targetURL.Path != socketPath { + t.Errorf("expected targetURL path '%s', got '%s'", socketPath, cache.targetURL.Path) + } + + // Check if the client transport is configured for Unix sockets + transport, ok := cache.client.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected client transport to be *http.Transport, got %T", cache.client.Transport) + } + if transport.DialContext == nil { + t.Fatal("expected client transport DialContext to be non-nil for unix socket") + } + + // Attempt a dummy dial to see if it uses the correct path (optional, more involved check) + dummyConn, err := transport.DialContext(context.Background(), "", "") + if err == nil { + dummyConn.Close() + t.Log("DialContext seems functional, but couldn't verify path without a listener") + } else if !strings.Contains(err.Error(), "connect: connection refused") && !strings.Contains(err.Error(), "connect: no such file or directory") { + // We expect connection refused or not found if nothing is listening + t.Errorf("DialContext failed with unexpected error: %v", err) + } +} + func TestGetTarget(t *testing.T) { tests := []struct { name string @@ -66,24 +123,39 @@ func TestGetTarget(t *testing.T) { expected: "http://example.com", }, { - name: "With complex path", - target: "http://example.com", - path: "/pag(#*((#@)ΓΓΓΓe/Γ", - query: "id=123", - expected: "http://example.com/pag(#*((#@)ΓΓΓΓe/Γ", + name: "With complex path", + target: "http://example.com", + path: "/pag(#*((#@)ΓΓΓΓe/Γ", + query: "id=123", + // Expect URL encoding and query parameter + expected: "http://example.com/pag%28%23%2A%28%28%23@%29%CE%93%CE%93%CE%93%CE%93e/%CE%93?id=123", }, { name: "With query and path", target: "http://example.com", path: "/page", query: "id=123", - expected: "http://example.com/page", + expected: "http://example.com/page?id=123", + }, + { + name: "Unix socket target", + target: "unix:/tmp/anubis.sock", + path: "/some/path", + query: "key=value&flag=true", + expected: "http://unix/some/path?key=value&flag=true", // Scheme becomes http, host is 'unix' + }, + { + name: "Unix socket target with ///", + target: "unix:///var/run/anubis.sock", + path: "/", + query: "", + expected: "http://unix/", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cache := NewOGTagCache(tt.target, false, time.Minute) + cache := NewOGTagCache(tt.target, false, time.Minute, false) u := &url.URL{ Path: tt.path, @@ -98,3 +170,86 @@ func TestGetTarget(t *testing.T) { }) } } + +// TestIntegrationGetOGTags_UnixSocket tests fetching OG tags via a Unix socket. +func TestIntegrationGetOGTags_UnixSocket(t *testing.T) { + tempDir := t.TempDir() + + socketPath := filepath.Join(tempDir, "anubis-test.sock") + + // Ensure the socket does not exist initially + _ = os.Remove(socketPath) + + // Create a simple HTTP server listening on the Unix socket + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Failed to listen on unix socket %s: %v", socketPath, err) + } + defer func(listener net.Listener, socketPath string) { + if listener != nil { + if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + t.Logf("Error closing listener: %v", err) + } + } + + if _, err := os.Stat(socketPath); err == nil { + if err := os.Remove(socketPath); err != nil { + t.Logf("Error removing socket file %s: %v", socketPath, err) + } + } + }(listener, socketPath) + + server := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprintln(w, `Test`) + }), + } + go func() { + if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Logf("Unix socket server error: %v", err) + } + }() + defer func(server *http.Server, ctx context.Context) { + err := server.Shutdown(ctx) + if err != nil { + t.Logf("Error shutting down server: %v", err) + } + }(server, context.Background()) // Ensure server is shut down + + // Wait a moment for the server to start + time.Sleep(100 * time.Millisecond) + + // Create cache instance pointing to the Unix socket + targetURL := "unix://" + socketPath + cache := NewOGTagCache(targetURL, true, 1*time.Minute, false) + + // Create a dummy URL for the request (path and query matter) + testReqURL, _ := url.Parse("/some/page?query=1") + + // Get OG tags + // Pass an empty string for host, as it's irrelevant for unix sockets + ogTags, err := cache.GetOGTags(testReqURL, "") + + if err != nil { + t.Fatalf("GetOGTags failed for unix socket: %v", err) + } + + expectedTags := map[string]string{ + "og:title": "Unix Socket Test", + } + + if !reflect.DeepEqual(ogTags, expectedTags) { + t.Errorf("Expected OG tags %v, got %v", expectedTags, ogTags) + } + + // Test cache retrieval (should hit cache) + // Pass an empty string for host + cachedTags, err := cache.GetOGTags(testReqURL, "") + if err != nil { + t.Fatalf("GetOGTags (cache hit) failed for unix socket: %v", err) + } + if !reflect.DeepEqual(cachedTags, expectedTags) { + t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags) + } +} diff --git a/internal/ogtags/parse_test.go b/internal/ogtags/parse_test.go index 54815b3..e25a211 100644 --- a/internal/ogtags/parse_test.go +++ b/internal/ogtags/parse_test.go @@ -12,7 +12,7 @@ import ( // TestExtractOGTags updated with correct expectations based on filtering logic func TestExtractOGTags(t *testing.T) { // Use a cache instance that reflects the default approved lists - testCache := NewOGTagCache("", false, time.Minute) + testCache := NewOGTagCache("", false, time.Minute, false) // Manually set approved tags/prefixes based on the user request for clarity testCache.approvedTags = []string{"description"} testCache.approvedPrefixes = []string{"og:"} @@ -189,7 +189,7 @@ func TestIsOGMetaTag(t *testing.T) { func TestExtractMetaTagInfo(t *testing.T) { // Use a cache instance that reflects the default approved lists - testCache := NewOGTagCache("", false, time.Minute) + testCache := NewOGTagCache("", false, time.Minute, false) testCache.approvedTags = []string{"description"} testCache.approvedPrefixes = []string{"og:"} diff --git a/lib/config.go b/lib/config.go index 81d2bcd..44b6479 100644 --- a/lib/config.go +++ b/lib/config.go @@ -33,9 +33,10 @@ type Options struct { CookieName string CookiePartitioned bool - OGPassthrough bool - OGTimeToLive time.Duration - Target string + OGPassthrough bool + OGTimeToLive time.Duration + OGCacheConsidersHost bool + Target string WebmasterEmail string BasePrefix string @@ -89,7 +90,7 @@ func New(opts Options) (*Server, error) { policy: opts.Policy, opts: opts, DNSBLCache: decaymap.New[string, dnsbl.DroneBLResponse](), - OGTags: ogtags.NewOGTagCache(opts.Target, opts.OGPassthrough, opts.OGTimeToLive), + OGTags: ogtags.NewOGTagCache(opts.Target, opts.OGPassthrough, opts.OGTimeToLive, opts.OGCacheConsidersHost), } mux := http.NewServeMux() diff --git a/lib/http.go b/lib/http.go index 9e134b3..bfcddfe 100644 --- a/lib/http.go +++ b/lib/http.go @@ -54,7 +54,7 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic var ogTags map[string]string = nil if s.opts.OGPassthrough { var err error - ogTags, err = s.OGTags.GetOGTags(r.URL) + ogTags, err = s.OGTags.GetOGTags(r.URL, r.Host) if err != nil { lg.Error("failed to get OG tags", "err", err) }