package main import ( "bytes" "container/heap" "crypto/rand" "crypto/sha1" "database/sql" "encoding/base64" "errors" _ "github.com/lib/pq" "golang.org/x/crypto/pbkdf2" "log" "net/http" "strings" "time" ) type SessionData struct { sid string username string priority time.Time index int } type SessionQueue []*SessionData func (pq SessionQueue) Len() int { return len(pq) } func (pq SessionQueue) Less(i, j int) bool { return pq[i].priority.Before(pq[j].priority) } func (pq SessionQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] pq[i].index = i pq[j].index = j } func (pq *SessionQueue) Push(x interface{}) { n := len(*pq) item := x.(*SessionData) item.index = n *pq = append(*pq, item) } func (pq *SessionQueue) Pop() interface{} { old := *pq n := len(old) item := old[n-1] item.index = -1 // for safety *pq = old[0 : n-1] return item } var ( sessions = make(map[string]*SessionData) sessionQueue = make(SessionQueue, 0) db *sql.DB ) func sessionExpirer() { for { for len(sessionQueue) > 0 && time.Now().After(sessionQueue[0].priority) { session := heap.Pop(&sessionQueue).(*SessionData) delete(sessions, session.sid) } time.Sleep(time.Second * 5) } } func init() { go sessionExpirer() connStr := "user=levyraati dbname=levyraati sslmode=disable" var err error db, err = sql.Open("postgres", connStr) if err != nil { log.Fatal(err) } _, err = db.Query("SELECT 1") if err != nil { log.Fatal(err) } } func addSession(data *SessionData) { sessions[data.sid] = data heap.Push(&sessionQueue, data) } func generateHash(password string) (string, error) { salt := make([]byte, 11) if _, err := rand.Read(salt); err != nil { log.Println(err) return "", errors.New("Couldn't generate random string") } passwordBytes := []byte(password) key := hashPasswordSalt(passwordBytes, salt) saltStr := base64.StdEncoding.EncodeToString(salt) keyStr := base64.StdEncoding.EncodeToString(key) return saltStr + "$" + keyStr, nil } func hashOk(password, hashed string) (bool, error) { parts := strings.Split(hashed, "$") if len(parts) != 2 { return false, errors.New("Invalid data stored in database for password") } salt, err := base64.StdEncoding.DecodeString(parts[0]) if err != nil { return false, err } passwordHash, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { return false, err } if len(passwordHash) != 32 { return false, errors.New("Password hash was not 32 bytes long") } key := hashPasswordSalt([]byte(password), salt) return bytes.Equal(key, passwordHash), nil } func hashPasswordSalt(password, salt []byte) []byte { return pbkdf2.Key(password, salt, 4096, 32, sha1.New) } func userOk(username, password string) (bool, error) { var hash string err := db.QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash) if err != nil { if err.Error() == "sql: no rows in result set" { return false, nil } else { return false, err } } ok, err := hashOk(password, hash) if err != nil { return false, err } return ok, nil } func tryLogin(username, password string, longerTime bool) (http.Cookie, error) { if exists, err := userOk(username, password); !exists { if err != nil { return http.Cookie{}, err } return http.Cookie{}, errors.New("The username or password you entered isn't correct.") } sid, err := randString(32) if err != nil { return http.Cookie{}, err } duration := time.Hour * 1 if longerTime { duration = time.Hour * 24 * 14 } loginCookie := http.Cookie{ Name: "id", Value: sid, MaxAge: int(duration.Seconds()), HttpOnly: true, } expiration := time.Now().Add(duration) addSession(&SessionData{sid, username, expiration, 0}) return loginCookie, nil } func getSession(req *http.Request) (*SessionData, error) { cookie, err := req.Cookie("id") if err != nil { return nil, err } session, exists := sessions[cookie.Value] if !exists { return nil, errors.New("Session expired from server") } return session, nil } func randString(size int) (string, error) { buf := make([]byte, size) if _, err := rand.Read(buf); err != nil { log.Println(err) return "", errors.New("Couldn't generate random string") } return base64.URLEncoding.EncodeToString(buf)[:size], nil }