work/app.go

360 lines
9.6 KiB
Go

package main
import (
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"strings"
"time"
sshPkg "golang.org/x/crypto/ssh"
"workctl/internal/config"
"workctl/internal/ssh"
"workctl/internal/store"
"github.com/charmbracelet/huh"
)
type Flags struct {
ShowWeek bool
ShowMonth bool
ShowExport bool
ExportName string
StartInBackground bool
WithoutTimew bool
}
type App struct {
cfg config.Config
store *store.Store
flags Flags
}
func NewApp() (*App, error) {
cfg, err := config.Load()
if err != nil {
return nil, fmt.Errorf("error loading config: %w", err)
}
st, err := store.NewStore()
if err != nil {
return nil, fmt.Errorf("error initializing time store: %w", err)
}
return &App{
cfg: cfg,
store: st,
flags: Flags{},
}, nil
}
func (a *App) Close() error {
if a.store != nil {
return a.store.Close()
}
return nil
}
func (a *App) Execute(ctx context.Context) error {
if len(os.Args) > 1 {
return a.setupCommands().ExecuteContext(ctx)
}
return a.makeChoice(ctx)
}
func (a *App) StartTracking(ctx context.Context, tag string) error {
if err := a.store.StartTracking(ctx, tag); err != nil {
return err
}
if !a.flags.WithoutTimew {
_ = a.runCommand("timew", "start", tag)
}
return nil
}
func (a *App) StopTracking(ctx context.Context) error {
if err := a.store.StopTracking(ctx); err != nil {
return err
}
if !a.flags.WithoutTimew {
_ = a.runCommand("timew", "stop")
}
return nil
}
func (a *App) connect(ctx context.Context) error {
if err := a.StartTracking(ctx, store.TagWork); err != nil {
slog.Warn("Failed to start time tracking", "error", err)
}
a.wakeWorkstation()
sshCon, err := ssh.NewConnection(a.cfg.SSHUser, a.cfg.SSHHost, a.cfg.SSHPort, a.getSSHAuth())
if err != nil {
return fmt.Errorf("failed to establish primary SSH connection: %w", err)
}
defer sshCon.Close()
slog.Info("SSH connection established. Setting up tunnels...")
tunnelCtx, cancelTunnels := context.WithCancel(ctx)
defer cancelTunnels()
sshForwarder := ssh.NewForwarder(sshCon.Client, "2048", "22", a.cfg.WorkstationIP)
rdpForwarder := ssh.NewForwarder(sshCon.Client, "6000", "3389", a.cfg.WorkstationIP)
go func() {
if err := sshForwarder.Start(tunnelCtx); err != nil {
slog.Error("SSH forwarder stopped", "error", err)
}
}()
go func() {
if err := rdpForwarder.Start(tunnelCtx); err != nil {
slog.Error("RDP forwarder stopped", "error", err)
}
}()
time.Sleep(200 * time.Millisecond)
if a.flags.StartInBackground {
fmt.Println("\nINFO: Tunnels are active in background.")
fmt.Println(" Connect manually via SSH: ssh -p 2048 <user>@127.0.0.1")
fmt.Println(" Connect manually via RDP: xfreerdp /v:127.0.0.1:6000 ...")
fmt.Println("INFO: Press Ctrl+C to stop.")
<-ctx.Done()
slog.Info("Context cancelled, shutting down tunnels...")
} else {
fmt.Println("Automatically connecting to workstation via SSH tunnel...")
a.connectToWorkstation()
fmt.Println("Workstation SSH session finished.")
}
if err := a.StopTracking(context.Background()); err != nil {
slog.Warn("Failed to stop time tracking", "error", err)
} else {
slog.Info("Time tracking stopped.")
}
return nil
}
func (a *App) runCommand(name string, args ...string) error {
slog.Info("Executing command", "cmd", name, "args", args)
cmd := exec.Command(name, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin
if err := cmd.Run(); err != nil {
slog.Error("Command failed", "cmd", name, "error", err)
return err
}
return nil
}
func (a *App) makeChoice(ctx context.Context) error {
var choice string
form := huh.NewForm(
huh.NewGroup(
huh.NewSelect[string]().
Title("What would you like to do?").
Options(
huh.NewOption("Start Work & Connect", "start work"),
huh.NewOption("Stop Work", "stop work"),
huh.NewOption("Start Break", "start break"),
huh.NewOption("Stop Break", "stop break"),
huh.NewOption("Show Today Summary", "show day summary"),
huh.NewOption("Show Week Summary", "show week summary"),
huh.NewOption("Show Month Summary", "show month summary"),
huh.NewOption("Export Yearly Timetable", "export"),
huh.NewOption("Connect to Jump Host (Tunnel)", "connect to jump"),
huh.NewOption("Connect to Workstation (Tunnel)", "connect to workstation"),
huh.NewOption("Start RDP Connection", "start rdp connection"),
huh.NewOption("Wake Workstation", "wake workstation"),
huh.NewOption("Kill Active Tunnels", "kill tunnels"),
huh.NewOption("Config: Set Secrets", "set secrets"),
huh.NewOption("Exit", "exit"),
).
Value(&choice),
),
)
if err := form.Run(); err != nil {
return nil
}
switch choice {
case "start work":
return a.connect(ctx)
case "stop work":
if err := a.StopTracking(ctx); err != nil {
slog.Error("Failed to stop time tracking", "error", err)
}
_ = a.killForwardings()
case "start break":
if err := a.StartTracking(ctx, store.TagBreak); err != nil {
slog.Error("Failed to start break", "error", err)
}
case "stop break":
if err := a.StartTracking(ctx, store.TagWork); err != nil {
slog.Error("Failed to stop break", "error", err)
}
case "show day summary":
_ = a.store.ShowSummary(ctx, "today")
case "show week summary":
_ = a.store.ShowSummary(ctx, "week")
case "show month summary":
_ = a.store.ShowSummary(ctx, "month")
case "export":
filename := "Arbeitszeiten_" + time.Now().Format("2006") + ".xlsx"
_ = a.store.ExportSummary(ctx, filename)
case "connect to jump":
a.connectToJump()
case "connect to workstation":
a.connectToWorkstation()
case "start rdp connection":
a.startRDPConnection()
case "wake workstation":
a.wakeWorkstation()
case "kill tunnels":
_ = a.killForwardings()
case "set secrets":
_ = a.configCommand()
fmt.Println("Please run 'workctl config set-secrets' directly from CLI.")
case "exit":
return nil
}
if choice != "exit" && choice != "start work" {
fmt.Println("\nPress Enter to continue...")
fmt.Scanln()
return a.makeChoice(ctx)
}
return nil
}
func (a *App) getSSHAuth() sshPkg.AuthMethod {
keyPath := os.ExpandEnv("$HOME/.ssh/hegenberg")
keyBytes, err := os.ReadFile(keyPath)
if err != nil {
slog.Error("Unable to read private key", "path", keyPath, "error", err)
return nil
}
key, err := sshPkg.ParsePrivateKey(keyBytes)
if err != nil {
if _, ok := err.(*sshPkg.PassphraseMissingError); ok {
slog.Info("Key requires passphrase, trying RDP password from config/keyring")
key, err = sshPkg.ParsePrivateKeyWithPassphrase(keyBytes, []byte(a.cfg.RDPPassword))
if err != nil {
slog.Error("Failed to parse key with passphrase", "error", err)
return nil
}
} else {
slog.Error("Failed to parse private key", "error", err)
return nil
}
}
return sshPkg.PublicKeys(key)
}
func (a *App) wakeWorkstation() {
slog.Info("Attempting to wake workstation...")
innerSSHCmd := fmt.Sprintf("ssh -tt %s@%s \"wakeonlan %s && echo 'Packet sent' && exit\"",
a.cfg.JumpUser, a.cfg.JumpHost, a.cfg.WorkstationMac)
args := []string{
"-tt",
"-p", fmt.Sprintf("%d", a.cfg.SSHPort),
fmt.Sprintf("%s@%s", a.cfg.SSHUser, a.cfg.SSHHost),
innerSSHCmd,
}
_ = a.runCommand("ssh", args...)
}
func (a *App) connectToJump() {
args := []string{
"-tt",
"-L", fmt.Sprintf("2048:%s:22", a.cfg.WorkstationHost),
"-p", fmt.Sprintf("%d", a.cfg.SSHPort),
fmt.Sprintf("%s@%s", a.cfg.SSHUser, a.cfg.SSHHost),
}
_ = a.runCommand("ssh", args...)
}
func (a *App) connectToWorkstation() {
args := []string{
"-tt",
"-L", fmt.Sprintf("6000:%s:3389", a.cfg.WorkstationHost),
"-p", "2048",
fmt.Sprintf("%s@127.0.0.1", a.cfg.WorkstationUser),
}
_ = a.runCommand("ssh", args...)
}
func (a *App) startRDPConnection() {
args := []string{
fmt.Sprintf("/u:%s", a.cfg.RDPUser),
fmt.Sprintf("/p:%s", a.cfg.RDPPassword),
"/v:127.0.0.1:6000",
"/size:3000x1350",
"+clipboard",
"/dynamic-resolution",
}
_ = a.runCommand("xfreerdp", args...)
}
func (a *App) killForwardings() error {
ports := []string{"2048", "6000"}
killedSomething := false
var lastErr error
slog.Info(fmt.Sprintf("Attempting to kill processes listening on ports: %v", strings.Join(ports, ", ")))
for _, port := range ports {
cmd := exec.Command("lsof", "-i", "tcp:"+port, "-t")
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
slog.Info(fmt.Sprintf("No process found listening on port %s.", port))
} else {
slog.Warn(fmt.Sprintf("'lsof' command failed for port %s: %v", port, err))
lastErr = fmt.Errorf("lsof failed for port %s: %w", port, err)
}
continue
}
pids := strings.SplitSeq(strings.TrimSpace(string(output)), "\n")
for pidStr := range pids {
pid := strings.TrimSpace(pidStr)
if pid == "" {
continue
}
slog.Info(fmt.Sprintf("Found process PID %s on port %s. Attempting to kill...", pid, port))
killCmd := exec.Command("kill", pid)
if err := killCmd.Run(); err != nil {
slog.Warn(fmt.Sprintf("Failed to kill PID %s (port %s): %v. Trying kill -9...", pid, port, err))
forceKillCmd := exec.Command("kill", "-9", pid)
if err := forceKillCmd.Run(); err != nil {
slog.Error(fmt.Sprintf("Failed to force kill PID %s (port %s): %v", pid, port, err))
lastErr = fmt.Errorf("kill -9 failed for PID %s: %w", pid, err)
} else {
slog.Info(fmt.Sprintf("Force killed PID %s (port %s).", pid, port))
killedSomething = true
}
} else {
slog.Info(fmt.Sprintf("Killed PID %s (port %s).", pid, port))
killedSomething = true
}
}
}
if killedSomething {
slog.Info("Finished attempting to kill forwarding processes.")
} else {
slog.Info("No forwarding processes found or killed.")
}
return lastErr
}