db.go 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package main
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. _ "github.com/lib/pq"
  7. "log"
  8. "os"
  9. "time"
  10. )
  11. type DB struct {
  12. database *sql.DB
  13. }
  14. func InitDatabase() *DB {
  15. if os.Getenv("PGUSER") == "" {
  16. os.Setenv("PGUSER", "levyraati")
  17. }
  18. if os.Getenv("PGDATABASE") == "" {
  19. os.Setenv("PGDATABASE", "levyraati")
  20. }
  21. connStr := "sslmode=disable"
  22. var err error
  23. database, err := sql.Open("postgres", connStr)
  24. if err != nil {
  25. log.Fatal(err)
  26. }
  27. _, err = database.Query("SELECT 1")
  28. if err != nil {
  29. log.Fatal(err)
  30. }
  31. return &DB{database}
  32. }
  33. func (db *DB) FindUserBySpotifyID(username string) (string, error) {
  34. var user string
  35. err := db.database.QueryRow("SELECT u.username FROM public.user u WHERE u.spotify_id = $1", username).Scan(&user)
  36. return user, err
  37. }
  38. func (db *DB) FindHashForUser(username string) (string, error) {
  39. var hash string
  40. err := db.database.QueryRow("SELECT u.password FROM public.user u WHERE lower(u.username) = lower($1)", username).Scan(&hash)
  41. return hash, err
  42. }
  43. func (db *DB) EntrySynced(userId, roundId int) (bool, error) {
  44. query := `UPDATE public.entry SET synced = true WHERE user_id = $1 AND round_id = $2`
  45. res, err := db.database.Exec(query, userId, roundId)
  46. if err != nil {
  47. return false, err
  48. }
  49. affected, err := res.RowsAffected()
  50. if err != nil {
  51. return false, err
  52. }
  53. if affected != 1 {
  54. return false, errors.New("Unknown entry ID")
  55. }
  56. return true, nil
  57. }
  58. func (db *DB) FindEntriesToSync() ([]*EntryToSync, error) {
  59. query := `
  60. SELECT e.user_id, e.round_id, e.artist, e.title, e.spotify_url, p.article, u.username, r.section
  61. FROM public.entry e
  62. JOIN public."user" u ON u.id = e.user_id
  63. JOIN public.round r ON r.id = e.round_id
  64. JOIN public.panel p ON p.id = r.panel_id
  65. WHERE r.start < current_timestamp AND e.synced = false AND p.sync_enabled = true`
  66. rows, err := db.database.Query(query)
  67. if err != nil {
  68. log.Println("Error while reading songs from database:", err)
  69. return nil, err
  70. }
  71. defer rows.Close()
  72. var entries []*EntryToSync
  73. for rows.Next() {
  74. var (
  75. userId, roundId int
  76. artist, title, spotifyURL, article, username, section string
  77. )
  78. err := rows.Scan(&userId, &roundId, &artist, &title, &spotifyURL, &article, &username, &section)
  79. if err != nil {
  80. log.Println("Error while scanning row:", err)
  81. return nil, err
  82. }
  83. entries = append(entries, &EntryToSync{userId, roundId, artist, title, spotifyURL, article, username, section})
  84. }
  85. err = rows.Err()
  86. if err != nil {
  87. log.Println("Error after reading cursor:", err)
  88. return nil, err
  89. }
  90. return entries, nil
  91. }
  92. type EntryToSync struct {
  93. userId, roundId int
  94. artist, title, spotifyURL, article, username, section string
  95. }
  96. func (db *DB) FindRoundInfo(roundNum int) (*RoundInfo, error) {
  97. query := `
  98. SELECT r.section, p.name, p.article, r.start, r.end
  99. FROM public.round r
  100. JOIN public.panel p ON p.id = r.panel_id
  101. WHERE r.id = $1`
  102. rows, err := db.database.Query(query, roundNum)
  103. if err != nil {
  104. return nil, err
  105. }
  106. defer rows.Close()
  107. var (
  108. section, panelName, panelArticle string
  109. start, end *time.Time
  110. )
  111. if !rows.Next() {
  112. return nil, nil
  113. }
  114. err = rows.Scan(&section, &panelName, &panelArticle, &start, &end)
  115. if err != nil {
  116. log.Println("Error while scanning row:", err)
  117. return nil, err
  118. }
  119. return &RoundInfo{section, panelName, panelArticle, start, end}, nil
  120. }
  121. type RoundInfo struct {
  122. Section, PanelName, PanelArticle string
  123. Start, End *time.Time
  124. }
  125. func (db *DB) FindAllRoundEntries(round int) ([]*Song, error) {
  126. query := `
  127. SELECT r.id, r.section, e.artist, e.title, e.spotify_url, e.synced, u.username
  128. FROM public.round r
  129. JOIN public.panel p ON p.id = r.panel_id
  130. JOIN public.entry e ON r.id = e.round_id
  131. JOIN public."user" u ON u.id = e.user_id
  132. WHERE r.id = $1`
  133. rows, err := db.database.Query(query, round)
  134. if err != nil {
  135. return nil, err
  136. }
  137. defer rows.Close()
  138. return db.rowsToSong(rows)
  139. }
  140. func (db *DB) FindAllEntries(username string) ([]*Song, error) {
  141. query := `
  142. SELECT r.id, r.section, e.artist, e.title, e.spotify_url, e.synced, u.username
  143. FROM public.round r
  144. LEFT JOIN public.entry e ON r.id = e.round_id AND e.user_id =
  145. (SELECT u2.id FROM public."user" u2 WHERE lower(u2.username) = lower($1))
  146. LEFT JOIN public."user" u ON r.id = e.round_id AND u.id = e.user_id
  147. ORDER BY r.start ASC`
  148. rows, err := db.database.Query(query, username)
  149. if err != nil {
  150. return nil, err
  151. }
  152. defer rows.Close()
  153. return db.rowsToSong(rows)
  154. }
  155. func (db *DB) rowsToSong(rows *sql.Rows) ([]*Song, error) {
  156. var songs []*Song
  157. for i := 0; rows.Next(); i++ {
  158. song := &Song{}
  159. songs = append(songs, song)
  160. var (
  161. artist, title, url, username *string
  162. sync *bool
  163. )
  164. err := rows.Scan(&songs[i].RoundID, &songs[i].RoundName, &artist, &title, &url, &sync, &username)
  165. if err != nil {
  166. return nil, err
  167. }
  168. if artist != nil {
  169. song.Artist = *artist
  170. }
  171. if title != nil {
  172. song.Title = *title
  173. }
  174. if url != nil {
  175. song.URL = *url
  176. }
  177. if sync != nil {
  178. song.Sync = *sync
  179. }
  180. if username != nil {
  181. song.Submitter = *username
  182. }
  183. }
  184. return songs, nil
  185. }
  186. type Song struct {
  187. RoundID int
  188. RoundName string
  189. Title string
  190. Artist string
  191. URL string
  192. Sync bool
  193. Submitter string
  194. }
  195. func (db *DB) UpdateEntry(username, round, artist, title, url string) (bool, error) {
  196. query := `
  197. INSERT INTO public.entry
  198. SELECT id, $2, $3, $4, $5, false
  199. FROM public."user" u
  200. WHERE lower(u.username) = lower($1)
  201. ON CONFLICT (user_id, round_id) DO UPDATE SET artist = EXCLUDED.artist, title = EXCLUDED.title, spotify_url = EXCLUDED.spotify_url, synced = EXCLUDED.synced`
  202. res, err := db.database.Exec(query, username, round, artist, title, url)
  203. if err != nil {
  204. return false, err
  205. }
  206. affected, err := res.RowsAffected()
  207. if err != nil {
  208. return false, err
  209. }
  210. if affected != 1 {
  211. return false, nil
  212. }
  213. return true, nil
  214. }
  215. func (db *DB) FindAllPanels() ([]string, error) {
  216. query := `
  217. SELECT name FROM public.panel`
  218. rows, err := db.database.Query(query)
  219. if err != nil {
  220. return nil, err
  221. }
  222. defer rows.Close()
  223. names := make([]string, 0)
  224. for i := 0; rows.Next(); i++ {
  225. var name string
  226. err := rows.Scan(&name)
  227. if err != nil {
  228. return nil, err
  229. }
  230. names = append(names, name)
  231. }
  232. return names, nil
  233. }
  234. func (db *DB) FindPlaylistBySection(sectionName string) ([]string, error) {
  235. query := `
  236. SELECT array_to_json(pl.tracks) FROM public.round_playlist pl
  237. JOIN public.round r ON pl.round_id = r.id
  238. WHERE r.section = $1`
  239. row := db.database.QueryRow(query, sectionName)
  240. var tracksJson []byte
  241. tracks := make([]string, 0)
  242. err := row.Scan(&tracksJson)
  243. if err != nil {
  244. return nil, err
  245. }
  246. err = json.Unmarshal(tracksJson, &tracks)
  247. if err != nil {
  248. return nil, err
  249. }
  250. return tracks, nil
  251. }
  252. func (db *DB) UpdatePlaylistBySection(sectionName string, tracks []string) (bool, error) {
  253. tracksJSON, err := json.Marshal(tracks)
  254. if err != nil {
  255. return false, err
  256. }
  257. query := `
  258. INSERT INTO public.round_playlist
  259. SELECT r.id, ARRAY(SELECT e::text FROM json_array_elements_text($2::json) e)
  260. FROM public.round r
  261. WHERE r.section = $1
  262. ON CONFLICT (round_id) DO UPDATE SET tracks = EXCLUDED.tracks`
  263. res, err := db.database.Exec(query, sectionName, tracksJSON)
  264. if err != nil {
  265. return false, err
  266. }
  267. affected, err := res.RowsAffected()
  268. if err != nil {
  269. return false, err
  270. }
  271. if affected != 1 {
  272. return false, nil
  273. }
  274. return true, nil
  275. }