123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- package main
-
- import (
- "bytes"
- "container/heap"
- "crypto/rand"
- "crypto/sha1"
- "database/sql"
- "encoding/base64"
- "errors"
- _ "github.com/lib/pq"
- "golang.org/x/crypto/pbkdf2"
- "log"
- "net/http"
- "strings"
- "time"
- )
-
- type SessionData struct {
- sid string
- username string
- priority time.Time
- index int
- }
-
- type SessionQueue []*SessionData
-
- func (pq SessionQueue) Len() int {
- return len(pq)
- }
-
- func (pq SessionQueue) Less(i, j int) bool {
- return pq[i].priority.Before(pq[j].priority)
- }
-
- func (pq SessionQueue) Swap(i, j int) {
- pq[i], pq[j] = pq[j], pq[i]
- pq[i].index = i
- pq[j].index = j
- }
-
- func (pq *SessionQueue) Push(x interface{}) {
- n := len(*pq)
- item := x.(*SessionData)
- item.index = n
- *pq = append(*pq, item)
- }
-
- func (pq *SessionQueue) Pop() interface{} {
- old := *pq
- n := len(old)
- item := old[n-1]
- item.index = -1 // for safety
- *pq = old[0 : n-1]
- return item
- }
-
- var (
- sessions = make(map[string]*SessionData)
- sessionQueue = make(SessionQueue, 0)
- db *sql.DB
- )
-
- func sessionExpirer() {
- for {
- for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) {
- session := heap.Pop(&sessionQueue).(*SessionData)
- delete(sessions, session.sid)
- }
- time.Sleep(time.Second * 5)
- }
- }
-
- func init() {
- go sessionExpirer()
-
- connStr := "user=levyraati dbname=levyraati sslmode=disable"
- var err error
- db, err = sql.Open("postgres", connStr)
- if err != nil {
- log.Fatal(err)
- }
- _, err = db.Query("SELECT 1")
- if err != nil {
- log.Fatal(err)
- }
- }
-
- func addSession(data *SessionData) {
- sessions[data.sid] = data
- heap.Push(&sessionQueue, data)
- }
-
- func generateHash(password string) (string, error) {
- salt := make([]byte, 11)
-
- if _, err := rand.Read(salt); err != nil {
- log.Println(err)
- return "", errors.New("Couldn't generate random string")
- }
-
- passwordBytes := []byte(password)
- key := hashPasswordSalt(passwordBytes, salt)
-
- saltStr := base64.StdEncoding.EncodeToString(salt)
- keyStr := base64.StdEncoding.EncodeToString(key)
- return saltStr + "$" + keyStr, nil
- }
-
- func hashOk(password, hashed string) (bool, error) {
- parts := strings.Split(hashed, "$")
- if len(parts) != 2 {
- return false, errors.New("Invalid data stored in database for password")
- }
- salt, err := base64.StdEncoding.DecodeString(parts[0])
- if err != nil {
- return false, err
- }
- passwordHash, err := base64.StdEncoding.DecodeString(parts[1])
- if err != nil {
- return false, err
- }
- if len(passwordHash) != 32 {
- return false, errors.New("Password hash was not 32 bytes long")
- }
-
- key := hashPasswordSalt([]byte(password), salt)
- return bytes.Equal(key, passwordHash), nil
- }
-
- func hashPasswordSalt(password, salt []byte) []byte {
- return pbkdf2.Key(password, salt, 4096, 32, sha1.New)
- }
-
- func userOk(username, password string) (bool, error) {
- var hash string
- err := db.QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash)
- if err != nil {
- if err.Error() == "sql: no rows in result set" {
- return false, nil
- } else {
- return false, err
- }
- }
-
- ok, err := hashOk(password, hash)
- if err != nil {
- return false, err
- }
- return ok, nil
- }
-
- func tryLogin(username, password string, longerTime bool) (http.Cookie, error) {
- if exists, err := userOk(username, password); !exists {
- if err != nil {
- return http.Cookie{}, err
- }
- return http.Cookie{},
- errors.New("The username or password you entered isn't correct.")
- }
-
- sid, err := randString(32)
- if err != nil {
- return http.Cookie{}, err
- }
-
- duration := time.Hour * 1
- if longerTime {
- duration = time.Hour * 24 * 14
- }
-
- loginCookie := http.Cookie{
- Name: "id",
- Value: sid,
- MaxAge: int(duration.Seconds()),
- HttpOnly: true,
- }
-
- expiration := time.Now().Add(duration)
- addSession(&SessionData{sid, username, expiration, 0})
-
- return loginCookie, nil
- }
-
- func getSession(req *http.Request) (*SessionData, error) {
- cookie, err := req.Cookie("id")
- if err != nil {
- return nil, err
- }
-
- session, exists := sessions[cookie.Value]
- if !exists {
- return nil, errors.New("Session expired from server")
- }
- return session, nil
- }
-
- func randString(size int) (string, error) {
- buf := make([]byte, size)
-
- if _, err := rand.Read(buf); err != nil {
- log.Println(err)
- return "", errors.New("Couldn't generate random string")
- }
-
- return base64.URLEncoding.EncodeToString(buf)[:size], nil
- }
|