db.go 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package main
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. _ "github.com/lib/pq"
  7. "log"
  8. "os"
  9. )
  10. type DB struct {
  11. database *sql.DB
  12. }
  13. func InitDatabase() *DB {
  14. if os.Getenv("PGUSER") == "" {
  15. os.Setenv("PGUSER", "levyraati")
  16. }
  17. if os.Getenv("PGDATABASE") == "" {
  18. os.Setenv("PGDATABASE", "levyraati")
  19. }
  20. connStr := "sslmode=disable"
  21. var err error
  22. database, err := sql.Open("postgres", connStr)
  23. if err != nil {
  24. log.Fatal(err)
  25. }
  26. _, err = database.Query("SELECT 1")
  27. if err != nil {
  28. log.Fatal(err)
  29. }
  30. return &DB{database}
  31. }
  32. func (db *DB) FindHashForUser(username string) (string, error) {
  33. var hash string
  34. err := db.database.QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash)
  35. return hash, err
  36. }
  37. func (db *DB) EntrySynced(userId, roundId int) (bool, error) {
  38. query := `UPDATE public.entry SET synced = true WHERE user_id = $1 AND round_id = $2`
  39. res, err := db.database.Exec(query, userId, roundId)
  40. if err != nil {
  41. return false, err
  42. }
  43. affected, err := res.RowsAffected()
  44. if err != nil {
  45. return false, err
  46. }
  47. if affected != 1 {
  48. return false, errors.New("Unknown entry ID")
  49. }
  50. return true, nil
  51. }
  52. func (db *DB) FindEntriesToSync() ([]*EntryToSync, error) {
  53. query := `
  54. SELECT e.user_id, e.round_id, e.artist, e.title, e.spotify_url, p.article, u.username, r.section
  55. FROM public.entry e
  56. JOIN public."user" u ON u.id = e.user_id
  57. JOIN public.round r ON r.id = e.round_id
  58. JOIN public.panel p ON p.id = r.panel_id
  59. WHERE r.start < current_timestamp AND e.synced = false`
  60. rows, err := db.database.Query(query)
  61. if err != nil {
  62. log.Println("Error while reading songs from database:", err)
  63. return nil, err
  64. }
  65. defer rows.Close()
  66. var entries []*EntryToSync
  67. for rows.Next() {
  68. var (
  69. userId, roundId int
  70. artist, title, spotifyURL, article, username, section string
  71. )
  72. err := rows.Scan(&userId, &roundId, &artist, &title, &spotifyURL, &article, &username, &section)
  73. if err != nil {
  74. log.Println("Error while scanning row:", err)
  75. return nil, err
  76. }
  77. entries = append(entries, &EntryToSync{userId, roundId, artist, title, spotifyURL, article, username, section})
  78. }
  79. err = rows.Err()
  80. if err != nil {
  81. log.Println("Error after reading cursor:", err)
  82. return nil, err
  83. }
  84. return entries, nil
  85. }
  86. type EntryToSync struct {
  87. userId, roundId int
  88. artist, title, spotifyURL, article, username, section string
  89. }
  90. func (db *DB) FindAllEntries(username string) ([]*Song, error) {
  91. var songs []*Song
  92. query := `
  93. SELECT r.id, r.section, e.artist, e.title, e.spotify_url, e.synced
  94. FROM public.round r
  95. LEFT JOIN public.entry e ON r.id = e.round_id
  96. LEFT JOIN public."user" u ON u.id = e.user_id AND lower(u.username) = lower($1)
  97. ORDER BY r.start ASC`
  98. rows, err := db.database.Query(query, username)
  99. if err != nil {
  100. return nil, err
  101. }
  102. defer rows.Close()
  103. for i := 0; rows.Next(); i++ {
  104. song := &Song{}
  105. songs = append(songs, song)
  106. var (
  107. artist, title, url *string
  108. sync *bool
  109. )
  110. err = rows.Scan(&songs[i].RoundID, &songs[i].RoundName, &artist, &title, &url, &sync)
  111. if err != nil {
  112. return nil, err
  113. }
  114. if artist != nil {
  115. song.Artist = *artist
  116. }
  117. if title != nil {
  118. song.Title = *title
  119. }
  120. if url != nil {
  121. song.URL = *url
  122. }
  123. if sync != nil {
  124. song.Sync = *sync
  125. }
  126. }
  127. return songs, nil
  128. }
  129. type Song struct {
  130. RoundID int
  131. RoundName string
  132. Title string
  133. Artist string
  134. URL string
  135. Sync bool
  136. }
  137. func (db *DB) UpdateEntry(username, round, artist, title, url string) (bool, error) {
  138. query := `
  139. INSERT INTO public.entry
  140. SELECT id, $2, $3, $4, $5, false
  141. FROM public."user" u
  142. WHERE lower(u.username) = lower($1)
  143. ON CONFLICT (user_id, round_id) DO UPDATE SET artist = EXCLUDED.artist, title = EXCLUDED.title, spotify_url = EXCLUDED.spotify_url, synced = EXCLUDED.synced`
  144. res, err := db.database.Exec(query, username, round, artist, title, url)
  145. if err != nil {
  146. return false, err
  147. }
  148. affected, err := res.RowsAffected()
  149. if err != nil {
  150. return false, err
  151. }
  152. if affected != 1 {
  153. return false, nil
  154. }
  155. return true, nil
  156. }
  157. func (db *DB) FindAllPanels() ([]string, error) {
  158. query := `
  159. SELECT name FROM public.panel`
  160. rows, err := db.database.Query(query)
  161. if err != nil {
  162. return nil, err
  163. }
  164. defer rows.Close()
  165. names := make([]string, 0)
  166. for i := 0; rows.Next(); i++ {
  167. var name string
  168. err := rows.Scan(&name)
  169. if err != nil {
  170. return nil, err
  171. }
  172. names = append(names, name)
  173. }
  174. return names, nil
  175. }
  176. func (db *DB) FindPlaylistBySection(sectionName string) ([]string, error) {
  177. query := `
  178. SELECT array_to_json(pl.tracks) FROM public.round_playlist pl
  179. JOIN public.round r ON pl.round_id = r.id
  180. WHERE r.section = $1`
  181. row := db.database.QueryRow(query, sectionName)
  182. var tracksJson []byte
  183. tracks := make([]string, 0)
  184. err := row.Scan(&tracksJson)
  185. if err != nil {
  186. return nil, err
  187. }
  188. err = json.Unmarshal(tracksJson, &tracks)
  189. if err != nil {
  190. return nil, err
  191. }
  192. return tracks, nil
  193. }
  194. func (db *DB) UpdatePlaylistBySection(sectionName string, tracks []string) (bool, error) {
  195. tracksJSON, err := json.Marshal(tracks)
  196. if err != nil {
  197. return false, err
  198. }
  199. query := `
  200. INSERT INTO public.round_playlist
  201. SELECT r.id, ARRAY(SELECT e::text FROM json_array_elements_text($2::json) e)
  202. FROM public.round r
  203. WHERE r.section = $1
  204. ON CONFLICT (round_id) DO UPDATE SET tracks = EXCLUDED.tracks`
  205. res, err := db.database.Exec(query, sectionName, tracksJSON)
  206. if err != nil {
  207. return false, err
  208. }
  209. affected, err := res.RowsAffected()
  210. if err != nil {
  211. return false, err
  212. }
  213. if affected != 1 {
  214. return false, nil
  215. }
  216. return true, nil
  217. }