From e5b344f4ebf29136104689ea2c77e6473492a19f Mon Sep 17 00:00:00 2001 From: James Mills Date: Sat, 9 Jul 2016 18:03:28 -0700 Subject: [PATCH] Added keepalive feature to keep clients connected --- sshd/net.go | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/sshd/net.go b/sshd/net.go index 84d6269..bf40bfc 100644 --- a/sshd/net.go +++ b/sshd/net.go @@ -2,6 +2,7 @@ package sshd import ( "net" + "time" "github.com/shazow/rateio" "golang.org/x/crypto/ssh" @@ -24,7 +25,7 @@ func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { return &l, nil } -func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { +func (l *SSHListener) handleConn(conn net.Conn, stop <-chan bool) (*Terminal, error) { if l.RateLimit != nil { // TODO: Configurable Limiter? conn = ReadLimitConn(conn, l.RateLimit()) @@ -38,11 +39,17 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { // FIXME: Disconnect if too many faulty requests? (Avoid DoS.) go ssh.DiscardRequests(requests) - return NewSession(sshConn, channels) + terminal, err := NewSession(sshConn, channels) + if err != nil { + return nil, err + } + go KeepAlive(terminal, 2, stop) + return terminal, err } // Accept incoming connections as terminal requests and yield them func (l *SSHListener) ServeTerminal() <-chan *Terminal { + stop := make(chan bool) ch := make(chan *Terminal) go func() { @@ -59,7 +66,7 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal { // Goroutineify to resume accepting sockets early go func() { - term, err := l.handleConn(conn) + term, err := l.handleConn(conn, stop) if err != nil { logger.Printf("Failed to handshake: %v", err) return @@ -71,3 +78,22 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal { return ch } + +// KeepAlive Setup a new keepalive goroutine +func KeepAlive(t *Terminal, interval time.Duration, stop <-chan bool) { + // this sends keepalive packets every 2 seconds + // there's no useful response from these, so we can just abort if there's an error + tick := time.NewTicker(interval * time.Second) + defer tick.Stop() + for { + select { + case <-tick.C: + _, err := t.Channel.SendRequest("keepalive", true, nil) + if err != nil { + return + } + case <-stop: + return + } + } +}