{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeApplications #-} -- Implementation of the API. This module is the main point of the program. module Server (application) where import Control.Concurrent.STM (TQueue, TVar, atomically, newTQueue, newTVar, readTQueue, readTVar, writeTQueue, writeTVar) import Control.Monad (forever, unless, void, when) import Control.Monad.Catch (handle) import Control.Monad.Extra (ifM, maybeM, unlessM, whenJust, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.Logger (LoggingT, logWarnN) import Control.Monad.Reader (forM) import Control.Monad.Trans (lift) import qualified Data.Aeson as A import qualified Data.ByteString.Char8 as C8 import Data.Coerce (coerce) import Data.Functor ((<&>)) import qualified Data.Map as M import Data.Pool (Pool) import Data.Proxy (Proxy (Proxy)) import Data.Swagger (toSchema) import Data.Text (Text) import Data.Text.Encoding (decodeUtf8) import Data.Time (NominalDiffTime, UTCTime (utctDay), addUTCTime, diffUTCTime, getCurrentTime, nominalDay) import qualified Data.Vector as V import Database.Persist import Database.Persist.Postgresql (SqlBackend, runMigration) import Fmt ((+|), (|+)) import qualified Network.WebSockets as WS import Servant (Application, ServerError (errBody), err401, err404, serve, throwError) import Servant.API (NoContent (..), (:<|>) (..)) import Servant.Server (Handler, hoistServer) import Servant.Swagger (toSwagger) import API import GTFS import Persist import Server.ControlRoom import Server.GTFS_RT (gtfsRealtimeServer) import Server.Util (Service, ServiceM, runService, sendErrorMsg) import Yesod (toWaiAppPlain) import Extrapolation (Extrapolator (..), LinearExtrapolator (..)) import System.IO.Unsafe import Config (ServerConfig) import Data.ByteString (ByteString) import Data.ByteString.Lazy (toStrict) import Prometheus import Prometheus.Metric.GHC application :: GTFS -> Pool SqlBackend -> ServerConfig -> IO Application application gtfs dbpool settings = do doMigration dbpool metrics <- Metrics <$> register (gauge (Info "ws_connections" "Number of WS Connections")) register ghcMetrics subscribers <- atomically $ newTVar mempty pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService $ server gtfs metrics subscribers dbpool settings -- databaseMigration :: ConnectionString -> IO () doMigration pool = runSql pool $ -- TODO: before that, check if the uuid module is enabled -- in sql: check if SELECT * FROM pg_extension WHERE extname = 'uuid-ossp'; -- returns an empty list runMigration migrateAll server :: GTFS -> Metrics -> TVar (M.Map TripID ([TQueue (Maybe TrainPing)])) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI :<|> (handleStations :<|> handleTimetable :<|> handleTrip :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain :<|> handleDebugRegister :<|> gtfsRealtimeServer gtfs dbpool) :<|> metrics :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings))) where handleStations = pure stations handleTimetable station maybeDay = do -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?) day <- liftIO $ maybeM (getCurrentTime <&> utctDay) pure (pure maybeDay) pure -- don't send stations ending at this station . M.filter ((==) station . stationId . stopStation . V.last . tripStops) $ tripsOnDay gtfs day handleTrip trip = case M.lookup trip trips of Just res -> pure res Nothing -> throwError err404 handleRegister tripID RegisterJson{..} = do today <- liftIO getCurrentTime <&> utctDay unless (runsOnDay gtfs tripID today) $ sendErrorMsg "this trip does not run today." expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod RunningKey token <- runSql dbpool $ insert (Running expires False tripID today Nothing registerAgent) pure token handleDebugRegister tripID day = do expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod RunningKey token <- runSql dbpool $ insert (Running expires False tripID day Nothing "debug key") pure token handleTrainPing onError ping = isTokenValid dbpool (coerce $ trainPingToken ping) >>= \case Nothing -> do onError pure Nothing Just running@Running{..} -> do let anchor = extrapolateAnchorFromPing LinearExtrapolator gtfs running ping -- TODO: are these always inserted in order? runSql dbpool $ do insert ping last <- selectFirst [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay] [Desc TrainAnchorWhen] -- only insert new estimates if they've actually changed anything when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) $ void $ insert anchor queues <- liftIO $ atomically $ do queues <- readTVar subscribers <&> M.lookup runningTrip whenJust queues $ mapM_ (\q -> writeTQueue q (Just ping)) pure queues pure (Just anchor) handleWS conn = do liftIO $ WS.forkPingThread conn 30 incGauge metricsWSGauge handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do msg <- liftIO $ WS.receiveData conn case A.eitherDecode msg of Left err -> do logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge Right ping -> -- if invalid token, send a "polite" close request. Note that the client may -- ignore this and continue sending messages, which will continue to be handled. liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case Just anchor -> WS.sendTextData conn (A.encode anchor) Nothing -> pure () handleSubscribe tripId day conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do queue <- atomically $ do queue <- newTQueue qs <- readTVar subscribers writeTVar subscribers $ M.insertWith (<>) tripId [queue] qs pure queue -- send most recent ping, if any (so we won't have to wait for movement) lastPing <- runSql dbpool $ do tokens <- selectList [RunningDay ==. day, RunningTrip ==. tripId] [] <&> fmap entityKey selectFirst [TrainPingToken <-. tokens] [Desc TrainPingTimestamp] <&> fmap entityVal whenJust lastPing $ \ping -> WS.sendTextData conn (A.encode lastPing) handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do res <- atomically $ readTQueue queue case res of Just ping -> WS.sendTextData conn (A.encode ping) Nothing -> do removeSubscriber queue WS.sendClose conn (C8.pack "train ended") where removeSubscriber queue = atomically $ do qs <- readTVar subscribers writeTVar subscribers $ M.adjust (filter (/= queue)) tripId qs handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do running <- selectList [RunningBlocked ==. False, RunningExpires >=. now] [] pairs <- forM running $ \(Entity token@(RunningKey uuid) _) -> do entities <- selectList [TrainPingToken ==. token] [] pure (uuid, fmap entityVal entities) pure (M.fromList pairs) handleDebugTrain tripId day = do unless (runsOnDay gtfs tripId day) $ sendErrorMsg ("this trip does not run on "+|day|+".") runSql dbpool $ do tokens <- selectList [RunningTrip ==. tripId, RunningDay ==. day] [] pings <- forM tokens $ \(Entity token _) -> do selectList [TrainPingToken ==. token] [] <&> fmap entityVal pure (concat pings) handleDebugAPI = pure $ toSwagger (Proxy @API) metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) -- TODO: proper debug logging for expired tokens isTokenValid :: MonadIO m => Pool SqlBackend -> Token -> m (Maybe Running) isTokenValid dbpool token = runSql dbpool $ get (coerce token) >>= \case Just trip | not (runningBlocked trip) -> do ifM (hasExpired (runningExpires trip)) (pure Nothing) (pure (Just trip)) _ -> pure Nothing hasExpired :: MonadIO m => UTCTime -> m Bool hasExpired limit = do now <- liftIO getCurrentTime pure (now > limit) validityPeriod :: NominalDiffTime validityPeriod = nominalDay