school-timetracker/backend/middleware.go

205 lines
4.8 KiB
Go

package main
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"golang.org/x/time/rate"
)
var jwtSecret []byte
func init() {
secret := os.Getenv("JWT_SECRET")
if secret == "" {
panic("JWT_SECRET environment variable is required")
}
jwtSecret = []byte(secret)
}
func createToken(userID int, username string, isAdmin bool) (string, error) {
claims := Claims{
UserID: userID,
Username: username,
IsAdmin: isAdmin,
}
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
claimsWithExp := map[string]any{
"user_id": claims.UserID,
"username": claims.Username,
"is_admin": claims.IsAdmin,
"exp": time.Now().Add(2 * time.Hour).Unix(),
}
payload, _ := json.Marshal(claimsWithExp)
payloadEncoded := base64.RawURLEncoding.EncodeToString(payload)
message := header + "." + payloadEncoded
h := hmac.New(sha256.New, jwtSecret)
h.Write([]byte(message))
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
return message + "." + signature, nil
}
func verifyToken(tokenString string) (*Claims, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
message := parts[0] + "." + parts[1]
h := hmac.New(sha256.New, jwtSecret)
h.Write([]byte(message))
expectedSignature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
if parts[2] != expectedSignature {
return nil, fmt.Errorf("invalid signature")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
var claimsMap map[string]any
if err := json.Unmarshal(payload, &claimsMap); err != nil {
return nil, err
}
if exp, ok := claimsMap["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return nil, fmt.Errorf("token expired")
}
}
claims := &Claims{
UserID: int(claimsMap["user_id"].(float64)),
Username: claimsMap["username"].(string),
IsAdmin: claimsMap["is_admin"].(bool),
}
return claims, nil
}
func JWTMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
claims, err := verifyToken(tokenString)
if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("is_admin", claims.IsAdmin)
c.Logger().Infof("Authenticated user: ID=%d, Username=%s", claims.UserID, claims.Username)
return next(c)
}
}
}
func AdminMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
isAdmin, ok := c.Get("is_admin").(bool)
if !ok || !isAdmin {
return echo.NewHTTPError(http.StatusForbidden, "Access denied")
}
return next(c)
}
}
}
func CustomLogger() echo.MiddlewareFunc {
return middleware.LoggerWithConfig(middleware.LoggerConfig{
Format: "${time_rfc3339} | ${status} | ${latency_human} | ${method} ${uri}\n",
})
}
type LoginRateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.Mutex
}
func NewLoginRateLimiter() *LoginRateLimiter {
limiter := &LoginRateLimiter{
limiters: make(map[string]*rate.Limiter),
}
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
limiter.mu.Lock()
limiter.limiters = make(map[string]*rate.Limiter)
limiter.mu.Unlock()
}
}()
return limiter
}
func (l *LoginRateLimiter) GetLimiter(ip string) *rate.Limiter {
l.mu.Lock()
defer l.mu.Unlock()
limiter, exists := l.limiters[ip]
if !exists {
limiter = rate.NewLimiter(rate.Every(time.Minute/5), 5)
l.limiters[ip] = limiter
}
return limiter
}
func (l *LoginRateLimiter) Middleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ip := c.RealIP()
limiter := l.GetLimiter(ip)
if !limiter.Allow() {
return echo.NewHTTPError(http.StatusTooManyRequests, "Too many login attempts. Please try again later.")
}
return next(c)
}
}
}
func HTTPSRedirectMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Nur in Production aktivieren
if os.Getenv("ENVIRONMENT") == "production" {
if c.Request().Header.Get("X-Forwarded-Proto") != "https" {
return c.Redirect(http.StatusMovedPermanently,
"https://"+c.Request().Host+c.Request().RequestURI)
}
}
return next(c)
}
}
}