auth.go 4.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package main
  2. import (
  3. "bytes"
  4. "container/heap"
  5. "crypto/rand"
  6. "crypto/sha1"
  7. "encoding/base64"
  8. "errors"
  9. "github.com/zmb3/spotify"
  10. "golang.org/x/crypto/pbkdf2"
  11. "log"
  12. "net/http"
  13. "strings"
  14. "time"
  15. )
  16. type SessionData struct {
  17. sid string
  18. username string
  19. spotify string
  20. priority time.Time
  21. index int
  22. spotifyClient *spotify.Client
  23. }
  24. type SessionQueue []*SessionData
  25. func (pq SessionQueue) Len() int {
  26. return len(pq)
  27. }
  28. func (pq SessionQueue) Less(i, j int) bool {
  29. return pq[i].priority.Before(pq[j].priority)
  30. }
  31. func (pq SessionQueue) Swap(i, j int) {
  32. pq[i], pq[j] = pq[j], pq[i]
  33. pq[i].index = i
  34. pq[j].index = j
  35. }
  36. func (pq *SessionQueue) Push(x interface{}) {
  37. n := len(*pq)
  38. item := x.(*SessionData)
  39. item.index = n
  40. *pq = append(*pq, item)
  41. }
  42. func (pq *SessionQueue) Pop() interface{} {
  43. old := *pq
  44. n := len(old)
  45. item := old[n-1]
  46. item.index = -1 // for safety
  47. *pq = old[0 : n-1]
  48. return item
  49. }
  50. var (
  51. sessions = make(map[string]*SessionData)
  52. sessionQueue = make(SessionQueue, 0)
  53. )
  54. func sessionExpirer() {
  55. for {
  56. for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) {
  57. session := heap.Pop(&sessionQueue).(*SessionData)
  58. delete(sessions, session.sid)
  59. }
  60. time.Sleep(time.Second * 5)
  61. }
  62. }
  63. func init() {
  64. go sessionExpirer()
  65. }
  66. func addSession(data *SessionData) {
  67. sessions[data.sid] = data
  68. heap.Push(&sessionQueue, data)
  69. }
  70. func generateHash(password string) (string, error) {
  71. salt := make([]byte, 11)
  72. if _, err := rand.Read(salt); err != nil {
  73. log.Println(err)
  74. return "", errors.New("Couldn't generate random string")
  75. }
  76. passwordBytes := []byte(password)
  77. key := hashPasswordSalt(passwordBytes, salt)
  78. saltStr := base64.StdEncoding.EncodeToString(salt)
  79. keyStr := base64.StdEncoding.EncodeToString(key)
  80. return saltStr + "$" + keyStr, nil
  81. }
  82. func hashOk(password, hashed string) (bool, error) {
  83. parts := strings.Split(hashed, "$")
  84. if len(parts) != 2 {
  85. return false, errors.New("Invalid data stored in database for password")
  86. }
  87. salt, err := base64.StdEncoding.DecodeString(parts[0])
  88. if err != nil {
  89. return false, err
  90. }
  91. passwordHash, err := base64.StdEncoding.DecodeString(parts[1])
  92. if err != nil {
  93. return false, err
  94. }
  95. if len(passwordHash) != 32 {
  96. return false, errors.New("Password hash was not 32 bytes long")
  97. }
  98. key := hashPasswordSalt([]byte(password), salt)
  99. return bytes.Equal(key, passwordHash), nil
  100. }
  101. func hashPasswordSalt(password, salt []byte) []byte {
  102. return pbkdf2.Key(password, salt, 4096, 32, sha1.New)
  103. }
  104. func spotifyOk(db *DB, spotifyID string) (string, error) {
  105. username, err := db.FindUserBySpotifyID(spotifyID)
  106. if err != nil {
  107. if err.Error() == "sql: no rows in result set" {
  108. return "", nil
  109. } else {
  110. return "", err
  111. }
  112. }
  113. return username, nil
  114. }
  115. func userOk(db *DB, username, password string) (bool, error) {
  116. hash, err := db.FindHashForUser(username)
  117. if err != nil {
  118. if err.Error() == "sql: no rows in result set" {
  119. return false, nil
  120. } else {
  121. return false, err
  122. }
  123. }
  124. ok, err := hashOk(password, hash)
  125. if err != nil {
  126. return false, err
  127. }
  128. return ok, nil
  129. }
  130. func tryLoginWithSpotify(db *DB, spotifyID string) (http.Cookie, error) {
  131. username, err := spotifyOk(db, spotifyID)
  132. if username == "" {
  133. if err != nil {
  134. return http.Cookie{}, err
  135. }
  136. return http.Cookie{},
  137. errors.New("This spotify account is not connected to any account")
  138. }
  139. return addSessionAndReturnMatchingCookie(username, spotifyID, true)
  140. }
  141. func tryLogin(db *DB, username, password string, longerTime bool) (http.Cookie, error) {
  142. if exists, err := userOk(db, username, password); !exists {
  143. if err != nil {
  144. return http.Cookie{}, err
  145. }
  146. return http.Cookie{},
  147. errors.New("The username or password you entered isn't correct.")
  148. }
  149. return addSessionAndReturnMatchingCookie(username, "", longerTime)
  150. }
  151. func addSessionAndReturnMatchingCookie(username, spotifyID string, longerTime bool) (http.Cookie, error) {
  152. sid, err := randString(32)
  153. if err != nil {
  154. return http.Cookie{}, err
  155. }
  156. duration := time.Hour * 1
  157. if longerTime {
  158. duration = time.Hour * 24 * 14
  159. }
  160. loginCookie := http.Cookie{
  161. Name: "id",
  162. Value: sid,
  163. MaxAge: int(duration.Seconds()),
  164. HttpOnly: true,
  165. }
  166. expiration := time.Now().Add(duration)
  167. addSession(&SessionData{sid, username, spotifyID, expiration, 0, nil})
  168. return loginCookie, nil
  169. }
  170. func getSession(req *http.Request) (*SessionData, error) {
  171. cookie, err := req.Cookie("id")
  172. if err != nil {
  173. return nil, err
  174. }
  175. return getSessionById(cookie.Value)
  176. }
  177. func getSessionById(sid string) (*SessionData, error) {
  178. session, exists := sessions[sid]
  179. if !exists {
  180. return nil, errors.New("Session expired from server")
  181. }
  182. return session, nil
  183. }
  184. func randString(size int) (string, error) {
  185. buf := make([]byte, size)
  186. if _, err := rand.Read(buf); err != nil {
  187. log.Println(err)
  188. return "", errors.New("Couldn't generate random string")
  189. }
  190. return base64.URLEncoding.EncodeToString(buf)[:size], nil
  191. }