school-timetracker/backend/middleware.go
Patryk Hegenberg 84def05c50 feat: change Manual time entry to work with hours instead of start and end time
according to add hours as well the logic was changed to accept ours for
manual entries instead of start and end time. This allows to add
negative numbers as well, which are added to working time.
2025-11-08 11:27:42 +01:00

204 lines
4.7 KiB
Go

package main
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"golang.org/x/time/rate"
)
var jwtSecret []byte
func init() {
secret := os.Getenv("JWT_SECRET")
if secret == "" {
panic("JWT_SECRET environment variable is required")
}
jwtSecret = []byte(secret)
}
func createToken(userID int, username string, isAdmin bool) (string, error) {
claims := Claims{
UserID: userID,
Username: username,
IsAdmin: isAdmin,
}
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
claimsWithExp := map[string]any{
"user_id": claims.UserID,
"username": claims.Username,
"is_admin": claims.IsAdmin,
"exp": time.Now().Add(2 * time.Hour).Unix(),
}
payload, _ := json.Marshal(claimsWithExp)
payloadEncoded := base64.RawURLEncoding.EncodeToString(payload)
message := header + "." + payloadEncoded
h := hmac.New(sha256.New, jwtSecret)
h.Write([]byte(message))
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
return message + "." + signature, nil
}
func verifyToken(tokenString string) (*Claims, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
message := parts[0] + "." + parts[1]
h := hmac.New(sha256.New, jwtSecret)
h.Write([]byte(message))
expectedSignature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
if parts[2] != expectedSignature {
return nil, fmt.Errorf("invalid signature")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
var claimsMap map[string]any
if err := json.Unmarshal(payload, &claimsMap); err != nil {
return nil, err
}
if exp, ok := claimsMap["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return nil, fmt.Errorf("token expired")
}
}
claims := &Claims{
UserID: int(claimsMap["user_id"].(float64)),
Username: claimsMap["username"].(string),
IsAdmin: claimsMap["is_admin"].(bool),
}
return claims, nil
}
func JWTMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
claims, err := verifyToken(tokenString)
if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("is_admin", claims.IsAdmin)
c.Logger().Infof("Authenticated user: ID=%d, Username=%s", claims.UserID, claims.Username)
return next(c)
}
}
}
func AdminMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
isAdmin, ok := c.Get("is_admin").(bool)
if !ok || !isAdmin {
return echo.NewHTTPError(http.StatusForbidden, "Access denied")
}
return next(c)
}
}
}
func CustomLogger() echo.MiddlewareFunc {
return middleware.LoggerWithConfig(middleware.LoggerConfig{
Format: "${time_rfc3339} | ${status} | ${latency_human} | ${method} ${uri}\n",
})
}
type LoginRateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.Mutex
}
func NewLoginRateLimiter() *LoginRateLimiter {
limiter := &LoginRateLimiter{
limiters: make(map[string]*rate.Limiter),
}
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
limiter.mu.Lock()
limiter.limiters = make(map[string]*rate.Limiter)
limiter.mu.Unlock()
}
}()
return limiter
}
func (l *LoginRateLimiter) GetLimiter(ip string) *rate.Limiter {
l.mu.Lock()
defer l.mu.Unlock()
limiter, exists := l.limiters[ip]
if !exists {
limiter = rate.NewLimiter(rate.Every(time.Minute/5), 5)
l.limiters[ip] = limiter
}
return limiter
}
func (l *LoginRateLimiter) Middleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ip := c.RealIP()
limiter := l.GetLimiter(ip)
if !limiter.Allow() {
return echo.NewHTTPError(http.StatusTooManyRequests, "Too many login attempts. Please try again later.")
}
return next(c)
}
}
}
func HTTPSRedirectMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if os.Getenv("ENVIRONMENT") == "production" {
if c.Request().Header.Get("X-Forwarded-Proto") != "https" {
return c.Redirect(http.StatusMovedPermanently,
"https://"+c.Request().Host+c.Request().RequestURI)
}
}
return next(c)
}
}
}