feat(og): Foward host header (#370)

* feat(ogtags): enhance target URL handling for OGTagCache, support Unix sockets

Closes: #323 #319
Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* docs: update CHANGELOG.md to include Opengraph passthrough support for Unix sockets

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* docs: update CHANGELOG.md to include Opengraph passthrough support for Unix sockets

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* feat(ogtags): add option to consider host in Open Graph tag cache key

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* feat(ogtags): add option to consider host in OG tag cache key

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* test(ogtags): enhance tests for OGTagCache with host consideration scenarios

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* refactor(ogtags): extract constants for HTTP timeout and max content length

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* refactor(ogtags): restore fetchHTMLDocument method for cache key generation

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* refactor(ogtags): replace maxContentLength field with constant and ensure HTTP scheme is set correctly

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* fix(fetch): add proxy headers

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

---------

Signed-off-by: Jason Cameron <git@jasoncameron.dev>
This commit is contained in:
Jason Cameron 2025-04-29 08:20:04 -04:00 committed by GitHub
parent 7a20a46b0d
commit 4184b42282
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 484 additions and 91 deletions

View File

@ -59,6 +59,7 @@ var (
debugBenchmarkJS = flag.Bool("debug-benchmark-js", false, "respond to every request with a challenge for benchmarking hashrate") debugBenchmarkJS = flag.Bool("debug-benchmark-js", false, "respond to every request with a challenge for benchmarking hashrate")
ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough") ogPassthrough = flag.Bool("og-passthrough", false, "enable Open Graph tag passthrough")
ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time") ogTimeToLive = flag.Duration("og-expiry-time", 24*time.Hour, "Open Graph tag cache expiration time")
ogCacheConsiderHost = flag.Bool("og-cache-consider-host", false, "enable or disable the use of the host in the Open Graph tag cache")
extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder") extractResources = flag.String("extract-resources", "", "if set, extract the static resources to the specified folder")
webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals") webmasterEmail = flag.String("webmaster-email", "", "if set, displays webmaster's email on the reject page for appeals")
) )
@ -272,18 +273,19 @@ func main() {
} }
s, err := libanubis.New(libanubis.Options{ s, err := libanubis.New(libanubis.Options{
BasePrefix: *basePrefix, BasePrefix: *basePrefix,
Next: rp, Next: rp,
Policy: policy, Policy: policy,
ServeRobotsTXT: *robotsTxt, ServeRobotsTXT: *robotsTxt,
PrivateKey: priv, PrivateKey: priv,
CookieDomain: *cookieDomain, CookieDomain: *cookieDomain,
CookiePartitioned: *cookiePartitioned, CookiePartitioned: *cookiePartitioned,
OGPassthrough: *ogPassthrough, OGPassthrough: *ogPassthrough,
OGTimeToLive: *ogTimeToLive, OGTimeToLive: *ogTimeToLive,
RedirectDomains: redirectDomainsList, RedirectDomains: redirectDomainsList,
Target: *target, Target: *target,
WebmasterEmail: *webmasterEmail, WebmasterEmail: *webmasterEmail,
OGCacheConsidersHost: *ogCacheConsiderHost,
}) })
if err != nil { if err != nil {
log.Fatalf("can't construct libanubis.Server: %v", err) log.Fatalf("can't construct libanubis.Server: %v", err)

View File

@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed mojeekbot user agent regex - Fixed mojeekbot user agent regex
- Added support for running anubis behind a base path (e.g. `/myapp`) - Added support for running anubis behind a base path (e.g. `/myapp`)
- Reduce Anubis' paranoia with user cookies ([#365](https://github.com/TecharoHQ/anubis/pull/365)) - Reduce Anubis' paranoia with user cookies ([#365](https://github.com/TecharoHQ/anubis/pull/365))
- Added support for Opengraph passthrough while using unix sockets
- The opengraph subsystem now passes the HTTP `HOST` header through to the origin
## v1.16.0 ## v1.16.0

View File

@ -9,10 +9,11 @@ This page provides detailed information on how to configure [OpenGraph tag](http
## Configuration Options ## Configuration Options
| Name | Description | Type | Default | Example | | Name | Description | Type | Default | Example |
|------------------|-----------------------------------------------------------|----------|---------|-------------------------| |--------------------------|-----------------------------------------------------------|----------|---------|---------------------------------|
| `OG_PASSTHROUGH` | Enables or disables the Open Graph tag passthrough system | Boolean | `false` | `OG_PASSTHROUGH=true` | | `OG_PASSTHROUGH` | Enables or disables the Open Graph tag passthrough system | Boolean | `false` | `OG_PASSTHROUGH=true` |
| `OG_EXPIRY_TIME` | Configurable cache expiration time for Open Graph tags | Duration | `24h` | `OG_EXPIRY_TIME=1h` | | `OG_EXPIRY_TIME` | Configurable cache expiration time for Open Graph tags | Duration | `24h` | `OG_EXPIRY_TIME=1h` |
| `OG_CACHE_CONSIDER_HOST` | Enables or disables the use of the host in the cache key | Boolean | `false` | `OG_CACHE_CONSIDER_HOST=true` |
## Usage ## Usage
@ -21,6 +22,7 @@ To configure Open Graph tags, you can set the following environment variables, e
```sh ```sh
export OG_PASSTHROUGH=true export OG_PASSTHROUGH=true
export OG_EXPIRY_TIME=1h export OG_EXPIRY_TIME=1h
export OG_CACHE_CONSIDER_HOST=false
``` ```
## Implementation Details ## Implementation Details
@ -33,6 +35,8 @@ When `OG_PASSTHROUGH` is enabled, Anubis will:
The cache expiration time is controlled by `OG_EXPIRY_TIME`. The cache expiration time is controlled by `OG_EXPIRY_TIME`.
When `OG_CACHE_CONSIDER_HOST` is enabled, Anubis will include the host in the cache key for Open Graph tags. This ensures that tags are cached separately for different hosts.
## Example ## Example
Here is an example of how to configure Open Graph tags in your Anubis setup: Here is an example of how to configure Open Graph tags in your Anubis setup:
@ -40,8 +44,19 @@ Here is an example of how to configure Open Graph tags in your Anubis setup:
```sh ```sh
export OG_PASSTHROUGH=true export OG_PASSTHROUGH=true
export OG_EXPIRY_TIME=1h export OG_EXPIRY_TIME=1h
export OG_CACHE_CONSIDER_HOST=false
``` ```
With these settings, Anubis will cache Open Graph tags for 1 hour and pass them through to the challenge page. With these settings, Anubis will cache Open Graph tags for 1 hour and pass them through to the challenge page, not considering the host in the cache key.
## When to Enable `OG_CACHE_CONSIDER_HOST`
In most cases, you would want to keep `OG_CACHE_CONSIDER_HOST` set to `false` to avoid unnecessary cache fragmentation. However, there are some scenarios where enabling this option can be beneficial:
1. **Multi-Tenant Applications**: If you are running a multi-tenant application where different tenants are hosted on different subdomains, enabling `OG_CACHE_CONSIDER_HOST` ensures that the Open Graph tags are cached separately for each tenant. This prevents one tenant's Open Graph tags from being served to another tenant's users.
2. **Different Content for Different Hosts**: If your application serves different content based on the host, enabling `OG_CACHE_CONSIDER_HOST` ensures that the correct Open Graph tags are cached and served for each host. This is useful for applications that have different branding or content for different domains or subdomains.
3. **Security and Privacy Concerns**: In some cases, you may want to ensure that Open Graph tags are not shared between different hosts for security or privacy reasons. Enabling `OG_CACHE_CONSIDER_HOST` ensures that the tags are cached separately for each host, preventing any potential leakage of information between hosts.
For more information, refer to the [installation guide](../installation). For more information, refer to the [installation guide](../installation).

View File

@ -63,6 +63,7 @@ Anubis uses these environment variables for configuration:
| `METRICS_BIND_NETWORK` | `tcp` | The address family that the Anubis metrics server listens on. See `BIND_NETWORK` for more information. | | `METRICS_BIND_NETWORK` | `tcp` | The address family that the Anubis metrics server listens on. See `BIND_NETWORK` for more information. |
| `OG_EXPIRY_TIME` | `24h` | The expiration time for the Open Graph tag cache. | | `OG_EXPIRY_TIME` | `24h` | The expiration time for the Open Graph tag cache. |
| `OG_PASSTHROUGH` | `false` | If set to `true`, Anubis will enable Open Graph tag passthrough. | | `OG_PASSTHROUGH` | `false` | If set to `true`, Anubis will enable Open Graph tag passthrough. |
| `OG_CACHE_CONSIDER_HOST` | `false` | If set to `true`, Anubis will consider the host in the Open Graph tag cache key. |
| `POLICY_FNAME` | unset | The file containing [bot policy configuration](./policies.mdx). See the bot policy documentation for more details. If unset, the default bot policy configuration is used. | | `POLICY_FNAME` | unset | The file containing [bot policy configuration](./policies.mdx). See the bot policy documentation for more details. If unset, the default bot policy configuration is used. |
| `REDIRECT_DOMAINS` | unset | If set, restrict the domains that Anubis can redirect to when passing a challenge.<br/><br/>If this is unset, Anubis may redirect to any domain which could cause security issues in the unlikely case that an attacker passes a challenge for your browser and then tricks you into clicking a link to your domain. | | `REDIRECT_DOMAINS` | unset | If set, restrict the domains that Anubis can redirect to when passing a challenge.<br/><br/>If this is unset, Anubis may redirect to any domain which could cause security issues in the unlikely case that an attacker passes a challenge for your browser and then tricks you into clicking a link to your domain. |
| `SERVE_ROBOTS_TXT` | `false` | If set `true`, Anubis will serve a default `robots.txt` file that disallows all known AI scrapers by name and then additionally disallows every scraper. This is useful if facts and circumstances make it difficult to change the underlying service to serve such a `robots.txt` file. | | `SERVE_ROBOTS_TXT` | `false` | If set `true`, Anubis will serve a default `robots.txt` file that disallows all known AI scrapers by name and then additionally disallows every scraper. This is useful if facts and circumstances make it difficult to change the underlying service to serve such a `robots.txt` file. |

View File

@ -8,18 +8,21 @@ import (
) )
// GetOGTags is the main function that retrieves Open Graph tags for a URL // GetOGTags is the main function that retrieves Open Graph tags for a URL
func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) { func (c *OGTagCache) GetOGTags(url *url.URL, originalHost string) (map[string]string, error) {
if url == nil { if url == nil {
return nil, errors.New("nil URL provided, cannot fetch OG tags") return nil, errors.New("nil URL provided, cannot fetch OG tags")
} }
urlStr := c.getTarget(url)
target := c.getTarget(url)
cacheKey := c.generateCacheKey(target, originalHost)
// Check cache first // Check cache first
if cachedTags := c.checkCache(urlStr); cachedTags != nil { if cachedTags := c.checkCache(cacheKey); cachedTags != nil {
return cachedTags, nil return cachedTags, nil
} }
// Fetch HTML content // Fetch HTML content, passing the original host
doc, err := c.fetchHTMLDocument(urlStr) doc, err := c.fetchHTMLDocumentWithCache(target, originalHost, cacheKey)
if errors.Is(err, syscall.ECONNREFUSED) { if errors.Is(err, syscall.ECONNREFUSED) {
slog.Debug("Connection refused, returning empty tags") slog.Debug("Connection refused, returning empty tags")
return nil, nil return nil, nil
@ -35,17 +38,28 @@ func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
ogTags := c.extractOGTags(doc) ogTags := c.extractOGTags(doc)
// Store in cache // Store in cache
c.cache.Set(urlStr, ogTags, c.ogTimeToLive) c.cache.Set(cacheKey, ogTags, c.ogTimeToLive)
return ogTags, nil return ogTags, nil
} }
func (c *OGTagCache) generateCacheKey(target string, originalHost string) string {
var cacheKey string
if c.ogCacheConsiderHost {
cacheKey = target + "|" + originalHost
} else {
cacheKey = target
}
return cacheKey
}
// checkCache checks if we have the tags cached and returns them if so // checkCache checks if we have the tags cached and returns them if so
func (c *OGTagCache) checkCache(urlStr string) map[string]string { func (c *OGTagCache) checkCache(cacheKey string) map[string]string {
if cachedTags, ok := c.cache.Get(urlStr); ok { if cachedTags, ok := c.cache.Get(cacheKey); ok {
slog.Debug("cache hit", "tags", cachedTags) slog.Debug("cache hit", "tags", cachedTags)
return cachedTags return cachedTags
} }
slog.Debug("cache miss", "url", urlStr) slog.Debug("cache miss", "url", cacheKey)
return nil return nil
} }

View File

@ -4,12 +4,13 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect"
"testing" "testing"
"time" "time"
) )
func TestCheckCache(t *testing.T) { func TestCheckCache(t *testing.T) {
cache := NewOGTagCache("http://example.com", true, time.Minute) cache := NewOGTagCache("http://example.com", true, time.Minute, false)
// Set up test data // Set up test data
urlStr := "http://example.com/page" urlStr := "http://example.com/page"
@ -17,18 +18,19 @@ func TestCheckCache(t *testing.T) {
"og:title": "Test Title", "og:title": "Test Title",
"og:description": "Test Description", "og:description": "Test Description",
} }
cacheKey := cache.generateCacheKey(urlStr, "example.com")
// Test cache miss // Test cache miss
tags := cache.checkCache(urlStr) tags := cache.checkCache(cacheKey)
if tags != nil { if tags != nil {
t.Errorf("expected nil tags on cache miss, got %v", tags) t.Errorf("expected nil tags on cache miss, got %v", tags)
} }
// Manually add to cache // Manually add to cache
cache.cache.Set(urlStr, expectedTags, time.Minute) cache.cache.Set(cacheKey, expectedTags, time.Minute)
// Test cache hit // Test cache hit
tags = cache.checkCache(urlStr) tags = cache.checkCache(cacheKey)
if tags == nil { if tags == nil {
t.Fatal("expected non-nil tags on cache hit, got nil") t.Fatal("expected non-nil tags on cache hit, got nil")
} }
@ -67,7 +69,7 @@ func TestGetOGTags(t *testing.T) {
defer ts.Close() defer ts.Close()
// Create an instance of OGTagCache with a short TTL for testing // Create an instance of OGTagCache with a short TTL for testing
cache := NewOGTagCache(ts.URL, true, 1*time.Minute) cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
// Parse the test server URL // Parse the test server URL
parsedURL, err := url.Parse(ts.URL) parsedURL, err := url.Parse(ts.URL)
@ -76,7 +78,8 @@ func TestGetOGTags(t *testing.T) {
} }
// Test fetching OG tags from the test server // Test fetching OG tags from the test server
ogTags, err := cache.GetOGTags(parsedURL) // Pass the host from the parsed test server URL
ogTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil { if err != nil {
t.Fatalf("failed to get OG tags: %v", err) t.Fatalf("failed to get OG tags: %v", err)
} }
@ -95,13 +98,15 @@ func TestGetOGTags(t *testing.T) {
} }
// Test fetching OG tags from the cache // Test fetching OG tags from the cache
ogTags, err = cache.GetOGTags(parsedURL) // Pass the host from the parsed test server URL
ogTags, err = cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil { if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err) t.Fatalf("failed to get OG tags from cache: %v", err)
} }
// Test fetching OG tags from the cache (3rd time) // Test fetching OG tags from the cache (3rd time)
newOgTags, err := cache.GetOGTags(parsedURL) // Pass the host from the parsed test server URL
newOgTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil { if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err) t.Fatalf("failed to get OG tags from cache: %v", err)
} }
@ -120,3 +125,116 @@ func TestGetOGTags(t *testing.T) {
} }
} }
// TestGetOGTagsWithHostConsideration tests the behavior of the cache with and without host consideration and for multiple hosts in a theoretical setup.
func TestGetOGTagsWithHostConsideration(t *testing.T) {
var loadCount int // Counter to track how many times the test route is loaded
// Create a test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
loadCount++ // Increment counter on each request to the server
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(`
<!DOCTYPE html>
<html>
<head>
<meta property="og:title" content="Test Title" />
<meta property="og:description" content="Test Description" />
</head>
<body><p>Content</p></body>
</html>
`))
}))
defer ts.Close()
parsedURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse test server URL: %v", err)
}
expectedTags := map[string]string{
"og:title": "Test Title",
"og:description": "Test Description",
}
testCases := []struct {
name string
ogCacheConsiderHost bool
requests []struct {
host string
expectedLoadCount int // Expected load count *after* this request
}
}{
{
name: "Host Not Considered - Same Host",
ogCacheConsiderHost: false,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host1", 1}, // Second request, same host, hit (host ignored)
},
},
{
name: "Host Not Considered - Different Host",
ogCacheConsiderHost: false,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host2", 1}, // Second request, different host, hit (host ignored)
},
},
{
name: "Host Considered - Same Host",
ogCacheConsiderHost: true,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host1", 1}, // Second request, same host, hit
},
},
{
name: "Host Considered - Different Host",
ogCacheConsiderHost: true,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host2", 2}, // Second request, different host, miss
{"host2", 2}, // Third request, same as second, hit
{"host1", 2}, // Fourth request, same as first, hit
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
loadCount = 0 // Reset load count for each test case
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, tc.ogCacheConsiderHost)
for i, req := range tc.requests {
ogTags, err := cache.GetOGTags(parsedURL, req.host)
if err != nil {
t.Errorf("Request %d (host: %s): unexpected error: %v", i+1, req.host, err)
continue // Skip further checks for this request if error occurred
}
// Verify tags are correct (should always be the same in this setup)
if !reflect.DeepEqual(ogTags, expectedTags) {
t.Errorf("Request %d (host: %s): expected tags %v, got %v", i+1, req.host, expectedTags, ogTags)
}
// Verify the load count to check cache hit/miss behavior
if loadCount != req.expectedLoadCount {
t.Errorf("Request %d (host: %s): expected load count %d, got %d (cache hit/miss mismatch)", i+1, req.host, req.expectedLoadCount, loadCount)
}
}
})
}
}

View File

@ -1,6 +1,7 @@
package ogtags package ogtags
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/html" "golang.org/x/net/html"
@ -16,17 +17,35 @@ var (
emptyMap = map[string]string{} // used to indicate an empty result in the cache. Can't use nil as it would be a cache miss. emptyMap = map[string]string{} // used to indicate an empty result in the cache. Can't use nil as it would be a cache miss.
) )
func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) { // fetchHTMLDocumentWithCache fetches the HTML document from the given URL string,
resp, err := c.client.Get(urlStr) // preserving the original host header.
func (c *OGTagCache) fetchHTMLDocumentWithCache(urlStr string, originalHost string, cacheKey string) (*html.Node, error) {
req, err := http.NewRequestWithContext(context.Background(), "GET", urlStr, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
// Set the Host header to the original host
if originalHost != "" {
req.Host = originalHost
}
// Add proxy headers
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes
// Send the request
resp, err := c.client.Do(req)
if err != nil { if err != nil {
var netErr net.Error var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() { if errors.As(err, &netErr) && netErr.Timeout() {
slog.Debug("og: request timed out", "url", urlStr) slog.Debug("og: request timed out", "url", urlStr)
c.cache.Set(urlStr, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server
} }
return nil, fmt.Errorf("http get failed: %w", err) return nil, fmt.Errorf("http get failed: %w", err)
} }
// this defer will call MaxBytesReader's Close, which closes the original body.
// Ensure the response body is closed
defer func(Body io.ReadCloser) { defer func(Body io.ReadCloser) {
err := Body.Close() err := Body.Close()
if err != nil { if err != nil {
@ -36,19 +55,17 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
slog.Debug("og: received non-OK status code", "url", urlStr, "status", resp.StatusCode) slog.Debug("og: received non-OK status code", "url", urlStr, "status", resp.StatusCode)
c.cache.Set(urlStr, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes
return nil, fmt.Errorf("%w: page not found", ErrOgHandled) return nil, fmt.Errorf("%w: page not found", ErrOgHandled)
} }
// Check content type // Check content type
ct := resp.Header.Get("Content-Type") ct := resp.Header.Get("Content-Type")
if ct == "" { if ct == "" {
// assume non html body
return nil, fmt.Errorf("missing Content-Type header") return nil, fmt.Errorf("missing Content-Type header")
} else { } else {
mediaType, _, err := mime.ParseMediaType(ct) mediaType, _, err := mime.ParseMediaType(ct)
if err != nil { if err != nil {
// Malformed Content-Type header
slog.Debug("og: malformed Content-Type header", "url", urlStr, "contentType", ct) slog.Debug("og: malformed Content-Type header", "url", urlStr, "contentType", ct)
return nil, fmt.Errorf("%w malformed Content-Type header: %w", ErrOgHandled, err) return nil, fmt.Errorf("%w malformed Content-Type header: %w", ErrOgHandled, err)
} }
@ -59,17 +76,16 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
} }
} }
resp.Body = http.MaxBytesReader(nil, resp.Body, c.maxContentLength) resp.Body = http.MaxBytesReader(nil, resp.Body, maxContentLength)
doc, err := html.Parse(resp.Body) doc, err := html.Parse(resp.Body)
if err != nil { if err != nil {
// Check if the error is specifically because the limit was exceeded // Check if the error is specifically because the limit was exceeded
var maxBytesErr *http.MaxBytesError var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) { if errors.As(err, &maxBytesErr) {
slog.Debug("og: content exceeded max length", "url", urlStr, "limit", c.maxContentLength) slog.Debug("og: content exceeded max length", "url", urlStr, "limit", maxContentLength)
return nil, fmt.Errorf("content too large: exceeded %d bytes", c.maxContentLength) return nil, fmt.Errorf("content too large: exceeded %d bytes", maxContentLength)
} }
// parsing error (e.g., malformed HTML)
return nil, fmt.Errorf("failed to parse HTML: %w", err) return nil, fmt.Errorf("failed to parse HTML: %w", err)
} }

View File

@ -2,6 +2,7 @@ package ogtags
import ( import (
"fmt" "fmt"
"golang.org/x/net/html"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -78,8 +79,8 @@ func TestFetchHTMLDocument(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
cache := NewOGTagCache("", true, time.Minute) cache := NewOGTagCache("", true, time.Minute, false)
doc, err := cache.fetchHTMLDocument(ts.URL) doc, err := cache.fetchHTMLDocument(ts.URL, "anything")
if tt.expectError { if tt.expectError {
if err == nil { if err == nil {
@ -105,9 +106,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
t.Skip("test requires theoretical network egress") t.Skip("test requires theoretical network egress")
} }
cache := NewOGTagCache("", true, time.Minute) cache := NewOGTagCache("", true, time.Minute, false)
doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example") doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example", "anything")
if err == nil { if err == nil {
t.Error("expected error for invalid URL, got nil") t.Error("expected error for invalid URL, got nil")
@ -117,3 +118,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
t.Error("expected nil document for invalid URL, got non-nil") t.Error("expected nil document for invalid URL, got non-nil")
} }
} }
// fetchHTMLDocument allows you to call fetchHTMLDocumentWithCache without a duplicate generateCacheKey call
func (c *OGTagCache) fetchHTMLDocument(urlStr string, originalHost string) (*html.Node, error) {
cacheKey := c.generateCacheKey(urlStr, originalHost)
return c.fetchHTMLDocumentWithCache(urlStr, originalHost, cacheKey)
}

View File

@ -104,7 +104,7 @@ func TestIntegrationGetOGTags(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Create cache instance // Create cache instance
cache := NewOGTagCache(ts.URL, true, 1*time.Minute) cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
// Create URL for test // Create URL for test
testURL, _ := url.Parse(ts.URL) testURL, _ := url.Parse(ts.URL)
@ -112,7 +112,8 @@ func TestIntegrationGetOGTags(t *testing.T) {
testURL.RawQuery = tc.query testURL.RawQuery = tc.query
// Get OG tags // Get OG tags
ogTags, err := cache.GetOGTags(testURL) // Pass the host from the test URL
ogTags, err := cache.GetOGTags(testURL, testURL.Host)
// Check error expectation // Check error expectation
if tc.expectError { if tc.expectError {
@ -139,7 +140,8 @@ func TestIntegrationGetOGTags(t *testing.T) {
} }
// Test cache retrieval // Test cache retrieval
cachedOGTags, err := cache.GetOGTags(testURL) // Pass the host from the test URL
cachedOGTags, err := cache.GetOGTags(testURL, testURL.Host)
if err != nil { if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err) t.Fatalf("failed to get OG tags from cache: %v", err)
} }

View File

@ -1,51 +1,111 @@
package ogtags package ogtags
import ( import (
"context"
"log/slog"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"time" "time"
"github.com/TecharoHQ/anubis/decaymap" "github.com/TecharoHQ/anubis/decaymap"
) )
const (
maxContentLength = 16 << 20 // 16 MiB in bytes, if there is a reasonable reason that you need more than this...Why?
httpTimeout = 5 * time.Second /*todo: make this configurable?*/
)
type OGTagCache struct { type OGTagCache struct {
cache *decaymap.Impl[string, map[string]string] cache *decaymap.Impl[string, map[string]string]
target string targetURL *url.URL
ogPassthrough bool ogCacheConsiderHost bool
ogTimeToLive time.Duration ogPassthrough bool
approvedTags []string ogTimeToLive time.Duration
approvedPrefixes []string approvedTags []string
client *http.Client approvedPrefixes []string
maxContentLength int64 client *http.Client
} }
func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration) *OGTagCache { func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration, ogTagsConsiderHost bool) *OGTagCache {
// Predefined approved tags and prefixes // Predefined approved tags and prefixes
// In the future, these could come from configuration // In the future, these could come from configuration
defaultApprovedTags := []string{"description", "keywords", "author"} defaultApprovedTags := []string{"description", "keywords", "author"}
defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"} defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"}
client := &http.Client{
Timeout: 5 * time.Second, /*make this configurable?*/ var parsedTargetURL *url.URL
var err error
if target == "" {
// Default to localhost if target is empty
parsedTargetURL, _ = url.Parse("http://localhost")
} else {
parsedTargetURL, err = url.Parse(target)
if err != nil {
slog.Debug("og: failed to parse target URL, treating as non-unix", "target", target, "error", err)
// If parsing fails, treat it as a non-unix target for backward compatibility or default behavior
// For now, assume it's not a scheme issue but maybe an invalid char, etc.
// A simple string target might be intended if it's not a full URL.
parsedTargetURL = &url.URL{Scheme: "http", Host: target} // Assume http if scheme missing and host-like
if !strings.Contains(target, "://") && !strings.HasPrefix(target, "unix:") {
// If it looks like just a host/host:port (and not unix), prepend http:// (todo: is this bad...? Trace path to see if i can yell at user to do it right)
parsedTargetURL, _ = url.Parse("http://" + target) // fetch cares about scheme but anubis doesn't
}
}
} }
const maxContentLength = 16 << 20 // 16 MiB in bytes client := &http.Client{
Timeout: httpTimeout,
}
// Configure custom transport for Unix sockets
if parsedTargetURL.Scheme == "unix" {
socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path
client.Transport = &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socketPath)
},
}
}
return &OGTagCache{ return &OGTagCache{
cache: decaymap.New[string, map[string]string](), cache: decaymap.New[string, map[string]string](),
target: target, targetURL: parsedTargetURL, // Store the parsed URL
ogPassthrough: ogPassthrough, ogPassthrough: ogPassthrough,
ogTimeToLive: ogTimeToLive, ogTimeToLive: ogTimeToLive,
approvedTags: defaultApprovedTags, ogCacheConsiderHost: ogTagsConsiderHost, // todo: refactor to be a separate struct
approvedPrefixes: defaultApprovedPrefixes, approvedTags: defaultApprovedTags,
client: client, approvedPrefixes: defaultApprovedPrefixes,
maxContentLength: maxContentLength, client: client,
} }
} }
// getTarget constructs the target URL string for fetching OG tags.
// For Unix sockets, it creates a "fake" HTTP URL that the custom dialer understands.
func (c *OGTagCache) getTarget(u *url.URL) string { func (c *OGTagCache) getTarget(u *url.URL) string {
return c.target + u.Path if c.targetURL.Scheme == "unix" {
// The custom dialer ignores the host, but we need a valid http URL structure.
// Use "unix" as a placeholder host. Path and Query from original request are appended.
fakeURL := &url.URL{
Scheme: "http", // Scheme must be http/https for client.Get
Host: "unix", // Arbitrary host, ignored by custom dialer
Path: u.Path,
RawQuery: u.RawQuery,
}
return fakeURL.String()
}
// For regular http/https targets
target := *c.targetURL // Make a copy
target.Path = u.Path
target.RawQuery = u.RawQuery
return target.String()
} }
func (c *OGTagCache) Cleanup() { func (c *OGTagCache) Cleanup() {
c.cache.Cleanup() if c.cache != nil {
c.cache.Cleanup()
}
} }

View File

@ -1,7 +1,16 @@
package ogtags package ogtags
import ( import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url" "net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing" "testing"
"time" "time"
) )
@ -29,14 +38,23 @@ func TestNewOGTagCache(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive) cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive, false)
if cache == nil { if cache == nil {
t.Fatal("expected non-nil cache, got nil") t.Fatal("expected non-nil cache, got nil")
} }
if cache.target != tt.target { // Check the parsed targetURL, handling the default case for empty target
t.Errorf("expected target %s, got %s", tt.target, cache.target) expectedURLStr := tt.target
if tt.target == "" {
// Default behavior when target is empty is now http://localhost
expectedURLStr = "http://localhost"
} else if !strings.Contains(tt.target, "://") && !strings.HasPrefix(tt.target, "unix:") {
// Handle case where target is just host or host:port (and not unix)
expectedURLStr = "http://" + tt.target
}
if cache.targetURL.String() != expectedURLStr {
t.Errorf("expected targetURL %s, got %s", expectedURLStr, cache.targetURL.String())
} }
if cache.ogPassthrough != tt.ogPassthrough { if cache.ogPassthrough != tt.ogPassthrough {
@ -50,6 +68,45 @@ func TestNewOGTagCache(t *testing.T) {
} }
} }
// TestNewOGTagCache_UnixSocket specifically tests unix socket initialization
func TestNewOGTagCache_UnixSocket(t *testing.T) {
tempDir := t.TempDir()
socketPath := filepath.Join(tempDir, "test.sock")
target := "unix://" + socketPath
cache := NewOGTagCache(target, true, 5*time.Minute, false)
if cache == nil {
t.Fatal("expected non-nil cache, got nil")
}
if cache.targetURL.Scheme != "unix" {
t.Errorf("expected targetURL scheme 'unix', got '%s'", cache.targetURL.Scheme)
}
if cache.targetURL.Path != socketPath {
t.Errorf("expected targetURL path '%s', got '%s'", socketPath, cache.targetURL.Path)
}
// Check if the client transport is configured for Unix sockets
transport, ok := cache.client.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected client transport to be *http.Transport, got %T", cache.client.Transport)
}
if transport.DialContext == nil {
t.Fatal("expected client transport DialContext to be non-nil for unix socket")
}
// Attempt a dummy dial to see if it uses the correct path (optional, more involved check)
dummyConn, err := transport.DialContext(context.Background(), "", "")
if err == nil {
dummyConn.Close()
t.Log("DialContext seems functional, but couldn't verify path without a listener")
} else if !strings.Contains(err.Error(), "connect: connection refused") && !strings.Contains(err.Error(), "connect: no such file or directory") {
// We expect connection refused or not found if nothing is listening
t.Errorf("DialContext failed with unexpected error: %v", err)
}
}
func TestGetTarget(t *testing.T) { func TestGetTarget(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -66,24 +123,39 @@ func TestGetTarget(t *testing.T) {
expected: "http://example.com", expected: "http://example.com",
}, },
{ {
name: "With complex path", name: "With complex path",
target: "http://example.com", target: "http://example.com",
path: "/pag(#*((#@)ΓΓΓΓe/Γ", path: "/pag(#*((#@)ΓΓΓΓe/Γ",
query: "id=123", query: "id=123",
expected: "http://example.com/pag(#*((#@)ΓΓΓΓe/Γ", // Expect URL encoding and query parameter
expected: "http://example.com/pag%28%23%2A%28%28%23@%29%CE%93%CE%93%CE%93%CE%93e/%CE%93?id=123",
}, },
{ {
name: "With query and path", name: "With query and path",
target: "http://example.com", target: "http://example.com",
path: "/page", path: "/page",
query: "id=123", query: "id=123",
expected: "http://example.com/page", expected: "http://example.com/page?id=123",
},
{
name: "Unix socket target",
target: "unix:/tmp/anubis.sock",
path: "/some/path",
query: "key=value&flag=true",
expected: "http://unix/some/path?key=value&flag=true", // Scheme becomes http, host is 'unix'
},
{
name: "Unix socket target with ///",
target: "unix:///var/run/anubis.sock",
path: "/",
query: "",
expected: "http://unix/",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cache := NewOGTagCache(tt.target, false, time.Minute) cache := NewOGTagCache(tt.target, false, time.Minute, false)
u := &url.URL{ u := &url.URL{
Path: tt.path, Path: tt.path,
@ -98,3 +170,86 @@ func TestGetTarget(t *testing.T) {
}) })
} }
} }
// TestIntegrationGetOGTags_UnixSocket tests fetching OG tags via a Unix socket.
func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
tempDir := t.TempDir()
socketPath := filepath.Join(tempDir, "anubis-test.sock")
// Ensure the socket does not exist initially
_ = os.Remove(socketPath)
// Create a simple HTTP server listening on the Unix socket
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Failed to listen on unix socket %s: %v", socketPath, err)
}
defer func(listener net.Listener, socketPath string) {
if listener != nil {
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
t.Logf("Error closing listener: %v", err)
}
}
if _, err := os.Stat(socketPath); err == nil {
if err := os.Remove(socketPath); err != nil {
t.Logf("Error removing socket file %s: %v", socketPath, err)
}
}
}(listener, socketPath)
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="Unix Socket Test" /></head><body>Test</body></html>`)
}),
}
go func() {
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Logf("Unix socket server error: %v", err)
}
}()
defer func(server *http.Server, ctx context.Context) {
err := server.Shutdown(ctx)
if err != nil {
t.Logf("Error shutting down server: %v", err)
}
}(server, context.Background()) // Ensure server is shut down
// Wait a moment for the server to start
time.Sleep(100 * time.Millisecond)
// Create cache instance pointing to the Unix socket
targetURL := "unix://" + socketPath
cache := NewOGTagCache(targetURL, true, 1*time.Minute, false)
// Create a dummy URL for the request (path and query matter)
testReqURL, _ := url.Parse("/some/page?query=1")
// Get OG tags
// Pass an empty string for host, as it's irrelevant for unix sockets
ogTags, err := cache.GetOGTags(testReqURL, "")
if err != nil {
t.Fatalf("GetOGTags failed for unix socket: %v", err)
}
expectedTags := map[string]string{
"og:title": "Unix Socket Test",
}
if !reflect.DeepEqual(ogTags, expectedTags) {
t.Errorf("Expected OG tags %v, got %v", expectedTags, ogTags)
}
// Test cache retrieval (should hit cache)
// Pass an empty string for host
cachedTags, err := cache.GetOGTags(testReqURL, "")
if err != nil {
t.Fatalf("GetOGTags (cache hit) failed for unix socket: %v", err)
}
if !reflect.DeepEqual(cachedTags, expectedTags) {
t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags)
}
}

View File

@ -12,7 +12,7 @@ import (
// TestExtractOGTags updated with correct expectations based on filtering logic // TestExtractOGTags updated with correct expectations based on filtering logic
func TestExtractOGTags(t *testing.T) { func TestExtractOGTags(t *testing.T) {
// Use a cache instance that reflects the default approved lists // Use a cache instance that reflects the default approved lists
testCache := NewOGTagCache("", false, time.Minute) testCache := NewOGTagCache("", false, time.Minute, false)
// Manually set approved tags/prefixes based on the user request for clarity // Manually set approved tags/prefixes based on the user request for clarity
testCache.approvedTags = []string{"description"} testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"} testCache.approvedPrefixes = []string{"og:"}
@ -189,7 +189,7 @@ func TestIsOGMetaTag(t *testing.T) {
func TestExtractMetaTagInfo(t *testing.T) { func TestExtractMetaTagInfo(t *testing.T) {
// Use a cache instance that reflects the default approved lists // Use a cache instance that reflects the default approved lists
testCache := NewOGTagCache("", false, time.Minute) testCache := NewOGTagCache("", false, time.Minute, false)
testCache.approvedTags = []string{"description"} testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"} testCache.approvedPrefixes = []string{"og:"}

View File

@ -33,9 +33,10 @@ type Options struct {
CookieName string CookieName string
CookiePartitioned bool CookiePartitioned bool
OGPassthrough bool OGPassthrough bool
OGTimeToLive time.Duration OGTimeToLive time.Duration
Target string OGCacheConsidersHost bool
Target string
WebmasterEmail string WebmasterEmail string
BasePrefix string BasePrefix string
@ -89,7 +90,7 @@ func New(opts Options) (*Server, error) {
policy: opts.Policy, policy: opts.Policy,
opts: opts, opts: opts,
DNSBLCache: decaymap.New[string, dnsbl.DroneBLResponse](), DNSBLCache: decaymap.New[string, dnsbl.DroneBLResponse](),
OGTags: ogtags.NewOGTagCache(opts.Target, opts.OGPassthrough, opts.OGTimeToLive), OGTags: ogtags.NewOGTagCache(opts.Target, opts.OGPassthrough, opts.OGTimeToLive, opts.OGCacheConsidersHost),
} }
mux := http.NewServeMux() mux := http.NewServeMux()

View File

@ -54,7 +54,7 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic
var ogTags map[string]string = nil var ogTags map[string]string = nil
if s.opts.OGPassthrough { if s.opts.OGPassthrough {
var err error var err error
ogTags, err = s.OGTags.GetOGTags(r.URL) ogTags, err = s.OGTags.GetOGTags(r.URL, r.Host)
if err != nil { if err != nil {
lg.Error("failed to get OG tags", "err", err) lg.Error("failed to get OG tags", "err", err)
} }