school-timetracker/backend/middleware.go

125 lines
3.4 KiB
Go

package main
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
var jwtSecret = []byte("your-secret-key-change-in-production")
type Claims struct {
UserID int `json:"user_id"`
Username string `json:"username"`
IsAdmin bool `json:"is_admin"`
Exp int64 `json:"exp"`
}
func createToken(userID int, username string, isAdmin bool) (string, error) {
claims := Claims{
UserID: userID,
Username: username,
IsAdmin: isAdmin,
Exp: time.Now().Add(24 * time.Hour).Unix(),
}
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
payload, _ := json.Marshal(claims)
payloadEncoded := base64.RawURLEncoding.EncodeToString(payload)
message := header + "." + payloadEncoded
h := hmac.New(sha256.New, jwtSecret)
h.Write([]byte(message))
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
return message + "." + signature, nil
}
func verifyToken(tokenString string) (*Claims, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
message := parts[0] + "." + parts[1]
h := hmac.New(sha256.New, jwtSecret)
h.Write([]byte(message))
expectedSignature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
if parts[2] != expectedSignature {
return nil, fmt.Errorf("invalid signature")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
var claims Claims
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, err
}
if time.Now().Unix() > claims.Exp {
return nil, fmt.Errorf("token expired")
}
return &claims, nil
}
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
claims, err := verifyToken(tokenString)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
r.Header.Set("X-User-ID", fmt.Sprintf("%d", claims.UserID))
r.Header.Set("X-Username", claims.Username)
r.Header.Set("X-Is-Admin", fmt.Sprintf("%t", claims.IsAdmin))
next(w, r)
}
}
func AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
return AuthMiddleware(func(w http.ResponseWriter, r *http.Request) {
isAdmin := r.Header.Get("X-Is-Admin") == "true"
if !isAdmin {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
next(w, r)
})
}
func CORS(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next(w, r)
}
}