home to your local SPACEGIRL 馃挮 arimelody.space
at dev 193 lines 5.9 kB view raw
1package controller 2 3import ( 4 "database/sql" 5 "fmt" 6 "net/http" 7 "strings" 8 "time" 9 10 "arimelody-web/log" 11 "arimelody-web/model" 12 13 "github.com/jmoiron/sqlx" 14) 15 16const TOKEN_LEN = 64 17 18func GetSessionFromRequest(app *model.AppState, r *http.Request) (*model.Session, error) { 19 sessionCookie, err := r.Cookie(model.COOKIE_TOKEN) 20 if err != nil && err != http.ErrNoCookie { 21 return nil, fmt.Errorf("Failed to retrieve session cookie: %v", err) 22 } 23 24 var session *model.Session 25 26 if sessionCookie != nil { 27 // fetch existing session 28 session, err = GetSession(app.DB, sessionCookie.Value) 29 30 if err != nil && !strings.Contains(err.Error(), "no rows") { 31 return nil, fmt.Errorf("Failed to retrieve session: %v", err) 32 } 33 34 if session != nil { 35 if session.UserAgent != r.UserAgent() { 36 msg := "Session user agent mismatch. A cookie may have been hijacked!" 37 if session.Account != nil { 38 account, _ := GetAccountByID(app.DB, session.Account.ID) 39 msg += " (Account \"" + account.Username + "\")" 40 } 41 app.Log.Warn(log.TYPE_ACCOUNT, msg) 42 err = DeleteSession(app.DB, session.Token) 43 if err != nil { 44 app.Log.Warn(log.TYPE_ACCOUNT, "Failed to delete affected session") 45 } 46 return nil, nil 47 } 48 } 49 } 50 51 return session, nil 52} 53 54func CreateSession(db *sqlx.DB, userAgent string) (*model.Session, error) { 55 tokenString := GenerateAlnumString(TOKEN_LEN) 56 57 session := model.Session{ 58 Token: string(tokenString), 59 UserAgent: userAgent, 60 CreatedAt: time.Now(), 61 ExpiresAt: time.Now().Add(time.Hour * 24), 62 } 63 64 _, err := db.Exec("INSERT INTO session " + 65 "(token, user_agent, created_at, expires_at) VALUES " + 66 "($1, $2, $3, $4)", 67 session.Token, 68 session.UserAgent, 69 session.CreatedAt, 70 session.ExpiresAt, 71 ) 72 if err != nil { 73 return nil, err 74 } 75 76 return &session, nil 77} 78 79// func WriteSession(db *sqlx.DB, session *model.Session) error { 80// _, err := db.Exec( 81// "UPDATE session " + 82// "SET account=$2,message=$3,error=$4 " + 83// "WHERE token=$1", 84// session.Token, 85// session.Account.ID, 86// session.Message, 87// session.Error, 88// ) 89// return err 90// } 91 92func SetSessionAttemptAccount(db *sqlx.DB, session *model.Session, account *model.Account) error { 93 var err error 94 session.AttemptAccount = account 95 if account == nil { 96 _, err = db.Exec("UPDATE session SET attempt_account=NULL WHERE token=$1", session.Token) 97 } else { 98 _, err = db.Exec("UPDATE session SET attempt_account=$2 WHERE token=$1", session.Token, account.ID) 99 } 100 return err 101} 102 103func SetSessionAccount(db *sqlx.DB, session *model.Session, account *model.Account) error { 104 var err error 105 session.Account = account 106 if account == nil { 107 _, err = db.Exec("UPDATE session SET account=NULL WHERE token=$1", session.Token) 108 } else { 109 _, err = db.Exec("UPDATE session SET account=$2 WHERE token=$1", session.Token, account.ID) 110 } 111 return err 112} 113 114func SetSessionMessage(db *sqlx.DB, session *model.Session, message string) error { 115 var err error 116 if message == "" { 117 if !session.Message.Valid { return nil } 118 session.Message = sql.NullString{ } 119 _, err = db.Exec("UPDATE session SET message=NULL WHERE token=$1", session.Token) 120 } else { 121 session.Message = sql.NullString{ String: message, Valid: true } 122 _, err = db.Exec("UPDATE session SET message=$2 WHERE token=$1", session.Token, message) 123 } 124 return err 125} 126 127func SetSessionError(db *sqlx.DB, session *model.Session, message string) error { 128 var err error 129 if message == "" { 130 if !session.Error.Valid { return nil } 131 session.Error = sql.NullString{ } 132 _, err = db.Exec("UPDATE session SET error=NULL WHERE token=$1", session.Token) 133 } else { 134 session.Error = sql.NullString{ String: message, Valid: true } 135 _, err = db.Exec("UPDATE session SET error=$2 WHERE token=$1", session.Token, message) 136 } 137 return err 138} 139 140func GetSession(db *sqlx.DB, token string) (*model.Session, error) { 141 type dbSession struct { 142 model.Session 143 AttemptAccountID sql.NullString `db:"attempt_account"` 144 AccountID sql.NullString `db:"account"` 145 } 146 147 session := dbSession{} 148 err := db.Get( 149 &session, 150 "SELECT * FROM session WHERE token=$1", 151 token, 152 ) 153 if err != nil { 154 return nil, err 155 } 156 157 if session.AccountID.Valid { 158 session.Account, err = GetAccountByID(db, session.AccountID.String) 159 if err != nil { 160 return nil, err 161 } 162 } 163 164 if session.AttemptAccountID.Valid { 165 session.AttemptAccount, err = GetAccountByID(db, session.AttemptAccountID.String) 166 if err != nil { 167 return nil, err 168 } 169 } 170 171 return &session.Session, err 172} 173 174// func GetAllSessionsForAccount(db *sqlx.DB, accountID string) ([]model.Session, error) { 175// sessions := []model.Session{} 176// err := db.Select(&sessions, "SELECT * FROM session WHERE account=$1 AND expires_at>current_timestamp", accountID) 177// return sessions, err 178// } 179 180func DeleteAllSessionsForAccount(db *sqlx.DB, accountID string) error { 181 _, err := db.Exec("DELETE FROM session WHERE account=$1", accountID) 182 return err 183} 184 185func DeleteSession(db *sqlx.DB, token string) error { 186 _, err := db.Exec("DELETE FROM session WHERE token=$1", token) 187 return err 188} 189 190func DeleteExpiredSessions(db *sqlx.DB) error { 191 _, err := db.Exec("DELETE FROM session WHERE expires_at<current_timestamp") 192 return err 193}