package store import ( "context" "database/sql" "fmt" "log/slog" "os" "path/filepath" "strings" "time" "golang.org/x/text/cases" "golang.org/x/text/language" _ "modernc.org/sqlite" ) const ( TagWork = "work" TagBreak = "break" ) type TimeEntry struct { ID int64 Tag string StartTime time.Time EndTime sql.NullTime } type Store struct { db *sql.DB dbPath string } func NewStore() (*Store, error) { dbPath, err := ensureDatabasePath() if err != nil { return nil, fmt.Errorf("could not determine database path: %w", err) } slog.Debug("Using database at:", "path", dbPath) db, err := sql.Open("sqlite", fmt.Sprintf("%s?_pragma=journal_mode(WAL)", dbPath)) if err != nil { return nil, fmt.Errorf("failed to open database '%s': %w", dbPath, err) } if err = db.Ping(); err != nil { db.Close() return nil, fmt.Errorf("failed to connect to database '%s': %w", dbPath, err) } if err := migrate(db); err != nil { db.Close() return nil, fmt.Errorf("migration failed: %w", err) } return &Store{db: db, dbPath: dbPath}, nil } func migrate(db *sql.DB) error { createTableSQL := ` CREATE TABLE IF NOT EXISTS time_entries ( id INTEGER PRIMARY KEY AUTOINCREMENT, tag TEXT NOT NULL CHECK(tag <> ''), start_time DATETIME NOT NULL, end_time DATETIME NULL, CHECK (end_time IS NULL OR end_time >= start_time) );` if _, err := db.Exec(createTableSQL); err != nil { return fmt.Errorf("failed to create table 'time_entries': %w", err) } createIndexSQL := `CREATE INDEX IF NOT EXISTS idx_time_entries_start_time ON time_entries (start_time);` if _, err := db.Exec(createIndexSQL); err != nil { slog.Warn("Failed to create index on start_time:", "error", err) } return nil } func ensureDatabasePath() (string, error) { configDir, err := os.UserConfigDir() if err != nil { return "", fmt.Errorf("could not get user config dir: %w", err) } workConfigDir := filepath.Join(configDir, "work") if err := os.MkdirAll(workConfigDir, 0o750); err != nil { return "", fmt.Errorf("failed to create config directory '%s': %w", workConfigDir, err) } return filepath.Join(workConfigDir, "worktime.sqlite"), nil } func (s *Store) Close() error { if s.db != nil { slog.Debug("Closing database connection", "path", s.dbPath) return s.db.Close() } return nil } func (s *Store) stopCurrentEntry(ctx context.Context, now time.Time) (bool, error) { query := `UPDATE time_entries SET end_time = ? WHERE end_time IS NULL;` result, err := s.db.ExecContext(ctx, query, now) if err != nil { return false, fmt.Errorf("failed to execute stop current entry query: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return false, fmt.Errorf("failed to get affected rows: %w", err) } return rowsAffected > 0, nil } func (s *Store) StartTracking(ctx context.Context, tag string) error { if tag == "" { return fmt.Errorf("cannot start tracking with an empty tag") } now := time.Now() stopped, err := s.stopCurrentEntry(ctx, now) if err != nil { return err } if stopped { slog.Info("Stopped previous time entry.") } query := `INSERT INTO time_entries (tag, start_time, end_time) VALUES (?, ?, NULL);` _, err = s.db.ExecContext(ctx, query, tag, now) if err != nil { return fmt.Errorf("failed to start tracking tag '%s': %w", tag, err) } slog.Info(fmt.Sprintf("Started tracking: %s", tag)) return nil } func (s *Store) StopTracking(ctx context.Context) error { now := time.Now() stopped, err := s.stopCurrentEntry(ctx, now) if err != nil { return err } if stopped { slog.Info(fmt.Sprintf("Stopped tracking at %s", now.Format(time.RFC3339))) } else { slog.Info("No active time entry found to stop.") } return nil } func (s *Store) LogFullDay(ctx context.Context, tag string, date time.Time) error { if tag == "" { return fmt.Errorf("cannot log full day with an empty tag") } tag = strings.ToLower(tag) location := date.Location() dayStart := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, location) dayEnd := dayStart.Add(24 * time.Hour) _, err := s.stopCurrentEntry(ctx, dayStart) if err != nil { slog.Warn("Failed to stop current entry before logging full day", "error", err) } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return err } defer tx.Rollback() query := `INSERT INTO time_entries (tag, start_time, end_time) VALUES (?, ?, ?);` if _, err := tx.ExecContext(ctx, query, tag, dayStart, dayEnd); err != nil { return fmt.Errorf("failed to insert full-day entry: %w", err) } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to commit transaction: %w", err) } titleCaser := cases.Title(language.English) slog.Info(fmt.Sprintf("Successfully logged full day entry: Tag='%s', Date='%s'", tag, dayStart.Format("2006-01-02"))) fmt.Printf("Successfully logged '%s' for %s.\n", titleCaser.String(tag), dayStart.Format("2006-01-02")) return nil } func (s *Store) GetEntriesInRange(ctx context.Context, start, end time.Time) ([]TimeEntry, error) { if start.IsZero() || end.IsZero() || end.Before(start) { return nil, fmt.Errorf("invalid time range: start=%v, end=%v", start, end) } query := ` SELECT id, tag, start_time, end_time FROM time_entries WHERE start_time >= ? AND start_time < ? ORDER BY start_time ASC;` rows, err := s.db.QueryContext(ctx, query, start, end) if err != nil { return nil, fmt.Errorf("failed to query entries in range [%v, %v): %w", start, end, err) } defer rows.Close() var entries []TimeEntry for rows.Next() { var entry TimeEntry if err := rows.Scan(&entry.ID, &entry.Tag, &entry.StartTime, &entry.EndTime); err != nil { return nil, fmt.Errorf("failed to scan entry row: %w", err) } entries = append(entries, entry) } if err = rows.Err(); err != nil { return nil, fmt.Errorf("error during row iteration: %w", err) } return entries, nil } func (s *Store) CalculateSummary(ctx context.Context, period string) (map[string]time.Duration, error) { start, end := GetTimeRangeFromPeriod(period) if start.IsZero() { return nil, fmt.Errorf("invalid period string: '%s'", period) } query := ` SELECT id, tag, start_time, end_time FROM time_entries WHERE (end_time IS NULL OR end_time > ?) AND start_time < ? ORDER BY start_time ASC;` rows, err := s.db.QueryContext(ctx, query, start, end) if err != nil { return nil, fmt.Errorf("failed to query entries: %w", err) } defer rows.Close() summary := make(map[string]time.Duration) now := time.Now() for rows.Next() { var entry TimeEntry if err := rows.Scan(&entry.ID, &entry.Tag, &entry.StartTime, &entry.EndTime); err != nil { return nil, fmt.Errorf("failed to scan entry: %w", err) } effStart := entry.StartTime if effStart.Before(start) { effStart = start } effEnd := now if entry.EndTime.Valid { effEnd = entry.EndTime.Time } if effEnd.After(end) { effEnd = end } if effEnd.After(effStart) { summary[entry.Tag] += effEnd.Sub(effStart) } } return summary, rows.Err() } func (s *Store) ShowSummary(ctx context.Context, period string) error { summary, err := s.CalculateSummary(ctx, period) if err != nil { return err } start, _ := GetTimeRangeFromPeriod(period) titlePeriod := period if !start.IsZero() { _, end := GetTimeRangeFromPeriod(period) if period == ":day" || period == "today" { titlePeriod = fmt.Sprintf("Today (%s)", start.Format("2006-01-02")) } else if period == ":week" { titlePeriod = fmt.Sprintf("Week starting %s", start.Format("Mon, 2006-01-02")) } else if period == ":month" { titlePeriod = fmt.Sprintf("Month %s", start.Format("January 2006")) } else if period == ":year" { titlePeriod = fmt.Sprintf("Year %d", start.Year()) } else if _, err := time.Parse("2006-01-02", period); err == nil { titlePeriod = fmt.Sprintf("Day %s", start.Format("2006-01-02")) } else { titlePeriod = fmt.Sprintf("Period '%s' (%s to %s)", period, start.Format("2006-01-02"), end.Format("2006-01-02")) } } fmt.Printf("\nTime Summary for %s\n", titlePeriod) if len(summary) == 0 { fmt.Println(" No recorded time entries for this period.") return nil } tags := make([]string, 0, len(summary)) for tag := range summary { tags = append(tags, tag) } titleCaser := cases.Title(language.English) totalDuration := time.Duration(0) fmt.Println("------------------------------") for _, tag := range tags { duration := summary[tag] fmt.Printf(" %-12s: %s\n", titleCaser.String(tag), formatDuration(duration)) totalDuration += duration } fmt.Println("------------------------------") fmt.Printf(" Total : %s\n\n", formatDuration(totalDuration)) return nil } func formatDuration(d time.Duration) string { if d < 0 { d = -d sign := "-" d = d.Round(time.Second) h := int64(d.Hours()) m := int64(d.Minutes()) % 60 s := int64(d.Seconds()) % 60 return fmt.Sprintf("%s%02d:%02d:%02d", sign, h, m, s) } d = d.Round(time.Second) h := int64(d.Hours()) m := int64(d.Minutes()) % 60 s := int64(d.Seconds()) % 60 return fmt.Sprintf("%02d:%02d:%02d", h, m, s) } func GetTimeRangeFromPeriod(period string) (time.Time, time.Time) { now := time.Now() year, month, day := now.Date() loc := now.Location() normalizedPeriod := strings.ToLower(strings.TrimPrefix(period, ":")) switch normalizedPeriod { case "week": weekday := now.Weekday() daysToMonday := time.Duration(weekday - time.Monday) if weekday == time.Sunday { daysToMonday = 6 } start := time.Date(year, month, day, 0, 0, 0, 0, loc).Add(-daysToMonday * 24 * time.Hour) end := start.Add(7 * 24 * time.Hour) return start, end case "month": start := time.Date(year, month, 1, 0, 0, 0, 0, loc) end := start.AddDate(0, 1, 0) return start, end case "year": start := time.Date(year, 1, 1, 0, 0, 0, 0, loc) end := start.AddDate(1, 0, 0) return start, end case "day", "today": start := time.Date(year, month, day, 0, 0, 0, 0, loc) end := start.AddDate(0, 0, 1) return start, end default: if t, err := time.ParseInLocation("2006-01-02", period, loc); err == nil { start := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc) end := start.AddDate(0, 0, 1) return start, end } slog.Warn(fmt.Sprintf("Unrecognized period string '%s'. Cannot calculate time range.", period)) return time.Time{}, time.Time{} } } func (s *Store) DB() *sql.DB { return s.db }