package main import ( "bytes" "container/heap" "crypto/rand" "crypto/sha1" "encoding/base64" "errors" "github.com/zmb3/spotify" "golang.org/x/crypto/pbkdf2" "log" "net/http" "strings" "time" ) type SessionData struct { sid string username string spotify string priority time.Time index int spotifyClient *spotify.Client } type SessionQueue []*SessionData func (pq SessionQueue) Len() int { return len(pq) } func (pq SessionQueue) Less(i, j int) bool { return pq[i].priority.Before(pq[j].priority) } func (pq SessionQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] pq[i].index = i pq[j].index = j } func (pq *SessionQueue) Push(x interface{}) { n := len(*pq) item := x.(*SessionData) item.index = n *pq = append(*pq, item) } func (pq *SessionQueue) Pop() interface{} { old := *pq n := len(old) item := old[n-1] item.index = -1 // for safety *pq = old[0 : n-1] return item } var ( sessions = make(map[string]*SessionData) sessionQueue = make(SessionQueue, 0) ) func sessionExpirer() { for { for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) { session := heap.Pop(&sessionQueue).(*SessionData) delete(sessions, session.sid) } time.Sleep(time.Second * 5) } } func init() { go sessionExpirer() } func addSession(data *SessionData) { sessions[data.sid] = data heap.Push(&sessionQueue, data) } func generateHash(password string) (string, error) { salt := make([]byte, 11) if _, err := rand.Read(salt); err != nil { log.Println(err) return "", errors.New("Couldn't generate random string") } passwordBytes := []byte(password) key := hashPasswordSalt(passwordBytes, salt) saltStr := base64.StdEncoding.EncodeToString(salt) keyStr := base64.StdEncoding.EncodeToString(key) return saltStr + "$" + keyStr, nil } func hashOk(password, hashed string) (bool, error) { parts := strings.Split(hashed, "$") if len(parts) != 2 { return false, errors.New("Invalid data stored in database for password") } salt, err := base64.StdEncoding.DecodeString(parts[0]) if err != nil { return false, err } passwordHash, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { return false, err } if len(passwordHash) != 32 { return false, errors.New("Password hash was not 32 bytes long") } key := hashPasswordSalt([]byte(password), salt) return bytes.Equal(key, passwordHash), nil } func hashPasswordSalt(password, salt []byte) []byte { return pbkdf2.Key(password, salt, 4096, 32, sha1.New) } func spotifyOk(db *DB, spotifyID string) (string, error) { username, err := db.FindUserBySpotifyID(spotifyID) if err != nil { if err.Error() == "sql: no rows in result set" { return "", nil } else { return "", err } } return username, nil } func userOk(db *DB, username, password string) (bool, error) { hash, err := db.FindHashForUser(username) if err != nil { if err.Error() == "sql: no rows in result set" { return false, nil } else { return false, err } } ok, err := hashOk(password, hash) if err != nil { return false, err } return ok, nil } func tryLoginWithSpotify(db *DB, spotifyID string) (http.Cookie, error) { username, err := spotifyOk(db, spotifyID) if username == "" { if err != nil { return http.Cookie{}, err } return http.Cookie{}, errors.New("This spotify account is not connected to any account") } return addSessionAndReturnMatchingCookie(username, spotifyID, true) } func tryLogin(db *DB, username, password string, longerTime bool) (http.Cookie, error) { if exists, err := userOk(db, username, password); !exists { if err != nil { return http.Cookie{}, err } return http.Cookie{}, errors.New("The username or password you entered isn't correct.") } return addSessionAndReturnMatchingCookie(username, "", longerTime) } func addSessionAndReturnMatchingCookie(username, spotifyID string, longerTime bool) (http.Cookie, error) { sid, err := randString(32) if err != nil { return http.Cookie{}, err } duration := time.Hour * 1 if longerTime { duration = time.Hour * 24 * 14 } loginCookie := http.Cookie{ Name: "id", Value: sid, MaxAge: int(duration.Seconds()), HttpOnly: true, } expiration := time.Now().Add(duration) addSession(&SessionData{sid, username, spotifyID, expiration, 0, nil}) return loginCookie, nil } func getSession(req *http.Request) (*SessionData, error) { cookie, err := req.Cookie("id") if err != nil { return nil, err } return getSessionById(cookie.Value) } func getSessionById(sid string) (*SessionData, error) { session, exists := sessions[sid] if !exists { return nil, errors.New("Session expired from server") } return session, nil } func randString(size int) (string, error) { buf := make([]byte, size) if _, err := rand.Read(buf); err != nil { log.Println(err) return "", errors.New("Couldn't generate random string") } return base64.URLEncoding.EncodeToString(buf)[:size], nil }