mirror of
https://github.com/Theodor-Springmann-Stiftung/musenalm.git
synced 2025-10-29 09:15:33 +00:00
189 lines
4.7 KiB
Go
189 lines
4.7 KiB
Go
package security
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// --- NonceCache ---
|
|
|
|
type nonceEntry struct {
|
|
expiresAt time.Time
|
|
}
|
|
|
|
// NonceCache stores nonces and their expiration times.
|
|
// It is safe for concurrent use.
|
|
type NonceCache struct {
|
|
mu sync.RWMutex
|
|
nonces map[string]nonceEntry
|
|
defaultExpiration time.Duration
|
|
cleanupInterval time.Duration
|
|
stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop
|
|
}
|
|
|
|
// NewNonceCache creates a new in-memory nonce cache.
|
|
// defaultExpiration: The default duration for which a nonce is valid.
|
|
// cleanupInterval: How often to scan for and remove expired nonces.
|
|
func NewNonceCache(defaultExpiration, cleanupInterval time.Duration) *NonceCache {
|
|
if defaultExpiration <= 0 {
|
|
defaultExpiration = 15 * time.Minute // Default to 15 minutes if invalid
|
|
}
|
|
if cleanupInterval <= 0 {
|
|
cleanupInterval = 5 * time.Minute // Default to 5 minutes if invalid
|
|
}
|
|
nc := &NonceCache{
|
|
nonces: make(map[string]nonceEntry),
|
|
defaultExpiration: defaultExpiration,
|
|
cleanupInterval: cleanupInterval,
|
|
stopCleanup: make(chan struct{}),
|
|
}
|
|
go nc.startCleanupRoutine()
|
|
return nc
|
|
}
|
|
|
|
func (nc *NonceCache) Add(nonce string) {
|
|
nc.addWithExpiration(nonce, nc.defaultExpiration)
|
|
}
|
|
|
|
func (nc *NonceCache) addWithExpiration(nonce string, expiresIn time.Duration) {
|
|
if nonce == "" {
|
|
return
|
|
}
|
|
nc.mu.Lock()
|
|
defer nc.mu.Unlock()
|
|
nc.nonces[nonce] = nonceEntry{expiresAt: time.Now().Add(expiresIn)}
|
|
}
|
|
|
|
func (nc *NonceCache) Use(nonce string) bool {
|
|
if nonce == "" {
|
|
return false
|
|
}
|
|
nc.mu.Lock()
|
|
defer nc.mu.Unlock()
|
|
|
|
entry, exists := nc.nonces[nonce]
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
if time.Now().After(entry.expiresAt) {
|
|
delete(nc.nonces, nonce)
|
|
return false
|
|
}
|
|
|
|
delete(nc.nonces, nonce)
|
|
return true
|
|
}
|
|
|
|
func (nc *NonceCache) startCleanupRoutine() {
|
|
ticker := time.NewTicker(nc.cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
nc.cleanupExpiredNonces()
|
|
case <-nc.stopCleanup:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (nc *NonceCache) cleanupExpiredNonces() {
|
|
nc.mu.Lock()
|
|
defer nc.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
slog.Debug("Cleaning up expired nonces", "current_time", now, "nonces_count", len(nc.nonces))
|
|
for nonce, entry := range nc.nonces {
|
|
if now.After(entry.expiresAt) {
|
|
delete(nc.nonces, nonce)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (nc *NonceCache) StopCleanup() {
|
|
select {
|
|
case <-nc.stopCleanup:
|
|
default:
|
|
close(nc.stopCleanup)
|
|
}
|
|
}
|
|
|
|
const (
|
|
defaultServerSecretSize = 32 // bytes
|
|
defaultNonceSize = 16 // bytes for raw nonce before encoding
|
|
)
|
|
|
|
type CSRFProtector struct {
|
|
serverSecret []byte
|
|
nonceCache *NonceCache
|
|
}
|
|
|
|
func NewCSRFProtector(nonceExpiration, nonceCleanupInterval time.Duration) (*CSRFProtector, error) {
|
|
secretToUse := make([]byte, defaultServerSecretSize)
|
|
if _, err := rand.Read(secretToUse); err != nil {
|
|
return nil, fmt.Errorf("failed to generate server secret: %w", err)
|
|
}
|
|
|
|
return &CSRFProtector{
|
|
serverSecret: secretToUse,
|
|
nonceCache: NewNonceCache(nonceExpiration, nonceCleanupInterval),
|
|
}, nil
|
|
}
|
|
|
|
func (p *CSRFProtector) GenerateTokenBundle() (nonceB64 string, validationTokenB64 string, err error) {
|
|
nonceBytes := make([]byte, defaultNonceSize)
|
|
if _, errRand := rand.Read(nonceBytes); errRand != nil {
|
|
return "", "", fmt.Errorf("failed to generate nonce bytes: %w", errRand)
|
|
}
|
|
nonceB64 = base64.URLEncoding.EncodeToString(nonceBytes)
|
|
|
|
p.nonceCache.Add(nonceB64)
|
|
|
|
mac := hmac.New(sha256.New, p.serverSecret)
|
|
mac.Write([]byte(nonceB64)) // Sign the base64 encoded nonce string
|
|
validationTokenBytes := mac.Sum(nil)
|
|
validationTokenB64 = base64.URLEncoding.EncodeToString(validationTokenBytes)
|
|
|
|
return nonceB64, validationTokenB64, nil
|
|
}
|
|
|
|
func (p *CSRFProtector) ValidateTokenBundle(nonceSubmittedB64 string, validationTokenSubmittedB64 string) (bool, error) {
|
|
if nonceSubmittedB64 == "" || validationTokenSubmittedB64 == "" {
|
|
return false, errors.New("submitted nonce or validation token is empty")
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, p.serverSecret)
|
|
mac.Write([]byte(nonceSubmittedB64))
|
|
expectedMACTokenBytes := mac.Sum(nil)
|
|
|
|
validationTokenSubmittedBytes, err := base64.URLEncoding.DecodeString(validationTokenSubmittedB64)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to decode submitted validation token: %w", err)
|
|
}
|
|
|
|
if !hmac.Equal(validationTokenSubmittedBytes, expectedMACTokenBytes) {
|
|
return false, errors.New("validation token (HMAC) mismatch")
|
|
}
|
|
|
|
if !p.nonceCache.Use(nonceSubmittedB64) {
|
|
return false, errors.New("nonce not found in cache, expired, or already used")
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (p *CSRFProtector) StopNonceCacheCleanup() {
|
|
if p.nonceCache != nil {
|
|
p.nonceCache.StopCleanup()
|
|
}
|
|
}
|