package ssh import ( "context" "fmt" "io" "log/slog" "net" "sync" "time" "golang.org/x/crypto/ssh" ) type Forwarder struct { sshClient *ssh.Client localPort string remotePort string remoteHost string } func NewForwarder(client *ssh.Client, localPort, remotePort, remoteHost string) *Forwarder { return &Forwarder{ sshClient: client, localPort: localPort, remotePort: remotePort, remoteHost: remoteHost, } } func (f *Forwarder) Start(ctx context.Context, ready chan<- struct{}) error { localAddr := "127.0.0.1:" + f.localPort remoteAddr := net.JoinHostPort(f.remoteHost, f.remotePort) listener, err := net.Listen("tcp", localAddr) if err != nil { return fmt.Errorf("failed to listen on %s: %w", localAddr, err) } if ready != nil { close(ready) } go func() { <-ctx.Done() listener.Close() }() slog.Info("Port forwarder active", "local", localAddr, "remote", remoteAddr) for { localConn, err := listener.Accept() if err != nil { select { case <-ctx.Done(): return nil default: } slog.Error("Accept failed", "error", err) time.Sleep(100 * time.Millisecond) continue } go f.handleConnection(ctx, localConn, remoteAddr) } } func (f *Forwarder) handleConnection(ctx context.Context, localConn net.Conn, remoteAddr string) { defer localConn.Close() _, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() remoteConn, err := f.sshClient.Dial("tcp", remoteAddr) if err != nil { slog.Error("Failed to dial remote via SSH", "target", remoteAddr, "error", err) return } defer remoteConn.Close() var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _, _ = io.Copy(localConn, remoteConn) // localConn.SetWriteDeadline(time.Now()) localConn.Close() }() go func() { defer wg.Done() _, _ = io.Copy(remoteConn, localConn) remoteConn.Close() }() done := make(chan struct{}) go func() { wg.Wait() close(done) }() select { case <-done: case <-ctx.Done(): } }