diff --git a/cmd/osiris/internal/entrypoint/entrypoint.go b/cmd/osiris/internal/entrypoint/entrypoint.go index 611f809..73d696e 100644 --- a/cmd/osiris/internal/entrypoint/entrypoint.go +++ b/cmd/osiris/internal/entrypoint/entrypoint.go @@ -2,6 +2,7 @@ package entrypoint import ( "context" + "crypto/tls" "fmt" "log/slog" "net" @@ -12,6 +13,7 @@ import ( "github.com/TecharoHQ/anubis/cmd/osiris/internal/config" "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/internal/fingerprint" "github.com/hashicorp/hcl/v2/hclsimple" "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/errgroup" @@ -57,7 +59,7 @@ func Main(opts Options) error { ln.Close() }(gCtx) - slog.Info("listening for HTTP", "bind", cfg.Bind.HTTP) + slog.Info("listening", "for", "http", "bind", cfg.Bind.HTTP) srv := http.Server{Handler: rtr, ErrorLog: internal.GetFilteredHTTPLogger()} @@ -65,6 +67,35 @@ func Main(opts Options) error { }) // HTTPS + g.Go(func() error { + ln, err := net.Listen("tcp", cfg.Bind.HTTPS) + if err != nil { + return fmt.Errorf("(https) can't bind to tcp %s: %w", cfg.Bind.HTTPS, err) + } + defer ln.Close() + + go func(ctx context.Context) { + <-ctx.Done() + ln.Close() + }(gCtx) + + tc := &tls.Config{ + GetCertificate: rtr.GetCertificate, + } + + srv := &http.Server{ + Addr: cfg.Bind.HTTPS, + Handler: rtr, + ErrorLog: internal.GetFilteredHTTPLogger(), + TLSConfig: tc, + } + + fingerprint.ApplyTLSFingerprinter(srv) + + slog.Info("listening", "for", "https", "bind", cfg.Bind.HTTPS) + + return srv.ServeTLS(ln, "", "") + }) // Metrics g.Go(func() error { @@ -101,12 +132,18 @@ func Main(opts Options) error { } }) - slog.Info("listening for Metrics", "bind", cfg.Bind.Metrics) + slog.Info("listening", "for", "metrics", "bind", cfg.Bind.Metrics) - srv := http.Server{Handler: mux, ErrorLog: internal.GetFilteredHTTPLogger()} + srv := http.Server{ + Addr: cfg.Bind.Metrics, + Handler: mux, + ErrorLog: internal.GetFilteredHTTPLogger(), + } return srv.Serve(ln) }) + internal.SetHealth("osiris", healthv1.HealthCheckResponse_SERVING) + return g.Wait() } diff --git a/cmd/osiris/internal/entrypoint/router.go b/cmd/osiris/internal/entrypoint/router.go index d120fba..0f720ad 100644 --- a/cmd/osiris/internal/entrypoint/router.go +++ b/cmd/osiris/internal/entrypoint/router.go @@ -2,6 +2,7 @@ package entrypoint import ( "context" + "crypto/tls" "errors" "fmt" "log/slog" @@ -13,14 +14,17 @@ import ( "sync" "github.com/TecharoHQ/anubis/cmd/osiris/internal/config" + "github.com/TecharoHQ/anubis/internal/fingerprint" "github.com/lum8rjack/go-ja4h" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) var ( - ErrTargetInvalid = errors.New("[unexpected] target invalid") - ErrNoHandler = errors.New("[unexpected] no handler for domain") + ErrTargetInvalid = errors.New("[unexpected] target invalid") + ErrNoHandler = errors.New("[unexpected] no handler for domain") + ErrInvalidTLSKeypair = errors.New("[unexpected] invalid TLS keypair") + ErrNoCert = errors.New("this server does not have a certificate for that domain") requestsPerDomain = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "techaro", @@ -36,13 +40,15 @@ var ( ) type Router struct { - lock sync.RWMutex - routes map[string]http.Handler + lock sync.RWMutex + routes map[string]http.Handler + tlsCerts map[string]*tls.Certificate } func (rtr *Router) setConfig(c config.Toplevel) error { var errs []error newMap := map[string]http.Handler{} + newCerts := map[string]*tls.Certificate{} for _, d := range c.Domains { var domainErrs []error @@ -75,6 +81,13 @@ func (rtr *Router) setConfig(c config.Toplevel) error { newMap[d.Name] = h + cert, err := tls.LoadX509KeyPair(d.TLS.Cert, d.TLS.Key) + if err != nil { + domainErrs = append(domainErrs, fmt.Errorf("%w: %w", ErrInvalidTLSKeypair, err)) + } + + newCerts[d.Name] = &cert + if len(domainErrs) != 0 { errs = append(errs, fmt.Errorf("invalid domain %s: %w", d.Name, errors.Join(domainErrs...))) } @@ -86,11 +99,24 @@ func (rtr *Router) setConfig(c config.Toplevel) error { rtr.lock.Lock() rtr.routes = newMap + rtr.tlsCerts = newCerts rtr.lock.Unlock() return nil } +func (rtr *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + rtr.lock.RLock() + cert, ok := rtr.tlsCerts[hello.ServerName] + rtr.lock.RUnlock() + + if !ok { + return nil, ErrNoCert + } + + return cert, nil +} + func NewRouter(c config.Toplevel) (*Router, error) { result := &Router{ routes: map[string]http.Handler{}, @@ -104,17 +130,23 @@ func NewRouter(c config.Toplevel) (*Router, error) { } func (rtr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - requestsPerDomain.WithLabelValues(r.Host).Inc() + var host = r.Host + + if strings.Contains(host, ":") { + host, _, _ = net.SplitHostPort(host) + } + + requestsPerDomain.WithLabelValues(host).Inc() var h http.Handler var ok bool ja4hFP := ja4h.JA4H(r) - slog.Info("got request", "method", r.Method, "host", r.Host, "path", r.URL.Path) + slog.Info("got request", "method", r.Method, "host", host, "path", r.URL.Path) rtr.lock.RLock() - h, ok = rtr.routes[r.Host] + h, ok = rtr.routes[host] rtr.lock.RUnlock() if !ok { @@ -125,5 +157,18 @@ func (rtr *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Header.Set("X-Http-Ja4h-Fingerprint", ja4hFP) + if fp := fingerprint.GetTLSFingerprint(r); fp != nil { + if ja3n := fp.JA3N(); ja3n != nil { + r.Header.Set("X-Tls-Ja3n-Fingerprint", ja3n.String()) + } + if ja4 := fp.JA4(); ja4 != nil { + r.Header.Set("X-Tls-Ja4-Fingerprint", ja4.String()) + } + } + + if tcpFP := fingerprint.GetTCPFingerprint(r); tcpFP != nil { + r.Header.Set("X-Tcp-Ja4t-Fingerprint", tcpFP.String()) + } + h.ServeHTTP(w, r) } diff --git a/cmd/osiris/osiris.hcl b/cmd/osiris/osiris.hcl index c20dac9..9eaafa8 100644 --- a/cmd/osiris/osiris.hcl +++ b/cmd/osiris/osiris.hcl @@ -4,7 +4,7 @@ bind { metrics = ":9091" } -domain "anubis.techaro.lol" { +domain "osiris.local.cetacean.club" { tls { cert = "./internal/config/testdata/tls/selfsigned.crt" key = "./internal/config/testdata/tls/selfsigned.key" diff --git a/internal/fingerprint/ja3n.go b/internal/fingerprint/ja3n.go new file mode 100644 index 0000000..a0e5b29 --- /dev/null +++ b/internal/fingerprint/ja3n.go @@ -0,0 +1,97 @@ +package fingerprint + +import ( + "crypto/md5" + "crypto/tls" + "encoding/hex" + "slices" + "strconv" +) + +// TLSFingerprintJA3N represents a JA3N fingerprint +type TLSFingerprintJA3N [md5.Size]byte + +func (f TLSFingerprintJA3N) String() string { + return hex.EncodeToString(f[:]) +} + +func buildJA3N(hello *tls.ClientHelloInfo, sortExtensions bool) TLSFingerprintJA3N { + buf := make([]byte, 0, 256) + + { + var sslVersion uint16 + var hasGrease bool + for _, v := range hello.SupportedVersions { + if v&greaseMask != greaseValue { + if v > sslVersion { + sslVersion = v + } + } else { + hasGrease = true + } + } + + // maximum TLS 1.2 as specified on JA3, as TLS 1.3 is put in SupportedVersions + if slices.Contains(hello.Extensions, extensionSupportedVersions) && hasGrease && sslVersion > tls.VersionTLS12 { + sslVersion = tls.VersionTLS12 + } + + buf = strconv.AppendUint(buf, uint64(sslVersion), 10) + buf = append(buf, ',') + } + + n := 0 + for _, cipher := range hello.CipherSuites { + //if !slices.Contains(greaseValues[:], cipher) { + if cipher&greaseMask != greaseValue { + buf = strconv.AppendUint(buf, uint64(cipher), 10) + buf = append(buf, '-') + n = 1 + } + } + + buf = buf[:len(buf)-n] + buf = append(buf, ',') + n = 0 + + extensions := hello.Extensions + if sortExtensions { + extensions = slices.Clone(extensions) + slices.Sort(extensions) + } + + for _, extension := range extensions { + if extension&greaseMask != greaseValue { + buf = strconv.AppendUint(buf, uint64(extension), 10) + buf = append(buf, '-') + n = 1 + } + } + + buf = buf[:len(buf)-n] + buf = append(buf, ',') + n = 0 + + for _, curve := range hello.SupportedCurves { + if curve&greaseMask != greaseValue { + buf = strconv.AppendUint(buf, uint64(curve), 10) + buf = append(buf, '-') + n = 1 + } + } + + buf = buf[:len(buf)-n] + buf = append(buf, ',') + n = 0 + + for _, point := range hello.SupportedPoints { + buf = strconv.AppendUint(buf, uint64(point), 10) + buf = append(buf, '-') + n = 1 + } + + buf = buf[:len(buf)-n] + + sum := md5.Sum(buf) + return TLSFingerprintJA3N(sum[:]) +} diff --git a/internal/fingerprint/ja4.go b/internal/fingerprint/ja4.go new file mode 100644 index 0000000..cee0143 --- /dev/null +++ b/internal/fingerprint/ja4.go @@ -0,0 +1,178 @@ +package fingerprint + +import ( + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "fmt" + "slices" + "strconv" + "strings" +) + +// TLSFingerprintJA4 represents a JA4 fingerprint +type TLSFingerprintJA4 struct { + A [10]byte + B [6]byte + C [6]byte +} + +func (f *TLSFingerprintJA4) String() string { + if f == nil { + return "" + } + + return strings.Join([]string{ + string(f.A[:]), + hex.EncodeToString(f.B[:]), + hex.EncodeToString(f.C[:]), + }, "_") +} + +func buildJA4(hello *tls.ClientHelloInfo) (ja4 TLSFingerprintJA4) { + buf := make([]byte, 0, 36) + + hasQuic := false + + for _, ext := range hello.Extensions { + if ext == extensionQUICTransportParameters { + hasQuic = true + } + } + + switch hasQuic { + case true: + buf = append(buf, 'q') + case false: + buf = append(buf, 't') + } + + { + var sslVersion uint16 + for _, v := range hello.SupportedVersions { + if v&greaseMask != greaseValue { + if v > sslVersion { + sslVersion = v + } + } + } + + switch sslVersion { + case tls.VersionSSL30: + buf = append(buf, 's', '3') + case tls.VersionTLS10: + buf = append(buf, '1', '0') + case tls.VersionTLS11: + buf = append(buf, '1', '1') + case tls.VersionTLS12: + buf = append(buf, '1', '2') + case tls.VersionTLS13: + buf = append(buf, '1', '3') + default: + sslVersion -= 0x0201 + buf = strconv.AppendUint(buf, uint64(sslVersion>>8), 10) + buf = strconv.AppendUint(buf, uint64(sslVersion&0xff), 10) + } + + } + + if slices.Contains(hello.Extensions, extensionServerName) && hello.ServerName != "" { + buf = append(buf, 'd') + } else { + buf = append(buf, 'i') + } + + ciphers := make([]uint16, 0, len(hello.CipherSuites)) + for _, cipher := range hello.CipherSuites { + if cipher&greaseMask != greaseValue { + ciphers = append(ciphers, cipher) + } + } + + extensionCount := 0 + extensions := make([]uint16, 0, len(hello.Extensions)) + for _, extension := range hello.Extensions { + if extension&greaseMask != greaseValue { + extensionCount++ + if extension != extensionALPN && extension != extensionServerName { + extensions = append(extensions, extension) + } + } + } + + schemes := make([]tls.SignatureScheme, 0, len(hello.SignatureSchemes)) + + for _, scheme := range hello.SignatureSchemes { + if scheme&greaseMask != greaseValue { + schemes = append(schemes, scheme) + } + } + + //TODO: maybe little endian + slices.Sort(ciphers) + slices.Sort(extensions) + //slices.Sort(schemes) + + if len(ciphers) < 10 { + buf = append(buf, '0') + buf = strconv.AppendUint(buf, uint64(len(ciphers)), 10) + } else if len(ciphers) > 99 { + buf = append(buf, '9', '9') + } else { + buf = strconv.AppendUint(buf, uint64(len(ciphers)), 10) + } + + if extensionCount < 10 { + buf = append(buf, '0') + buf = strconv.AppendUint(buf, uint64(extensionCount), 10) + } else if extensionCount > 99 { + buf = append(buf, '9', '9') + } else { + buf = strconv.AppendUint(buf, uint64(extensionCount), 10) + } + + if len(hello.SupportedProtos) > 0 && len(hello.SupportedProtos[0]) > 1 { + buf = append(buf, hello.SupportedProtos[0][0], hello.SupportedProtos[0][len(hello.SupportedProtos[0])-1]) + } else { + buf = append(buf, '0', '0') + } + + copy(ja4.A[:], buf) + + ja4.B = ja4SHA256(uint16SliceToHex(ciphers)) + + extBuf := uint16SliceToHex(extensions) + + if len(schemes) > 0 { + extBuf = append(extBuf, '_') + extBuf = append(extBuf, uint16SliceToHex(schemes)...) + } + + ja4.C = ja4SHA256(extBuf) + + return ja4 +} + +func uint16SliceToHex[T ~uint16](in []T) (out []byte) { + if len(in) == 0 { + return out + } + out = slices.Grow(out, hex.EncodedLen(len(in)*2)+len(in)) + + for _, n := range in { + out = append(out, fmt.Sprintf("%04x", uint16(n))...) + out = append(out, ',') + } + out = out[:len(out)-1] + + return out +} + +func ja4SHA256(buf []byte) [6]byte { + if len(buf) == 0 { + return [6]byte{0, 0, 0, 0, 0, 0} + } + sum := sha256.Sum256(buf) + + return [6]byte(sum[:6]) +} diff --git a/internal/fingerprint/tcp.go b/internal/fingerprint/tcp.go new file mode 100644 index 0000000..21c5c6e --- /dev/null +++ b/internal/fingerprint/tcp.go @@ -0,0 +1,46 @@ +package fingerprint + +import ( + "fmt" + "net/http" + "strings" +) + +// JA4T represents a TCP fingerprint +type JA4T struct { + Window uint32 + Options []uint8 + MSS uint16 + WindowScale uint8 +} + +func (j JA4T) String() string { + var sb strings.Builder + + // Start with the window size + fmt.Fprintf(&sb, "%d", j.Window) + + // Append each option + for i, opt := range j.Options { + if i == 0 { + fmt.Fprint(&sb, "_") + } else { + fmt.Fprint(&sb, "-") + } + fmt.Fprintf(&sb, "%d", opt) + } + + // Append MSS and WindowScale + fmt.Fprintf(&sb, "_%d_%d", j.MSS, j.WindowScale) + + return sb.String() +} + +// GetTCPFingerprint extracts TCP fingerprint from HTTP request context +func GetTCPFingerprint(r *http.Request) *JA4T { + ptr := r.Context().Value(tcpFingerprintKey{}) + if fpPtr, ok := ptr.(*JA4T); ok && ptr != nil && fpPtr != nil { + return fpPtr + } + return nil +} diff --git a/internal/fingerprint/tcp_freebsd.go b/internal/fingerprint/tcp_freebsd.go new file mode 100644 index 0000000..4f35433 --- /dev/null +++ b/internal/fingerprint/tcp_freebsd.go @@ -0,0 +1,106 @@ +//go:build freebsd + +package fingerprint + +import ( + "fmt" + "net" + "syscall" + "unsafe" +) + +type tcpInfo struct { + State uint8 + Options uint8 + SndScale uint8 + RcvScale uint8 + __pad [4]byte + Rto uint32 + Ato uint32 + SndMss uint32 + RcvMss uint32 + Unacked uint32 + Sacked uint32 + Lost uint32 + Retrans uint32 + Fackets uint32 + Last_data_sent uint32 + Last_ack_sent uint32 + Last_data_recv uint32 + Last_ack_recv uint32 + Pmtu uint32 + Rcv_ssthresh uint32 + RTT uint32 + RTTvar uint32 + Snd_ssthresh uint32 + Snd_cwnd uint32 + Advmss uint32 + Reordering uint32 + Rcv_rtt uint32 + Rcv_space uint32 + Total_retrans uint32 + Snd_wnd uint32 + // Truncated for brevity — add more fields if needed +} + +// AssignTCPFingerprint extracts TCP fingerprint information from a connection +func AssignTCPFingerprint(conn net.Conn) (*JA4T, error) { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("not a TCPConn") + } + + rawConn, err := tcpConn.SyscallConn() + if err != nil { + return nil, fmt.Errorf("SyscallConn failed: %w", err) + } + + var info tcpInfo + var sysErr error + + err = rawConn.Control(func(fd uintptr) { + size := uint32(unsafe.Sizeof(info)) + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + fd, + uintptr(syscall.IPPROTO_TCP), + uintptr(syscall.TCP_INFO), + uintptr(unsafe.Pointer(&info)), + uintptr(unsafe.Pointer(&size)), + 0, + ) + if errno != 0 { + sysErr = errno + } + }) + if err != nil { + return nil, fmt.Errorf("SyscallConn.Control: %w", err) + } + if sysErr != nil { + return nil, fmt.Errorf("getsockopt TCP_INFO: %w", sysErr) + } + + fp := &JA4T{ + Window: info.Snd_wnd, + MSS: uint16(info.SndMss), + WindowScale: info.SndScale, + } + + const ( + TCPI_OPT_TIMESTAMPS = 1 << 0 + TCPI_OPT_SACK = 1 << 1 + TCPI_OPT_WSCALE = 1 << 2 + ) + + if info.Options&TCPI_OPT_SACK != 0 { + fp.Options = append(fp.Options, 4, 1) + } + if info.Options&TCPI_OPT_TIMESTAMPS != 0 { + fp.Options = append(fp.Options, 8, 1) + } + if info.Options&TCPI_OPT_WSCALE != 0 { + fp.Options = append(fp.Options, 3) + } + + return fp, nil +} diff --git a/internal/fingerprint/tcp_linux.go b/internal/fingerprint/tcp_linux.go new file mode 100644 index 0000000..1c5ac29 --- /dev/null +++ b/internal/fingerprint/tcp_linux.go @@ -0,0 +1,132 @@ +//go:build linux + +package fingerprint + +import ( + "fmt" + "net" + "syscall" + "unsafe" +) + +type tcpInfo struct { + State uint8 + Ca_state uint8 + Retransmits uint8 + Probes uint8 + Backoff uint8 + Options uint8 + Wnd_scale uint8 + Delivery_rate_app_limited uint8 + + Rto uint32 + Ato uint32 + SndMss uint32 + RcvMss uint32 + + Unacked uint32 + Sacked uint32 + Lost uint32 + Retrans uint32 + Fackets uint32 + + Last_data_sent uint32 + Last_ack_sent uint32 + Last_data_recv uint32 + Last_ack_recv uint32 + PMTU uint32 + Rcv_ssthresh uint32 + RTT uint32 + RTTvar uint32 + Snd_ssthresh uint32 + Snd_cwnd uint32 + Advmss uint32 + Reordering uint32 + Rcv_rtt uint32 + Rcv_space uint32 + Total_retrans uint32 + Pacing_rate uint64 + Max_pacing_rate uint64 + Bytes_acked uint64 + Bytes_received uint64 + Segs_out uint32 + Segs_in uint32 + Notsent_bytes uint32 + Min_rtt uint32 + Data_segs_in uint32 + Data_segs_out uint32 + Delivery_rate uint64 + Busy_time uint64 + Rwnd_limited uint64 + Sndbuf_limited uint64 + Delivered uint32 + Delivered_ce uint32 + Bytes_sent uint64 + Bytes_retrans uint64 + DSACK_dups uint32 + Reord_seen uint32 + Rcv_ooopack uint32 + Snd_wnd uint32 +} + +// AssignTCPFingerprint extracts TCP fingerprint information from a connection +func AssignTCPFingerprint(conn net.Conn) (*JA4T, error) { + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("not a TCPConn") + } + + rawConn, err := tcpConn.SyscallConn() + if err != nil { + return nil, fmt.Errorf("SyscallConn failed: %w", err) + } + + var info tcpInfo + var sysErr error + + err = rawConn.Control(func(fd uintptr) { + size := uint32(unsafe.Sizeof(info)) + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + fd, + uintptr(syscall.IPPROTO_TCP), + uintptr(syscall.TCP_INFO), + uintptr(unsafe.Pointer(&info)), + uintptr(unsafe.Pointer(&size)), + 0, + ) + if errno != 0 { + sysErr = errno + } + }) + if err != nil { + return nil, fmt.Errorf("SyscallConn.Control: %w", err) + } + if sysErr != nil { + return nil, fmt.Errorf("getsockopt TCP_INFO: %w", sysErr) + } + + fp := &JA4T{ + Window: info.Snd_wnd, + MSS: uint16(info.SndMss), + } + + const ( + TCPI_OPT_TIMESTAMPS = 1 << 0 + TCPI_OPT_SACK = 1 << 1 + TCPI_OPT_WSCALE = 1 << 2 + ) + + if info.Options&TCPI_OPT_SACK != 0 { + fp.Options = append(fp.Options, 4, 1) + } + if info.Options&TCPI_OPT_TIMESTAMPS != 0 { + fp.Options = append(fp.Options, 8, 1) + } + if info.Options&TCPI_OPT_WSCALE != 0 { + fp.Options = append(fp.Options, 3) + fp.WindowScale = info.Wnd_scale + } + + return fp, nil +} diff --git a/internal/fingerprint/tcp_unsupported.go b/internal/fingerprint/tcp_unsupported.go new file mode 100644 index 0000000..3fbd259 --- /dev/null +++ b/internal/fingerprint/tcp_unsupported.go @@ -0,0 +1,11 @@ +//go:build !linux && !freebsd + +package fingerprint + +import "net" + +// AssignTCPFingerprint is not supported on this platform +func AssignTCPFingerprint(conn net.Conn) (*JA4T, error) { + // Not supported on macOS and other platforms + return &JA4T{}, nil +} diff --git a/internal/fingerprint/tls.go b/internal/fingerprint/tls.go new file mode 100644 index 0000000..a8a6d42 --- /dev/null +++ b/internal/fingerprint/tls.go @@ -0,0 +1,110 @@ +package fingerprint + +import ( + "context" + "crypto/tls" + "log/slog" + "net" + "net/http" + "sync/atomic" +) + +// ApplyTLSFingerprinter configures a TLS server to capture TLS fingerprints +func ApplyTLSFingerprinter(server *http.Server) { + if server.TLSConfig == nil { + return + } + server.TLSConfig = server.TLSConfig.Clone() + + getConfigForClient := server.TLSConfig.GetConfigForClient + + if getConfigForClient == nil { + getConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + } + + server.TLSConfig.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) { + ja3n, ja4 := buildTLSFingerprint(clientHello) + ptr := clientHello.Context().Value(tlsFingerprintKey{}) + if fpPtr, ok := ptr.(*TLSFingerprint); ok && ptr != nil && fpPtr != nil { + fpPtr.ja3n.Store(&ja3n) + fpPtr.ja4.Store(&ja4) + } + return getConfigForClient(clientHello) + } + server.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + ctx = context.WithValue(ctx, tlsFingerprintKey{}, &TLSFingerprint{}) + + if tc, ok := c.(*tls.Conn); ok { + tcpFP, err := AssignTCPFingerprint(tc.NetConn()) + if err == nil { + ctx = context.WithValue(ctx, tcpFingerprintKey{}, tcpFP) + } else { + slog.Debug("ja4t error", "err", err) + } + } + + return ctx + } +} + +type tcpFingerprintKey struct{} +type tlsFingerprintKey struct{} + +// TLSFingerprint represents TLS fingerprint data +type TLSFingerprint struct { + ja3n atomic.Pointer[TLSFingerprintJA3N] + ja4 atomic.Pointer[TLSFingerprintJA4] +} + +// JA3N returns the JA3N fingerprint +func (f *TLSFingerprint) JA3N() *TLSFingerprintJA3N { + return f.ja3n.Load() +} + +// JA4 returns the JA4 fingerprint +func (f *TLSFingerprint) JA4() *TLSFingerprintJA4 { + return f.ja4.Load() +} + +const greaseMask = 0x0F0F +const greaseValue = 0x0a0a + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionExtendedMasterSecret uint16 = 23 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionQUICTransportParameters uint16 = 57 + extensionRenegotiationInfo uint16 = 0xff01 + extensionECHOuterExtensions uint16 = 0xfd00 + extensionEncryptedClientHello uint16 = 0xfe0d +) + +func buildTLSFingerprint(hello *tls.ClientHelloInfo) (TLSFingerprintJA3N, TLSFingerprintJA4) { + return TLSFingerprintJA3N(buildJA3N(hello, true)), buildJA4(hello) +} + +// GetTLSFingerprint extracts TLS fingerprint from HTTP request context +func GetTLSFingerprint(r *http.Request) *TLSFingerprint { + ptr := r.Context().Value(tlsFingerprintKey{}) + if fpPtr, ok := ptr.(*TLSFingerprint); ok && ptr != nil && fpPtr != nil { + return fpPtr + } + return nil +}