work/internal/ssh/forwarder.go

107 lines
2 KiB
Go

package ssh
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
type Forwarder struct {
sshClient *ssh.Client
localPort string
remotePort string
remoteHost string
}
func NewForwarder(client *ssh.Client, localPort, remotePort, remoteHost string) *Forwarder {
return &Forwarder{
sshClient: client,
localPort: localPort,
remotePort: remotePort,
remoteHost: remoteHost,
}
}
func (f *Forwarder) Start(ctx context.Context, ready chan<- struct{}) error {
localAddr := "127.0.0.1:" + f.localPort
remoteAddr := net.JoinHostPort(f.remoteHost, f.remotePort)
listener, err := net.Listen("tcp", localAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", localAddr, err)
}
if ready != nil {
close(ready)
}
go func() {
<-ctx.Done()
listener.Close()
}()
slog.Info("Port forwarder active", "local", localAddr, "remote", remoteAddr)
for {
localConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
}
slog.Error("Accept failed", "error", err)
time.Sleep(100 * time.Millisecond)
continue
}
go f.handleConnection(ctx, localConn, remoteAddr)
}
}
func (f *Forwarder) handleConnection(ctx context.Context, localConn net.Conn, remoteAddr string) {
defer localConn.Close()
_, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
remoteConn, err := f.sshClient.Dial("tcp", remoteAddr)
if err != nil {
slog.Error("Failed to dial remote via SSH", "target", remoteAddr, "error", err)
return
}
defer remoteConn.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(localConn, remoteConn)
// localConn.SetWriteDeadline(time.Now())
localConn.Close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(remoteConn, localConn)
remoteConn.Close()
}()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
}
}