auth.go 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package main
  2. import (
  3. "bytes"
  4. "container/heap"
  5. "crypto/rand"
  6. "crypto/sha1"
  7. "encoding/base64"
  8. "errors"
  9. "golang.org/x/crypto/pbkdf2"
  10. "log"
  11. "net/http"
  12. "strings"
  13. "time"
  14. )
  15. type SessionData struct {
  16. sid string
  17. username string
  18. priority time.Time
  19. index int
  20. }
  21. type SessionQueue []*SessionData
  22. func (pq SessionQueue) Len() int {
  23. return len(pq)
  24. }
  25. func (pq SessionQueue) Less(i, j int) bool {
  26. return pq[i].priority.Before(pq[j].priority)
  27. }
  28. func (pq SessionQueue) Swap(i, j int) {
  29. pq[i], pq[j] = pq[j], pq[i]
  30. pq[i].index = i
  31. pq[j].index = j
  32. }
  33. func (pq *SessionQueue) Push(x interface{}) {
  34. n := len(*pq)
  35. item := x.(*SessionData)
  36. item.index = n
  37. *pq = append(*pq, item)
  38. }
  39. func (pq *SessionQueue) Pop() interface{} {
  40. old := *pq
  41. n := len(old)
  42. item := old[n-1]
  43. item.index = -1 // for safety
  44. *pq = old[0 : n-1]
  45. return item
  46. }
  47. var (
  48. sessions = make(map[string]*SessionData)
  49. sessionQueue = make(SessionQueue, 0)
  50. )
  51. func sessionExpirer() {
  52. for {
  53. for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) {
  54. session := heap.Pop(&sessionQueue).(*SessionData)
  55. delete(sessions, session.sid)
  56. }
  57. time.Sleep(time.Second * 5)
  58. }
  59. }
  60. func init() {
  61. go sessionExpirer()
  62. }
  63. func addSession(data *SessionData) {
  64. sessions[data.sid] = data
  65. heap.Push(&sessionQueue, data)
  66. }
  67. func generateHash(password string) (string, error) {
  68. salt := make([]byte, 11)
  69. if _, err := rand.Read(salt); err != nil {
  70. log.Println(err)
  71. return "", errors.New("Couldn't generate random string")
  72. }
  73. passwordBytes := []byte(password)
  74. key := hashPasswordSalt(passwordBytes, salt)
  75. saltStr := base64.StdEncoding.EncodeToString(salt)
  76. keyStr := base64.StdEncoding.EncodeToString(key)
  77. return saltStr + "$" + keyStr, nil
  78. }
  79. func hashOk(password, hashed string) (bool, error) {
  80. parts := strings.Split(hashed, "$")
  81. if len(parts) != 2 {
  82. return false, errors.New("Invalid data stored in database for password")
  83. }
  84. salt, err := base64.StdEncoding.DecodeString(parts[0])
  85. if err != nil {
  86. return false, err
  87. }
  88. passwordHash, err := base64.StdEncoding.DecodeString(parts[1])
  89. if err != nil {
  90. return false, err
  91. }
  92. if len(passwordHash) != 32 {
  93. return false, errors.New("Password hash was not 32 bytes long")
  94. }
  95. key := hashPasswordSalt([]byte(password), salt)
  96. return bytes.Equal(key, passwordHash), nil
  97. }
  98. func hashPasswordSalt(password, salt []byte) []byte {
  99. return pbkdf2.Key(password, salt, 4096, 32, sha1.New)
  100. }
  101. func userOk(username, password string) (bool, error) {
  102. var hash string
  103. err := getDb().QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash)
  104. if err != nil {
  105. if err.Error() == "sql: no rows in result set" {
  106. return false, nil
  107. } else {
  108. return false, err
  109. }
  110. }
  111. ok, err := hashOk(password, hash)
  112. if err != nil {
  113. return false, err
  114. }
  115. return ok, nil
  116. }
  117. func tryLogin(username, password string, longerTime bool) (http.Cookie, error) {
  118. if exists, err := userOk(username, password); !exists {
  119. if err != nil {
  120. return http.Cookie{}, err
  121. }
  122. return http.Cookie{},
  123. errors.New("The username or password you entered isn't correct.")
  124. }
  125. sid, err := randString(32)
  126. if err != nil {
  127. return http.Cookie{}, err
  128. }
  129. duration := time.Hour * 1
  130. if longerTime {
  131. duration = time.Hour * 24 * 14
  132. }
  133. loginCookie := http.Cookie{
  134. Name: "id",
  135. Value: sid,
  136. MaxAge: int(duration.Seconds()),
  137. HttpOnly: true,
  138. }
  139. expiration := time.Now().Add(duration)
  140. addSession(&SessionData{sid, username, expiration, 0})
  141. return loginCookie, nil
  142. }
  143. func getSession(req *http.Request) (*SessionData, error) {
  144. cookie, err := req.Cookie("id")
  145. if err != nil {
  146. return nil, err
  147. }
  148. session, exists := sessions[cookie.Value]
  149. if !exists {
  150. return nil, errors.New("Session expired from server")
  151. }
  152. return session, nil
  153. }
  154. func randString(size int) (string, error) {
  155. buf := make([]byte, size)
  156. if _, err := rand.Read(buf); err != nil {
  157. log.Println(err)
  158. return "", errors.New("Couldn't generate random string")
  159. }
  160. return base64.URLEncoding.EncodeToString(buf)[:size], nil
  161. }