{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeApplications #-} -- Implementation of the API. This module is the main point of the program. module Server (application) where import Control.Monad (forever, void, when) import Control.Monad.Extra (maybeM, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.Logger (logWarnN) import Control.Monad.Reader (forM) import Control.Monad.Trans (lift) import qualified Data.Aeson as A import qualified Data.ByteString.Char8 as C8 import Data.Coerce (coerce) import Data.Functor ((<&>)) import qualified Data.Map as M import Data.Pool (Pool) import Data.Proxy (Proxy (Proxy)) import Data.Swagger (toSchema) import Data.Text (Text) import Data.Time (NominalDiffTime, UTCTime (utctDay), addUTCTime, diffUTCTime, getCurrentTime, nominalDay) import qualified Data.Vector as V import Database.Persist import Database.Persist.Postgresql (SqlBackend, runMigration) import Fmt ((+|), (|+)) import qualified Network.WebSockets as WS import Servant (Application, err401, err404, serve, throwError) import Servant.API (NoContent (..), (:<|>) (..)) import Servant.Server (Handler, hoistServer) import Servant.Swagger (toSwagger) import API import GTFS import Persist import Server.ControlRoom import Server.GTFS_RT (gtfsRealtimeServer) import Server.Util (Service, ServiceM, runService) import Yesod (toWaiAppPlain) import System.IO.Unsafe application :: GTFS -> Pool SqlBackend -> IO Application application gtfs dbpool = do doMigration dbpool pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService $ server gtfs dbpool -- databaseMigration :: ConnectionString -> IO () doMigration pool = runSql pool $ -- TODO: before that, check if the uuid module is enabled -- in sql: check if SELECT * FROM pg_extension WHERE extname = 'uuid-ossp'; -- returns an empty list runMigration migrateAll server :: GTFS -> Pool SqlBackend -> Service CompleteAPI server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> (handleStations :<|> handleTimetable :<|> handleTrip :<|> handleRegister :<|> handleTripPing :<|> handleWS :<|> handleDebugState :<|> gtfsRealtimeServer gtfs dbpool :<|> adminServer gtfs dbpool) :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom "http://localhost:4000/cr" gtfs dbpool))) where handleStations = pure stations handleTimetable station maybeDay = do -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?) day <- liftIO $ maybeM (getCurrentTime <&> utctDay) pure (pure maybeDay) pure -- don't send stations ending at this station . M.filter ((==) station . stationId . stopStation . V.last . tripStops) $ tripsOnDay gtfs day handleTrip trip = case M.lookup trip trips of Just res -> pure res Nothing -> throwError err404 handleRegister tripID = do -- TODO registration may carry extra information! expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod RunningTripKey token <- runSql dbpool $ insert (RunningTrip expires False tripID Nothing) pure token handleTripPing ping = do lift $ checkTokenValid dbpool (coerce $ tripPingToken ping) -- TODO: are these always inserted in order? runSql dbpool $ insert ping pure NoContent handleWS conn = do -- TODO test this!! liftIO $ WS.forkPingThread conn 30 forever $ do msg <- liftIO $ WS.receiveData conn case A.eitherDecode msg of Left err -> do logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) Right ping -> do lift $ checkTokenValid dbpool (coerce $ tripPingToken ping) void $ runSql dbpool $ insert ping handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do running <- selectList [RunningTripBlocked ==. False, RunningTripExpires >=. now] [] pairs <- forM running $ \(Entity token@(RunningTripKey uuid) _) -> do entities <- selectList [TripPingToken ==. token] [] pure (uuid, fmap entityVal entities) pure (M.fromList pairs) handleDebugAPI = pure $ toSwagger (Proxy @API) adminServer :: GTFS -> Pool SqlBackend -> Service AdminAPI adminServer gtfs dbpool = addAnnounce :<|> delAnnounce :<|> modTripDate Added Cancelled :<|> modTripDate Cancelled Added :<|> extraTrip where addAnnounce ann@Announcement{..} = runSql dbpool $ do AnnouncementKey uuid <- insert ann pure uuid delAnnounce uuid = runSql dbpool $ do delete (AnnouncementKey uuid) pure NoContent modTripDate one two tripId day = runSql dbpool $ do getBy (TripAndDay tripId day) >>= \case Just (Entity key (ScheduleAmendment _ _ status)) -> do when (status == two) $ delete key pure NoContent Nothing -> do insert (ScheduleAmendment tripId day one) pure NoContent extraTrip = error "unimplemented!" -- TODO: proper debug logging for expired tokens checkTokenValid :: Pool SqlBackend -> Token -> Handler () checkTokenValid dbpool token = do trip <- try $ runSql dbpool $ get (coerce token) when (runningTripBlocked trip) $ throwError err401 whenM (hasExpired (runningTripExpires trip)) $ throwError err401 where try m = m >>= \case Just a -> pure a Nothing -> throwError err404 hasExpired :: MonadIO m => UTCTime -> m Bool hasExpired limit = do now <- liftIO getCurrentTime pure (now > limit) validityPeriod :: NominalDiffTime validityPeriod = nominalDay