work/internal/store/store.go

376 lines
10 KiB
Go

package store
import (
"context"
"database/sql"
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"golang.org/x/text/cases"
"golang.org/x/text/language"
_ "modernc.org/sqlite"
)
const (
TagWork = "work"
TagBreak = "break"
)
type TimeEntry struct {
ID int64
Tag string
StartTime time.Time
EndTime sql.NullTime
}
type Store struct {
db *sql.DB
dbPath string
}
func NewStore() (*Store, error) {
dbPath, err := ensureDatabasePath()
if err != nil {
return nil, fmt.Errorf("could not determine database path: %w", err)
}
slog.Debug("Using database at:", "path", dbPath)
db, err := sql.Open("sqlite", fmt.Sprintf("%s?_pragma=journal_mode(WAL)", dbPath))
if err != nil {
return nil, fmt.Errorf("failed to open database '%s': %w", dbPath, err)
}
if err = db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to connect to database '%s': %w", dbPath, err)
}
if err := migrate(db); err != nil {
db.Close()
return nil, fmt.Errorf("migration failed: %w", err)
}
return &Store{db: db, dbPath: dbPath}, nil
}
func migrate(db *sql.DB) error {
createTableSQL := `
CREATE TABLE IF NOT EXISTS time_entries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tag TEXT NOT NULL CHECK(tag <> ''),
start_time DATETIME NOT NULL,
end_time DATETIME NULL,
CHECK (end_time IS NULL OR end_time >= start_time)
);`
if _, err := db.Exec(createTableSQL); err != nil {
return fmt.Errorf("failed to create table 'time_entries': %w", err)
}
createIndexSQL := `CREATE INDEX IF NOT EXISTS idx_time_entries_start_time ON time_entries (start_time);`
if _, err := db.Exec(createIndexSQL); err != nil {
slog.Warn("Failed to create index on start_time:", "error", err)
}
return nil
}
func ensureDatabasePath() (string, error) {
configDir, err := os.UserConfigDir()
if err != nil {
return "", fmt.Errorf("could not get user config dir: %w", err)
}
workConfigDir := filepath.Join(configDir, "work")
if err := os.MkdirAll(workConfigDir, 0o750); err != nil {
return "", fmt.Errorf("failed to create config directory '%s': %w", workConfigDir, err)
}
return filepath.Join(workConfigDir, "worktime.sqlite"), nil
}
func (s *Store) Close() error {
if s.db != nil {
slog.Debug("Closing database connection", "path", s.dbPath)
return s.db.Close()
}
return nil
}
func (s *Store) stopCurrentEntry(ctx context.Context, now time.Time) (bool, error) {
query := `UPDATE time_entries SET end_time = ? WHERE end_time IS NULL;`
result, err := s.db.ExecContext(ctx, query, now)
if err != nil {
return false, fmt.Errorf("failed to execute stop current entry query: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return false, fmt.Errorf("failed to get affected rows: %w", err)
}
return rowsAffected > 0, nil
}
func (s *Store) StartTracking(ctx context.Context, tag string) error {
if tag == "" {
return fmt.Errorf("cannot start tracking with an empty tag")
}
now := time.Now()
stopped, err := s.stopCurrentEntry(ctx, now)
if err != nil {
return err
}
if stopped {
slog.Info("Stopped previous time entry.")
}
query := `INSERT INTO time_entries (tag, start_time, end_time) VALUES (?, ?, NULL);`
_, err = s.db.ExecContext(ctx, query, tag, now)
if err != nil {
return fmt.Errorf("failed to start tracking tag '%s': %w", tag, err)
}
slog.Info(fmt.Sprintf("Started tracking: %s", tag))
return nil
}
func (s *Store) StopTracking(ctx context.Context) error {
now := time.Now()
stopped, err := s.stopCurrentEntry(ctx, now)
if err != nil {
return err
}
if stopped {
slog.Info(fmt.Sprintf("Stopped tracking at %s", now.Format(time.RFC3339)))
} else {
slog.Info("No active time entry found to stop.")
}
return nil
}
func (s *Store) LogFullDay(ctx context.Context, tag string, date time.Time) error {
if tag == "" {
return fmt.Errorf("cannot log full day with an empty tag")
}
tag = strings.ToLower(tag)
location := date.Location()
dayStart := time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, location)
dayEnd := dayStart.Add(24 * time.Hour)
_, err := s.stopCurrentEntry(ctx, dayStart)
if err != nil {
slog.Warn("Failed to stop current entry before logging full day", "error", err)
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
query := `INSERT INTO time_entries (tag, start_time, end_time) VALUES (?, ?, ?);`
if _, err := tx.ExecContext(ctx, query, tag, dayStart, dayEnd); err != nil {
return fmt.Errorf("failed to insert full-day entry: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
titleCaser := cases.Title(language.English)
slog.Info(fmt.Sprintf("Successfully logged full day entry: Tag='%s', Date='%s'", tag, dayStart.Format("2006-01-02")))
fmt.Printf("Successfully logged '%s' for %s.\n", titleCaser.String(tag), dayStart.Format("2006-01-02"))
return nil
}
func (s *Store) GetEntriesInRange(ctx context.Context, start, end time.Time) ([]TimeEntry, error) {
if start.IsZero() || end.IsZero() || end.Before(start) {
return nil, fmt.Errorf("invalid time range: start=%v, end=%v", start, end)
}
query := `
SELECT id, tag, start_time, end_time
FROM time_entries
WHERE start_time >= ? AND start_time < ?
ORDER BY start_time ASC;`
rows, err := s.db.QueryContext(ctx, query, start, end)
if err != nil {
return nil, fmt.Errorf("failed to query entries in range [%v, %v): %w", start, end, err)
}
defer rows.Close()
var entries []TimeEntry
for rows.Next() {
var entry TimeEntry
if err := rows.Scan(&entry.ID, &entry.Tag, &entry.StartTime, &entry.EndTime); err != nil {
return nil, fmt.Errorf("failed to scan entry row: %w", err)
}
entries = append(entries, entry)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error during row iteration: %w", err)
}
return entries, nil
}
func (s *Store) CalculateSummary(ctx context.Context, period string) (map[string]time.Duration, error) {
start, end := GetTimeRangeFromPeriod(period)
if start.IsZero() {
return nil, fmt.Errorf("invalid period string: '%s'", period)
}
query := `
SELECT id, tag, start_time, end_time
FROM time_entries
WHERE (end_time IS NULL OR end_time > ?)
AND start_time < ?
ORDER BY start_time ASC;`
rows, err := s.db.QueryContext(ctx, query, start, end)
if err != nil {
return nil, fmt.Errorf("failed to query entries: %w", err)
}
defer rows.Close()
summary := make(map[string]time.Duration)
now := time.Now()
for rows.Next() {
var entry TimeEntry
if err := rows.Scan(&entry.ID, &entry.Tag, &entry.StartTime, &entry.EndTime); err != nil {
return nil, fmt.Errorf("failed to scan entry: %w", err)
}
effStart := entry.StartTime
if effStart.Before(start) {
effStart = start
}
effEnd := now
if entry.EndTime.Valid {
effEnd = entry.EndTime.Time
}
if effEnd.After(end) {
effEnd = end
}
if effEnd.After(effStart) {
summary[entry.Tag] += effEnd.Sub(effStart)
}
}
return summary, rows.Err()
}
func (s *Store) ShowSummary(ctx context.Context, period string) error {
summary, err := s.CalculateSummary(ctx, period)
if err != nil {
return err
}
start, _ := GetTimeRangeFromPeriod(period)
titlePeriod := period
if !start.IsZero() {
_, end := GetTimeRangeFromPeriod(period)
if period == ":day" || period == "today" {
titlePeriod = fmt.Sprintf("Today (%s)", start.Format("2006-01-02"))
} else if period == ":week" {
titlePeriod = fmt.Sprintf("Week starting %s", start.Format("Mon, 2006-01-02"))
} else if period == ":month" {
titlePeriod = fmt.Sprintf("Month %s", start.Format("January 2006"))
} else if period == ":year" {
titlePeriod = fmt.Sprintf("Year %d", start.Year())
} else if _, err := time.Parse("2006-01-02", period); err == nil {
titlePeriod = fmt.Sprintf("Day %s", start.Format("2006-01-02"))
} else {
titlePeriod = fmt.Sprintf("Period '%s' (%s to %s)", period, start.Format("2006-01-02"), end.Format("2006-01-02"))
}
}
fmt.Printf("\nTime Summary for %s\n", titlePeriod)
if len(summary) == 0 {
fmt.Println(" No recorded time entries for this period.")
return nil
}
tags := make([]string, 0, len(summary))
for tag := range summary {
tags = append(tags, tag)
}
titleCaser := cases.Title(language.English)
totalDuration := time.Duration(0)
fmt.Println("------------------------------")
for _, tag := range tags {
duration := summary[tag]
fmt.Printf(" %-12s: %s\n", titleCaser.String(tag), formatDuration(duration))
totalDuration += duration
}
fmt.Println("------------------------------")
fmt.Printf(" Total : %s\n\n", formatDuration(totalDuration))
return nil
}
func formatDuration(d time.Duration) string {
if d < 0 {
d = -d
sign := "-"
d = d.Round(time.Second)
h := int64(d.Hours())
m := int64(d.Minutes()) % 60
s := int64(d.Seconds()) % 60
return fmt.Sprintf("%s%02d:%02d:%02d", sign, h, m, s)
}
d = d.Round(time.Second)
h := int64(d.Hours())
m := int64(d.Minutes()) % 60
s := int64(d.Seconds()) % 60
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
func GetTimeRangeFromPeriod(period string) (time.Time, time.Time) {
now := time.Now()
year, month, day := now.Date()
loc := now.Location()
normalizedPeriod := strings.ToLower(strings.TrimPrefix(period, ":"))
switch normalizedPeriod {
case "week":
weekday := now.Weekday()
daysToMonday := time.Duration(weekday - time.Monday)
if weekday == time.Sunday {
daysToMonday = 6
}
start := time.Date(year, month, day, 0, 0, 0, 0, loc).Add(-daysToMonday * 24 * time.Hour)
end := start.Add(7 * 24 * time.Hour)
return start, end
case "month":
start := time.Date(year, month, 1, 0, 0, 0, 0, loc)
end := start.AddDate(0, 1, 0)
return start, end
case "year":
start := time.Date(year, 1, 1, 0, 0, 0, 0, loc)
end := start.AddDate(1, 0, 0)
return start, end
case "day", "today":
start := time.Date(year, month, day, 0, 0, 0, 0, loc)
end := start.AddDate(0, 0, 1)
return start, end
default:
if t, err := time.ParseInLocation("2006-01-02", period, loc); err == nil {
start := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, loc)
end := start.AddDate(0, 0, 1)
return start, end
}
slog.Warn(fmt.Sprintf("Unrecognized period string '%s'. Cannot calculate time range.", period))
return time.Time{}, time.Time{}
}
}
func (s *Store) DB() *sql.DB {
return s.db
}