package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "fmt" "net/http" "github.com/gorilla/sessions" "golang.org/x/crypto/bcrypt" ) type contextKey string const UserIDKey contextKey = "user_id" var store *sessions.CookieStore func InitStore(secret string) { store = sessions.NewCookieStore([]byte(secret)) store.Options = &sessions.Options{ Path: "/", MaxAge: 86400 * 30, // 30 days HttpOnly: true, SameSite: http.SameSiteLaxMode, } } func GetSession(r *http.Request) *sessions.Session { sess, _ := store.Get(r, "mdhub") return sess } func SetUserID(w http.ResponseWriter, r *http.Request, userID string) error { sess := GetSession(r) sess.Values["user_id"] = userID return sess.Save(r, w) } func GetUserID(r *http.Request) string { sess := GetSession(r) if v, ok := sess.Values["user_id"].(string); ok { return v } return "" } func ClearSession(w http.ResponseWriter, r *http.Request) error { sess := GetSession(r) sess.Options.MaxAge = -1 return sess.Save(r, w) } func UserIDFromContext(ctx context.Context) string { if v, ok := ctx.Value(UserIDKey).(string); ok { return v } return "" } // Password helpers func HashPassword(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return string(hash), nil } func CheckPassword(hash, password string) bool { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil } func GenerateToken() string { b := make([]byte, 32) rand.Read(b) return hex.EncodeToString(b) } // State parameter for CSRF protection func SetOAuthState(w http.ResponseWriter, r *http.Request) string { state := GenerateToken() sess := GetSession(r) sess.Values["oauth_state"] = state sess.Save(r, w) return state } func ValidateOAuthState(r *http.Request, state string) error { sess := GetSession(r) expected, ok := sess.Values["oauth_state"].(string) if !ok || expected != state { return fmt.Errorf("invalid oauth state") } delete(sess.Values, "oauth_state") return nil } // PKCEVerifier generates a cryptographically random PKCE code verifier (43-128 chars, URL-safe base64). func PKCEVerifier() string { b := make([]byte, 32) rand.Read(b) return base64.RawURLEncoding.EncodeToString(b) } // PKCEChallenge derives the S256 code challenge from a verifier. func PKCEChallenge(verifier string) string { h := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(h[:]) }