From 5de9ff7961b9df81d600348c9e445d5e688d00ab Mon Sep 17 00:00:00 2001 From: Patryk Hegenberg Date: Sun, 11 Jan 2026 11:44:45 +0100 Subject: [PATCH] feat: implement ready chanel to wait for established connection --- app.go | 41 +++++++++++++++++++++++++++++---------- internal/config/config.go | 5 +++++ internal/ssh/forwarder.go | 6 +++++- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/app.go b/app.go index 64d36e1..8ff9874 100644 --- a/app.go +++ b/app.go @@ -101,21 +101,43 @@ func (a *App) connect(ctx context.Context) error { tunnelCtx, cancelTunnels := context.WithCancel(ctx) defer cancelTunnels() - sshForwarder := ssh.NewForwarder(sshCon.Client, "2048", "22", a.cfg.WorkstationIP) - rdpForwarder := ssh.NewForwarder(sshCon.Client, "6000", "3389", a.cfg.WorkstationIP) + sshForwarder := ssh.NewForwarder(sshCon.Client, config.PortLocalSSH, config.PortRemoteSSH, a.cfg.WorkstationIP) + rdpForwarder := ssh.NewForwarder(sshCon.Client, config.PortLocalRDP, config.PortRemoteRDP, a.cfg.WorkstationIP) + + sshReady := make(chan struct{}) + rdpReady := make(chan struct{}) go func() { - if err := sshForwarder.Start(tunnelCtx); err != nil { + if err := sshForwarder.Start(tunnelCtx, sshReady); err != nil { slog.Error("SSH forwarder stopped", "error", err) } }() go func() { - if err := rdpForwarder.Start(tunnelCtx); err != nil { + if err := rdpForwarder.Start(tunnelCtx, rdpReady); err != nil { slog.Error("RDP forwarder stopped", "error", err) } }() - time.Sleep(200 * time.Millisecond) + slog.Info("Waiting for tunnels to initialize...") + + readyCtx, cancelReady := context.WithTimeout(ctx, 5*time.Second) + defer cancelReady() + + select { + case <-sshReady: + slog.Debug("SSH Tunnel ready") + case <-readyCtx.Done(): + return fmt.Errorf("timeout waiting for SSH tunnel readiness") + } + + select { + case <-rdpReady: + slog.Debug("RDP Tunnel ready") + case <-readyCtx.Done(): + return fmt.Errorf("timeout waiting for RDP tunnel readiness") + } + + slog.Info("All tunnels established and listening.") if a.flags.StartInBackground { fmt.Println("\nINFO: Tunnels are active in background.") @@ -220,7 +242,6 @@ func (a *App) makeChoice(ctx context.Context) error { case "kill tunnels": _ = a.killForwardings() case "set secrets": - _ = a.configCommand() fmt.Println("Please run 'workctl config set-secrets' directly from CLI.") case "exit": return nil @@ -276,7 +297,7 @@ func (a *App) wakeWorkstation() { func (a *App) connectToJump() { args := []string{ "-tt", - "-L", fmt.Sprintf("2048:%s:22", a.cfg.WorkstationHost), + "-L", fmt.Sprintf("%s:%s:%s", config.PortLocalSSH, a.cfg.WorkstationHost, config.PortRemoteSSH), "-p", fmt.Sprintf("%d", a.cfg.SSHPort), fmt.Sprintf("%s@%s", a.cfg.SSHUser, a.cfg.SSHHost), } @@ -286,8 +307,8 @@ func (a *App) connectToJump() { func (a *App) connectToWorkstation() { args := []string{ "-tt", - "-L", fmt.Sprintf("6000:%s:3389", a.cfg.WorkstationHost), - "-p", "2048", + "-L", fmt.Sprintf("%s:%s:%s", config.PortLocalRDP, a.cfg.WorkstationHost, config.PortRemoteRDP), + "-p", config.PortLocalSSH, fmt.Sprintf("%s@127.0.0.1", a.cfg.WorkstationUser), } _ = a.runCommand("ssh", args...) @@ -297,7 +318,7 @@ func (a *App) startRDPConnection() { args := []string{ fmt.Sprintf("/u:%s", a.cfg.RDPUser), fmt.Sprintf("/p:%s", a.cfg.RDPPassword), - "/v:127.0.0.1:6000", + fmt.Sprintf("/v:127.0.0.1:%s", config.PortLocalRDP), "/size:3000x1350", "+clipboard", "/dynamic-resolution", diff --git a/internal/config/config.go b/internal/config/config.go index 1e75913..17c7240 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,6 +14,11 @@ const ( serviceName = "workctl" keySSHPassword = "ssh-password" keyRDPPassword = "rdp-password" + + PortLocalSSH = "2048" + PortLocalRDP = "6000" + PortRemoteSSH = "22" + PortRemoteRDP = "3389" ) type Config struct { diff --git a/internal/ssh/forwarder.go b/internal/ssh/forwarder.go index 281d25e..803ca09 100644 --- a/internal/ssh/forwarder.go +++ b/internal/ssh/forwarder.go @@ -28,7 +28,7 @@ func NewForwarder(client *ssh.Client, localPort, remotePort, remoteHost string) } } -func (f *Forwarder) Start(ctx context.Context) error { +func (f *Forwarder) Start(ctx context.Context, ready chan<- struct{}) error { localAddr := "127.0.0.1:" + f.localPort remoteAddr := net.JoinHostPort(f.remoteHost, f.remotePort) @@ -37,6 +37,10 @@ func (f *Forwarder) Start(ctx context.Context) error { return fmt.Errorf("failed to listen on %s: %w", localAddr, err) } + if ready != nil { + close(ready) + } + go func() { <-ctx.Done() listener.Close()