auth.go 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package main
  2. import (
  3. "bytes"
  4. "container/heap"
  5. "crypto/rand"
  6. "crypto/sha1"
  7. "encoding/base64"
  8. "errors"
  9. "golang.org/x/crypto/pbkdf2"
  10. "log"
  11. "net/http"
  12. "strings"
  13. "time"
  14. )
  15. type SessionData struct {
  16. sid string
  17. username string
  18. priority time.Time
  19. index int
  20. }
  21. type SessionQueue []*SessionData
  22. func (pq SessionQueue) Len() int {
  23. return len(pq)
  24. }
  25. func (pq SessionQueue) Less(i, j int) bool {
  26. return pq[i].priority.Before(pq[j].priority)
  27. }
  28. func (pq SessionQueue) Swap(i, j int) {
  29. pq[i], pq[j] = pq[j], pq[i]
  30. pq[i].index = i
  31. pq[j].index = j
  32. }
  33. func (pq *SessionQueue) Push(x interface{}) {
  34. n := len(*pq)
  35. item := x.(*SessionData)
  36. item.index = n
  37. *pq = append(*pq, item)
  38. }
  39. func (pq *SessionQueue) Pop() interface{} {
  40. old := *pq
  41. n := len(old)
  42. item := old[n-1]
  43. item.index = -1 // for safety
  44. *pq = old[0 : n-1]
  45. return item
  46. }
  47. var (
  48. sessions = make(map[string]*SessionData)
  49. sessionQueue = make(SessionQueue, 0)
  50. )
  51. func sessionExpirer() {
  52. for {
  53. for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) {
  54. session := heap.Pop(&sessionQueue).(*SessionData)
  55. delete(sessions, session.sid)
  56. }
  57. time.Sleep(time.Second * 5)
  58. }
  59. }
  60. func init() {
  61. go sessionExpirer()
  62. }
  63. func addSession(data *SessionData) {
  64. sessions[data.sid] = data
  65. heap.Push(&sessionQueue, data)
  66. }
  67. func generateHash(password string) (string, error) {
  68. salt := make([]byte, 11)
  69. if _, err := rand.Read(salt); err != nil {
  70. log.Println(err)
  71. return "", errors.New("Couldn't generate random string")
  72. }
  73. passwordBytes := []byte(password)
  74. key := hashPasswordSalt(passwordBytes, salt)
  75. saltStr := base64.StdEncoding.EncodeToString(salt)
  76. keyStr := base64.StdEncoding.EncodeToString(key)
  77. return saltStr + "$" + keyStr, nil
  78. }
  79. func hashOk(password, hashed string) (bool, error) {
  80. parts := strings.Split(hashed, "$")
  81. if len(parts) != 2 {
  82. return false, errors.New("Invalid data stored in database for password")
  83. }
  84. salt, err := base64.StdEncoding.DecodeString(parts[0])
  85. if err != nil {
  86. return false, err
  87. }
  88. passwordHash, err := base64.StdEncoding.DecodeString(parts[1])
  89. if err != nil {
  90. return false, err
  91. }
  92. if len(passwordHash) != 32 {
  93. return false, errors.New("Password hash was not 32 bytes long")
  94. }
  95. key := hashPasswordSalt([]byte(password), salt)
  96. return bytes.Equal(key, passwordHash), nil
  97. }
  98. func hashPasswordSalt(password, salt []byte) []byte {
  99. return pbkdf2.Key(password, salt, 4096, 32, sha1.New)
  100. }
  101. func userOk(db *DB, username, password string) (bool, error) {
  102. hash, err := db.FindHashForUser(username)
  103. if err != nil {
  104. if err.Error() == "sql: no rows in result set" {
  105. return false, nil
  106. } else {
  107. return false, err
  108. }
  109. }
  110. ok, err := hashOk(password, hash)
  111. if err != nil {
  112. return false, err
  113. }
  114. return ok, nil
  115. }
  116. func tryLogin(db *DB, username, password string, longerTime bool) (http.Cookie, error) {
  117. if exists, err := userOk(db, username, password); !exists {
  118. if err != nil {
  119. return http.Cookie{}, err
  120. }
  121. return http.Cookie{},
  122. errors.New("The username or password you entered isn't correct.")
  123. }
  124. sid, err := randString(32)
  125. if err != nil {
  126. return http.Cookie{}, err
  127. }
  128. duration := time.Hour * 1
  129. if longerTime {
  130. duration = time.Hour * 24 * 14
  131. }
  132. loginCookie := http.Cookie{
  133. Name: "id",
  134. Value: sid,
  135. MaxAge: int(duration.Seconds()),
  136. HttpOnly: true,
  137. }
  138. expiration := time.Now().Add(duration)
  139. addSession(&SessionData{sid, username, expiration, 0})
  140. return loginCookie, nil
  141. }
  142. func getSession(req *http.Request) (*SessionData, error) {
  143. cookie, err := req.Cookie("id")
  144. if err != nil {
  145. return nil, err
  146. }
  147. session, exists := sessions[cookie.Value]
  148. if !exists {
  149. return nil, errors.New("Session expired from server")
  150. }
  151. return session, nil
  152. }
  153. func randString(size int) (string, error) {
  154. buf := make([]byte, size)
  155. if _, err := rand.Read(buf); err != nil {
  156. log.Println(err)
  157. return "", errors.New("Couldn't generate random string")
  158. }
  159. return base64.URLEncoding.EncodeToString(buf)[:size], nil
  160. }