auth.go 4.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package main
  2. import (
  3. "bytes"
  4. "container/heap"
  5. "crypto/rand"
  6. "crypto/sha1"
  7. "database/sql"
  8. "encoding/base64"
  9. "errors"
  10. _ "github.com/lib/pq"
  11. "golang.org/x/crypto/pbkdf2"
  12. "log"
  13. "net/http"
  14. "strings"
  15. "time"
  16. )
  17. type SessionData struct {
  18. sid string
  19. username string
  20. priority time.Time
  21. index int
  22. }
  23. type SessionQueue []*SessionData
  24. func (pq SessionQueue) Len() int {
  25. return len(pq)
  26. }
  27. func (pq SessionQueue) Less(i, j int) bool {
  28. return pq[i].priority.Before(pq[j].priority)
  29. }
  30. func (pq SessionQueue) Swap(i, j int) {
  31. pq[i], pq[j] = pq[j], pq[i]
  32. pq[i].index = i
  33. pq[j].index = j
  34. }
  35. func (pq *SessionQueue) Push(x interface{}) {
  36. n := len(*pq)
  37. item := x.(*SessionData)
  38. item.index = n
  39. *pq = append(*pq, item)
  40. }
  41. func (pq *SessionQueue) Pop() interface{} {
  42. old := *pq
  43. n := len(old)
  44. item := old[n-1]
  45. item.index = -1 // for safety
  46. *pq = old[0 : n-1]
  47. return item
  48. }
  49. var (
  50. sessions = make(map[string]*SessionData)
  51. sessionQueue = make(SessionQueue, 0)
  52. db *sql.DB
  53. )
  54. func sessionExpirer() {
  55. for {
  56. for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) {
  57. session := heap.Pop(&sessionQueue).(*SessionData)
  58. delete(sessions, session.sid)
  59. }
  60. time.Sleep(time.Second * 5)
  61. }
  62. }
  63. func init() {
  64. go sessionExpirer()
  65. connStr := "user=levyraati dbname=levyraati sslmode=disable"
  66. var err error
  67. db, err = sql.Open("postgres", connStr)
  68. if err != nil {
  69. log.Fatal(err)
  70. }
  71. _, err = db.Query("SELECT 1")
  72. if err != nil {
  73. log.Fatal(err)
  74. }
  75. }
  76. func addSession(data *SessionData) {
  77. sessions[data.sid] = data
  78. heap.Push(&sessionQueue, data)
  79. }
  80. func generateHash(password string) (string, error) {
  81. salt := make([]byte, 11)
  82. if _, err := rand.Read(salt); err != nil {
  83. log.Println(err)
  84. return "", errors.New("Couldn't generate random string")
  85. }
  86. passwordBytes := []byte(password)
  87. key := hashPasswordSalt(passwordBytes, salt)
  88. saltStr := base64.StdEncoding.EncodeToString(salt)
  89. keyStr := base64.StdEncoding.EncodeToString(key)
  90. return saltStr + "$" + keyStr, nil
  91. }
  92. func hashOk(password, hashed string) (bool, error) {
  93. parts := strings.Split(hashed, "$")
  94. if len(parts) != 2 {
  95. return false, errors.New("Invalid data stored in database for password")
  96. }
  97. salt, err := base64.StdEncoding.DecodeString(parts[0])
  98. if err != nil {
  99. return false, err
  100. }
  101. passwordHash, err := base64.StdEncoding.DecodeString(parts[1])
  102. if err != nil {
  103. return false, err
  104. }
  105. if len(passwordHash) != 32 {
  106. return false, errors.New("Password hash was not 32 bytes long")
  107. }
  108. key := hashPasswordSalt([]byte(password), salt)
  109. return bytes.Equal(key, passwordHash), nil
  110. }
  111. func hashPasswordSalt(password, salt []byte) []byte {
  112. return pbkdf2.Key(password, salt, 4096, 32, sha1.New)
  113. }
  114. func userOk(username, password string) (bool, error) {
  115. var hash string
  116. err := db.QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash)
  117. if err != nil {
  118. if err.Error() == "sql: no rows in result set" {
  119. return false, nil
  120. } else {
  121. return false, err
  122. }
  123. }
  124. ok, err := hashOk(password, hash)
  125. if err != nil {
  126. return false, err
  127. }
  128. return ok, nil
  129. }
  130. func tryLogin(username, password string, longerTime bool) (http.Cookie, error) {
  131. if exists, err := userOk(username, password); !exists {
  132. if err != nil {
  133. return http.Cookie{}, err
  134. }
  135. return http.Cookie{},
  136. errors.New("The username or password you entered isn't correct.")
  137. }
  138. sid, err := randString(32)
  139. if err != nil {
  140. return http.Cookie{}, err
  141. }
  142. duration := time.Hour * 1
  143. if longerTime {
  144. duration = time.Hour * 24 * 14
  145. }
  146. loginCookie := http.Cookie{
  147. Name: "id",
  148. Value: sid,
  149. MaxAge: int(duration.Seconds()),
  150. HttpOnly: true,
  151. }
  152. expiration := time.Now().Add(duration)
  153. addSession(&SessionData{sid, username, expiration, 0})
  154. return loginCookie, nil
  155. }
  156. func getSession(req *http.Request) (*SessionData, error) {
  157. cookie, err := req.Cookie("id")
  158. if err != nil {
  159. return nil, err
  160. }
  161. session, exists := sessions[cookie.Value]
  162. if !exists {
  163. return nil, errors.New("Session expired from server")
  164. }
  165. return session, nil
  166. }
  167. func randString(size int) (string, error) {
  168. buf := make([]byte, size)
  169. if _, err := rand.Read(buf); err != nil {
  170. log.Println(err)
  171. return "", errors.New("Couldn't generate random string")
  172. }
  173. return base64.URLEncoding.EncodeToString(buf)[:size], nil
  174. }