user login & middleware complete

This commit is contained in:
Simon Martens
2025-05-22 21:12:29 +02:00
parent 3f57e7a18d
commit 36e34d9e7b
17 changed files with 808 additions and 26 deletions

View File

@@ -0,0 +1,145 @@
package collections
import (
"sync"
"time"
"github.com/Theodor-Springmann-Stiftung/musenalm/dbmodels"
)
type cacheEntry struct {
user dbmodels.FixedUser
session dbmodels.FixedSession
}
type UserSessionCache struct {
mu sync.RWMutex
capacity int
cache sync.Map
approximateSize int
cleanupInterval time.Duration
stopCleanupSignal chan struct{}
}
func NewUserSessionCache(capacity int, cleanupInterval time.Duration) *UserSessionCache {
if capacity <= 0 {
capacity = 1000
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute
}
cache := &UserSessionCache{
capacity: capacity,
cache: sync.Map{},
cleanupInterval: cleanupInterval,
stopCleanupSignal: make(chan struct{}),
}
go cache.startCleanupRoutine()
return cache
}
func (c *UserSessionCache) Set(user *dbmodels.User, session *dbmodels.Session) (*dbmodels.FixedUser, *dbmodels.FixedSession) {
if user == nil || session == nil {
return nil, nil
}
newEntry := &cacheEntry{
user: user.Fixed(),
session: session.Fixed(),
}
_, loaded := c.cache.LoadOrStore(session.Token(), newEntry)
if !loaded {
c.cache.Store(session.Token(), newEntry)
c.mu.Lock()
c.approximateSize++
c.mu.Unlock()
}
return &newEntry.user, &newEntry.session
}
func (c *UserSessionCache) Get(sessionTokenClear string) (*dbmodels.FixedUser, *dbmodels.FixedSession, bool) {
if sessionTokenClear == "" {
return nil, nil, false
}
value, ok := c.cache.Load(sessionTokenClear)
if !ok {
return nil, nil, false
}
entry, ok := value.(*cacheEntry)
if !ok {
c.cache.Delete(sessionTokenClear)
return nil, nil, false
}
if time.Now().After(entry.session.Expires.Time()) {
c.cache.Delete(sessionTokenClear)
c.mu.Lock()
c.approximateSize--
c.mu.Unlock()
return nil, nil, false
}
return &entry.user, &entry.session, true
}
func (c *UserSessionCache) Delete(sessionTokenClear string) {
if sessionTokenClear == "" {
return
}
_, loaded := c.cache.LoadAndDelete(sessionTokenClear)
if loaded {
c.mu.Lock()
c.approximateSize--
c.mu.Unlock()
}
}
func (c *UserSessionCache) startCleanupRoutine() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.cleanupExpiredItems()
case <-c.stopCleanupSignal:
return
}
}
}
func (c *UserSessionCache) cleanupExpiredItems() {
now := time.Now()
var newSize int
c.cache.Range(func(key, value any) bool {
entry, ok := value.(*cacheEntry)
if !ok {
c.cache.Delete(key)
return true
}
if now.After(entry.session.Expires.Time()) {
c.cache.Delete(key)
} else {
newSize++
}
return true
})
c.mu.Lock()
c.approximateSize = newSize
c.mu.Unlock()
}
func (c *UserSessionCache) StopCleanup() {
select {
case <-c.stopCleanupSignal:
default:
close(c.stopCleanupSignal)
}
}

View File

@@ -0,0 +1,188 @@
package security
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"log/slog"
"sync"
"time"
)
// --- NonceCache ---
type nonceEntry struct {
expiresAt time.Time
}
// NonceCache stores nonces and their expiration times.
// It is safe for concurrent use.
type NonceCache struct {
mu sync.RWMutex
nonces map[string]nonceEntry
defaultExpiration time.Duration
cleanupInterval time.Duration
stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop
}
// NewNonceCache creates a new in-memory nonce cache.
// defaultExpiration: The default duration for which a nonce is valid.
// cleanupInterval: How often to scan for and remove expired nonces.
func NewNonceCache(defaultExpiration, cleanupInterval time.Duration) *NonceCache {
if defaultExpiration <= 0 {
defaultExpiration = 15 * time.Minute // Default to 15 minutes if invalid
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute // Default to 5 minutes if invalid
}
nc := &NonceCache{
nonces: make(map[string]nonceEntry),
defaultExpiration: defaultExpiration,
cleanupInterval: cleanupInterval,
stopCleanup: make(chan struct{}),
}
go nc.startCleanupRoutine()
return nc
}
func (nc *NonceCache) Add(nonce string) {
nc.addWithExpiration(nonce, nc.defaultExpiration)
}
func (nc *NonceCache) addWithExpiration(nonce string, expiresIn time.Duration) {
if nonce == "" {
return
}
nc.mu.Lock()
defer nc.mu.Unlock()
nc.nonces[nonce] = nonceEntry{expiresAt: time.Now().Add(expiresIn)}
}
func (nc *NonceCache) Use(nonce string) bool {
if nonce == "" {
return false
}
nc.mu.Lock()
defer nc.mu.Unlock()
entry, exists := nc.nonces[nonce]
if !exists {
return false
}
if time.Now().After(entry.expiresAt) {
delete(nc.nonces, nonce)
return false
}
delete(nc.nonces, nonce)
return true
}
func (nc *NonceCache) startCleanupRoutine() {
ticker := time.NewTicker(nc.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
nc.cleanupExpiredNonces()
case <-nc.stopCleanup:
return
}
}
}
func (nc *NonceCache) cleanupExpiredNonces() {
nc.mu.Lock()
defer nc.mu.Unlock()
now := time.Now()
slog.Debug("Cleaning up expired nonces", "current_time", now, "nonces_count", len(nc.nonces))
for nonce, entry := range nc.nonces {
if now.After(entry.expiresAt) {
delete(nc.nonces, nonce)
}
}
}
func (nc *NonceCache) StopCleanup() {
select {
case <-nc.stopCleanup:
default:
close(nc.stopCleanup)
}
}
const (
defaultServerSecretSize = 32 // bytes
defaultNonceSize = 16 // bytes for raw nonce before encoding
)
type CSRFProtector struct {
serverSecret []byte
nonceCache *NonceCache
}
func NewCSRFProtector(nonceExpiration, nonceCleanupInterval time.Duration) (*CSRFProtector, error) {
secretToUse := make([]byte, defaultServerSecretSize)
if _, err := rand.Read(secretToUse); err != nil {
return nil, fmt.Errorf("failed to generate server secret: %w", err)
}
return &CSRFProtector{
serverSecret: secretToUse,
nonceCache: NewNonceCache(nonceExpiration, nonceCleanupInterval),
}, nil
}
func (p *CSRFProtector) GenerateTokenBundle() (nonceB64 string, validationTokenB64 string, err error) {
nonceBytes := make([]byte, defaultNonceSize)
if _, errRand := rand.Read(nonceBytes); errRand != nil {
return "", "", fmt.Errorf("failed to generate nonce bytes: %w", errRand)
}
nonceB64 = base64.URLEncoding.EncodeToString(nonceBytes)
p.nonceCache.Add(nonceB64)
mac := hmac.New(sha256.New, p.serverSecret)
mac.Write([]byte(nonceB64)) // Sign the base64 encoded nonce string
validationTokenBytes := mac.Sum(nil)
validationTokenB64 = base64.URLEncoding.EncodeToString(validationTokenBytes)
return nonceB64, validationTokenB64, nil
}
func (p *CSRFProtector) ValidateTokenBundle(nonceSubmittedB64 string, validationTokenSubmittedB64 string) (bool, error) {
if nonceSubmittedB64 == "" || validationTokenSubmittedB64 == "" {
return false, errors.New("submitted nonce or validation token is empty")
}
mac := hmac.New(sha256.New, p.serverSecret)
mac.Write([]byte(nonceSubmittedB64))
expectedMACTokenBytes := mac.Sum(nil)
validationTokenSubmittedBytes, err := base64.URLEncoding.DecodeString(validationTokenSubmittedB64)
if err != nil {
return false, fmt.Errorf("failed to decode submitted validation token: %w", err)
}
if !hmac.Equal(validationTokenSubmittedBytes, expectedMACTokenBytes) {
return false, errors.New("validation token (HMAC) mismatch")
}
if !p.nonceCache.Use(nonceSubmittedB64) {
return false, errors.New("nonce not found in cache, expired, or already used")
}
return true, nil
}
func (p *CSRFProtector) StopNonceCacheCleanup() {
if p.nonceCache != nil {
p.nonceCache.StopCleanup()
}
}