123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- package main
-
- import (
- "database/sql"
- "encoding/json"
- "errors"
- _ "github.com/lib/pq"
- "log"
- "os"
- "time"
- )
-
- type DB struct {
- database *sql.DB
- }
-
- func InitDatabase() *DB {
- if os.Getenv("PGUSER") == "" {
- os.Setenv("PGUSER", "levyraati")
- }
- if os.Getenv("PGDATABASE") == "" {
- os.Setenv("PGDATABASE", "levyraati")
- }
-
- connStr := "sslmode=disable"
- var err error
- database, err := sql.Open("postgres", connStr)
- if err != nil {
- log.Fatal(err)
- }
- _, err = database.Query("SELECT 1")
- if err != nil {
- log.Fatal(err)
- }
- return &DB{database}
- }
-
- func (db *DB) FindUserBySpotifyID(username string) (string, error) {
- var user string
- err := db.database.QueryRow("SELECT u.username FROM public.user u WHERE u.spotify_id = $1", username).Scan(&user)
- return user, err
- }
-
- func (db *DB) FindHashForUser(username string) (string, error) {
- var hash string
- err := db.database.QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash)
- return hash, err
- }
-
- func (db *DB) EntrySynced(userId, roundId int) (bool, error) {
- query := `UPDATE public.entry SET synced = true WHERE user_id = $1 AND round_id = $2`
- res, err := db.database.Exec(query, userId, roundId)
-
- if err != nil {
- return false, err
- }
-
- affected, err := res.RowsAffected()
- if err != nil {
- return false, err
- }
-
- if affected != 1 {
- return false, errors.New("Unknown entry ID")
- }
- return true, nil
-
- }
-
- func (db *DB) FindEntriesToSync() ([]*EntryToSync, error) {
- query := `
- SELECT e.user_id, e.round_id, e.artist, e.title, e.spotify_url, p.article, u.username, r.section
- FROM public.entry e
- JOIN public."user" u ON u.id = e.user_id
- JOIN public.round r ON r.id = e.round_id
- JOIN public.panel p ON p.id = r.panel_id
- WHERE r.start < current_timestamp AND e.synced = false AND p.sync_enabled = true`
- rows, err := db.database.Query(query)
-
- if err != nil {
- log.Println("Error while reading songs from database:", err)
- return nil, err
- }
- defer rows.Close()
- var entries []*EntryToSync
- for rows.Next() {
- var (
- userId, roundId int
- artist, title, spotifyURL, article, username, section string
- )
- err := rows.Scan(&userId, &roundId, &artist, &title, &spotifyURL, &article, &username, §ion)
- if err != nil {
- log.Println("Error while scanning row:", err)
- return nil, err
- }
- entries = append(entries, &EntryToSync{userId, roundId, artist, title, spotifyURL, article, username, section})
- }
- err = rows.Err()
- if err != nil {
- log.Println("Error after reading cursor:", err)
- return nil, err
- }
- return entries, nil
- }
-
- type EntryToSync struct {
- userId, roundId int
- artist, title, spotifyURL, article, username, section string
- }
-
- func (db *DB) FindRoundInfo(roundNum int) (*RoundInfo, error) {
- query := `
- SELECT r.section, p.name, p.article, r.start, r.end
- FROM public.round r
- JOIN public.panel p ON p.id = r.panel_id
- WHERE r.id = $1`
- rows, err := db.database.Query(query, roundNum)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- var (
- section, panelName, panelArticle string
- start, end *time.Time
- )
- if !rows.Next() {
- return nil, nil
- }
- err = rows.Scan(§ion, &panelName, &panelArticle, &start, &end)
- if err != nil {
- log.Println("Error while scanning row:", err)
- return nil, err
- }
- return &RoundInfo{section, panelName, panelArticle, start, end}, nil
- }
-
- type RoundInfo struct {
- Section, PanelName, PanelArticle string
- Start, End *time.Time
- }
-
- func (db *DB) FindAllRoundEntries(round int) ([]*Song, error) {
- query := `
- SELECT r.id, r.section, e.artist, e.title, e.spotify_url, e.synced, u.username
- FROM public.round r
- JOIN public.panel p ON p.id = r.panel_id
- JOIN public.entry e ON r.id = e.round_id
- JOIN public."user" u ON u.id = e.user_id
- WHERE r.id = $1`
- rows, err := db.database.Query(query, round)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- return db.rowsToSong(rows)
- }
-
- func (db *DB) FindAllEntries(username string) ([]*Song, error) {
- query := `
- SELECT r.id, r.section, e.artist, e.title, e.spotify_url, e.synced, u.username
- FROM public.round r
- LEFT JOIN public.entry e ON r.id = e.round_id AND e.user_id =
- (SELECT u2.id FROM public."user" u2 WHERE lower(u2.username) = lower($1))
- LEFT JOIN public."user" u ON r.id = e.round_id AND u.id = e.user_id
- ORDER BY r.start ASC`
- rows, err := db.database.Query(query, username)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- return db.rowsToSong(rows)
- }
-
- func (db *DB) rowsToSong(rows *sql.Rows) ([]*Song, error) {
- var songs []*Song
- for i := 0; rows.Next(); i++ {
- song := &Song{}
- songs = append(songs, song)
- var (
- artist, title, url, username *string
- sync *bool
- )
- err := rows.Scan(&songs[i].RoundID, &songs[i].RoundName, &artist, &title, &url, &sync, &username)
- if err != nil {
- return nil, err
- }
- if artist != nil {
- song.Artist = *artist
- }
- if title != nil {
- song.Title = *title
- }
- if url != nil {
- song.URL = *url
- }
- if sync != nil {
- song.Sync = *sync
- }
- if username != nil {
- song.Submitter = *username
- }
- }
- return songs, nil
- }
-
- type Song struct {
- RoundID int
- RoundName string
- Title string
- Artist string
- URL string
- Sync bool
- Submitter string
- }
-
- func (db *DB) UpdateEntry(username, round, artist, title, url string) (bool, error) {
- query := `
- INSERT INTO public.entry
- SELECT id, $2, $3, $4, $5, false
- FROM public."user" u
- WHERE lower(u.username) = lower($1)
- ON CONFLICT (user_id, round_id) DO UPDATE SET artist = EXCLUDED.artist, title = EXCLUDED.title, spotify_url = EXCLUDED.spotify_url, synced = EXCLUDED.synced`
- res, err := db.database.Exec(query, username, round, artist, title, url)
-
- if err != nil {
- return false, err
- }
- affected, err := res.RowsAffected()
- if err != nil {
- return false, err
- }
- if affected != 1 {
- return false, nil
- }
- return true, nil
- }
-
- func (db *DB) FindAllPanels() ([]string, error) {
- query := `
- SELECT name FROM public.panel`
- rows, err := db.database.Query(query)
- if err != nil {
- return nil, err
- }
- defer rows.Close()
- names := make([]string, 0)
- for i := 0; rows.Next(); i++ {
- var name string
- err := rows.Scan(&name)
- if err != nil {
- return nil, err
- }
- names = append(names, name)
- }
- return names, nil
- }
-
- func (db *DB) FindPlaylistBySection(sectionName string) ([]string, error) {
- query := `
- SELECT array_to_json(pl.tracks) FROM public.round_playlist pl
- JOIN public.round r ON pl.round_id = r.id
- WHERE r.section = $1`
- row := db.database.QueryRow(query, sectionName)
- var tracksJson []byte
- tracks := make([]string, 0)
- err := row.Scan(&tracksJson)
- if err != nil {
- return nil, err
- }
- err = json.Unmarshal(tracksJson, &tracks)
- if err != nil {
- return nil, err
- }
- return tracks, nil
- }
-
- func (db *DB) UpdatePlaylistBySection(sectionName string, tracks []string) (bool, error) {
- tracksJSON, err := json.Marshal(tracks)
- if err != nil {
- return false, err
- }
- query := `
- INSERT INTO public.round_playlist
- SELECT r.id, ARRAY(SELECT e::text FROM json_array_elements_text($2::json) e)
- FROM public.round r
- WHERE r.section = $1
- ON CONFLICT (round_id) DO UPDATE SET tracks = EXCLUDED.tracks`
-
- res, err := db.database.Exec(query, sectionName, tracksJSON)
-
- if err != nil {
- return false, err
- }
- affected, err := res.RowsAffected()
- if err != nil {
- return false, err
- }
- if affected != 1 {
- return false, nil
- }
- return true, nil
- }
|