180 lines
4.1 KiB
Go
180 lines
4.1 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var jwtSecret = []byte("your-secret-key-change-in-production")
|
|
|
|
type Claims struct {
|
|
UserID int `json:"user_id"`
|
|
Username string `json:"username"`
|
|
IsAdmin bool `json:"is_admin"`
|
|
Exp int64 `json:"exp"`
|
|
}
|
|
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
wroteHeader bool
|
|
written int64
|
|
}
|
|
|
|
func wrapResponseWriter(w http.ResponseWriter) *responseWriter {
|
|
return &responseWriter{
|
|
ResponseWriter: w,
|
|
status: http.StatusOK,
|
|
}
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
if rw.wroteHeader {
|
|
return
|
|
}
|
|
rw.status = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
rw.wroteHeader = true
|
|
}
|
|
|
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
|
if !rw.wroteHeader {
|
|
rw.WriteHeader(http.StatusOK)
|
|
}
|
|
n, err := rw.ResponseWriter.Write(b)
|
|
rw.written += int64(n)
|
|
return n, err
|
|
}
|
|
|
|
func createToken(userID int, username string, isAdmin bool) (string, error) {
|
|
claims := Claims{
|
|
UserID: userID,
|
|
Username: username,
|
|
IsAdmin: isAdmin,
|
|
Exp: time.Now().Add(24 * time.Hour).Unix(),
|
|
}
|
|
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
|
|
|
payload, _ := json.Marshal(claims)
|
|
payloadEncoded := base64.RawURLEncoding.EncodeToString(payload)
|
|
|
|
message := header + "." + payloadEncoded
|
|
|
|
h := hmac.New(sha256.New, jwtSecret)
|
|
h.Write([]byte(message))
|
|
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
|
|
|
return message + "." + signature, nil
|
|
}
|
|
|
|
func verifyToken(tokenString string) (*Claims, error) {
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("invalid token format")
|
|
}
|
|
|
|
message := parts[0] + "." + parts[1]
|
|
h := hmac.New(sha256.New, jwtSecret)
|
|
h.Write([]byte(message))
|
|
expectedSignature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
|
|
|
if parts[2] != expectedSignature {
|
|
return nil, fmt.Errorf("invalid signature")
|
|
}
|
|
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var claims Claims
|
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if time.Now().Unix() > claims.Exp {
|
|
return nil, fmt.Errorf("token expired")
|
|
}
|
|
|
|
return &claims, nil
|
|
}
|
|
|
|
func AuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
claims, err := verifyToken(tokenString)
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
r.Header.Set("X-User-ID", fmt.Sprintf("%d", claims.UserID))
|
|
r.Header.Set("X-Username", claims.Username)
|
|
r.Header.Set("X-Is-Admin", fmt.Sprintf("%t", claims.IsAdmin))
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
func AdminMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return AuthMiddleware(func(w http.ResponseWriter, r *http.Request) {
|
|
isAdmin := r.Header.Get("X-Is-Admin") == "true"
|
|
if !isAdmin {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
next(w, r)
|
|
})
|
|
}
|
|
|
|
func LoggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
wrapped := wrapResponseWriter(w)
|
|
|
|
defer func() {
|
|
slog.Info("http request",
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"query", r.URL.RawQuery,
|
|
"status", wrapped.status,
|
|
"duration_ms", time.Since(start).Milliseconds(),
|
|
"client_ip", r.RemoteAddr,
|
|
"user_agent", r.UserAgent(),
|
|
"bytes_written", wrapped.written,
|
|
)
|
|
}()
|
|
|
|
next(wrapped, r)
|
|
}
|
|
}
|
|
|
|
func CORS(next http.HandlerFunc) http.HandlerFunc {
|
|
return LoggingMiddleware(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
})
|
|
}
|