Files
musenalm/helpers/security/nonce_cache.go
2025-05-22 21:12:29 +02:00

189 lines
4.7 KiB
Go

package security
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"log/slog"
"sync"
"time"
)
// --- NonceCache ---
type nonceEntry struct {
expiresAt time.Time
}
// NonceCache stores nonces and their expiration times.
// It is safe for concurrent use.
type NonceCache struct {
mu sync.RWMutex
nonces map[string]nonceEntry
defaultExpiration time.Duration
cleanupInterval time.Duration
stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop
}
// NewNonceCache creates a new in-memory nonce cache.
// defaultExpiration: The default duration for which a nonce is valid.
// cleanupInterval: How often to scan for and remove expired nonces.
func NewNonceCache(defaultExpiration, cleanupInterval time.Duration) *NonceCache {
if defaultExpiration <= 0 {
defaultExpiration = 15 * time.Minute // Default to 15 minutes if invalid
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute // Default to 5 minutes if invalid
}
nc := &NonceCache{
nonces: make(map[string]nonceEntry),
defaultExpiration: defaultExpiration,
cleanupInterval: cleanupInterval,
stopCleanup: make(chan struct{}),
}
go nc.startCleanupRoutine()
return nc
}
func (nc *NonceCache) Add(nonce string) {
nc.addWithExpiration(nonce, nc.defaultExpiration)
}
func (nc *NonceCache) addWithExpiration(nonce string, expiresIn time.Duration) {
if nonce == "" {
return
}
nc.mu.Lock()
defer nc.mu.Unlock()
nc.nonces[nonce] = nonceEntry{expiresAt: time.Now().Add(expiresIn)}
}
func (nc *NonceCache) Use(nonce string) bool {
if nonce == "" {
return false
}
nc.mu.Lock()
defer nc.mu.Unlock()
entry, exists := nc.nonces[nonce]
if !exists {
return false
}
if time.Now().After(entry.expiresAt) {
delete(nc.nonces, nonce)
return false
}
delete(nc.nonces, nonce)
return true
}
func (nc *NonceCache) startCleanupRoutine() {
ticker := time.NewTicker(nc.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
nc.cleanupExpiredNonces()
case <-nc.stopCleanup:
return
}
}
}
func (nc *NonceCache) cleanupExpiredNonces() {
nc.mu.Lock()
defer nc.mu.Unlock()
now := time.Now()
slog.Debug("Cleaning up expired nonces", "current_time", now, "nonces_count", len(nc.nonces))
for nonce, entry := range nc.nonces {
if now.After(entry.expiresAt) {
delete(nc.nonces, nonce)
}
}
}
func (nc *NonceCache) StopCleanup() {
select {
case <-nc.stopCleanup:
default:
close(nc.stopCleanup)
}
}
const (
defaultServerSecretSize = 32 // bytes
defaultNonceSize = 16 // bytes for raw nonce before encoding
)
type CSRFProtector struct {
serverSecret []byte
nonceCache *NonceCache
}
func NewCSRFProtector(nonceExpiration, nonceCleanupInterval time.Duration) (*CSRFProtector, error) {
secretToUse := make([]byte, defaultServerSecretSize)
if _, err := rand.Read(secretToUse); err != nil {
return nil, fmt.Errorf("failed to generate server secret: %w", err)
}
return &CSRFProtector{
serverSecret: secretToUse,
nonceCache: NewNonceCache(nonceExpiration, nonceCleanupInterval),
}, nil
}
func (p *CSRFProtector) GenerateTokenBundle() (nonceB64 string, validationTokenB64 string, err error) {
nonceBytes := make([]byte, defaultNonceSize)
if _, errRand := rand.Read(nonceBytes); errRand != nil {
return "", "", fmt.Errorf("failed to generate nonce bytes: %w", errRand)
}
nonceB64 = base64.URLEncoding.EncodeToString(nonceBytes)
p.nonceCache.Add(nonceB64)
mac := hmac.New(sha256.New, p.serverSecret)
mac.Write([]byte(nonceB64)) // Sign the base64 encoded nonce string
validationTokenBytes := mac.Sum(nil)
validationTokenB64 = base64.URLEncoding.EncodeToString(validationTokenBytes)
return nonceB64, validationTokenB64, nil
}
func (p *CSRFProtector) ValidateTokenBundle(nonceSubmittedB64 string, validationTokenSubmittedB64 string) (bool, error) {
if nonceSubmittedB64 == "" || validationTokenSubmittedB64 == "" {
return false, errors.New("submitted nonce or validation token is empty")
}
mac := hmac.New(sha256.New, p.serverSecret)
mac.Write([]byte(nonceSubmittedB64))
expectedMACTokenBytes := mac.Sum(nil)
validationTokenSubmittedBytes, err := base64.URLEncoding.DecodeString(validationTokenSubmittedB64)
if err != nil {
return false, fmt.Errorf("failed to decode submitted validation token: %w", err)
}
if !hmac.Equal(validationTokenSubmittedBytes, expectedMACTokenBytes) {
return false, errors.New("validation token (HMAC) mismatch")
}
if !p.nonceCache.Use(nonceSubmittedB64) {
return false, errors.New("nonce not found in cache, expired, or already used")
}
return true, nil
}
func (p *CSRFProtector) StopNonceCacheCleanup() {
if p.nonceCache != nil {
p.nonceCache.StopCleanup()
}
}